JiantaoLin commited on
Commit
2fe3da0
·
1 Parent(s): e2cc5f8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README copy.md +111 -0
  3. app.py +499 -0
  4. configs/PRM.yaml +71 -0
  5. configs/PRM_inference.yaml +22 -0
  6. light2map.py +95 -0
  7. obj2mesh.py +121 -0
  8. requirements.txt +21 -0
  9. run.py +355 -0
  10. run.sh +7 -0
  11. run_hpc.sh +16 -0
  12. src/__init__.py +0 -0
  13. src/__pycache__/__init__.cpython-310.pyc +0 -0
  14. src/data/__init__.py +0 -0
  15. src/data/__pycache__/__init__.cpython-310.pyc +0 -0
  16. src/data/__pycache__/objaverse.cpython-310.pyc +0 -0
  17. src/data/bsdf_256_256.bin +0 -0
  18. src/data/objaverse.py +509 -0
  19. src/model_mesh.py +642 -0
  20. src/models/__init__.py +0 -0
  21. src/models/__pycache__/__init__.cpython-310.pyc +0 -0
  22. src/models/__pycache__/lrm_mesh.cpython-310.pyc +0 -0
  23. src/models/decoder/__init__.py +0 -0
  24. src/models/decoder/__pycache__/__init__.cpython-310.pyc +0 -0
  25. src/models/decoder/__pycache__/transformer.cpython-310.pyc +0 -0
  26. src/models/decoder/transformer.py +123 -0
  27. src/models/encoder/__init__.py +0 -0
  28. src/models/encoder/__pycache__/__init__.cpython-310.pyc +0 -0
  29. src/models/encoder/__pycache__/dino.cpython-310.pyc +0 -0
  30. src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc +0 -0
  31. src/models/encoder/dino.py +550 -0
  32. src/models/encoder/dino_wrapper.py +80 -0
  33. src/models/geometry/__init__.py +7 -0
  34. src/models/geometry/__pycache__/__init__.cpython-310.pyc +0 -0
  35. src/models/geometry/camera/__init__.py +16 -0
  36. src/models/geometry/camera/__pycache__/__init__.cpython-310.pyc +0 -0
  37. src/models/geometry/camera/__pycache__/perspective_camera.cpython-310.pyc +0 -0
  38. src/models/geometry/camera/perspective_camera.py +35 -0
  39. src/models/geometry/render/__init__.py +8 -0
  40. src/models/geometry/render/__pycache__/__init__.cpython-310.pyc +0 -0
  41. src/models/geometry/render/__pycache__/neural_render.cpython-310.pyc +0 -0
  42. src/models/geometry/render/__pycache__/util.cpython-310.pyc +0 -0
  43. src/models/geometry/render/neural_render.py +293 -0
  44. src/models/geometry/render/renderutils/__init__.py +11 -0
  45. src/models/geometry/render/renderutils/__pycache__/__init__.cpython-310.pyc +0 -0
  46. src/models/geometry/render/renderutils/__pycache__/bsdf.cpython-310.pyc +0 -0
  47. src/models/geometry/render/renderutils/__pycache__/loss.cpython-310.pyc +0 -0
  48. src/models/geometry/render/renderutils/__pycache__/ops.cpython-310.pyc +0 -0
  49. src/models/geometry/render/renderutils/bsdf.py +151 -0
  50. src/models/geometry/render/renderutils/c_src/bsdf.cu +710 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README copy.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <div align="center">
4
+
5
+ # PRM: Photometric Stereo based Large Reconstruction Model
6
+
7
+ <a href="https://tau-yihouxiang.github.io/projects/X-Ray/X-Ray.html"><img src="https://img.shields.io/badge/Project_Page-Online-EA3A97"></a>
8
+ <a href="https://arxiv.org/abs/2404.07191"><img src="https://img.shields.io/badge/ArXiv-2404.07191-brightgreen"></a>
9
+ <a href="https://huggingface.co/LTT/PRM"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <br>
10
+ <a href="https://huggingface.co/spaces/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
11
+ <a href="https://github.com/jtydhr88/ComfyUI-InstantMesh"><img src="https://img.shields.io/badge/Demo-ComfyUI-8A2BE2"></a>
12
+
13
+ </div>
14
+
15
+ ---
16
+
17
+ An official implementation of PRM, a feed-forward framework for high-quality 3D mesh generation with photometric stereo images.
18
+
19
+
20
+ ![image](https://github.com/g3956/PRM/blob/main/assets/teaser.png)
21
+
22
+ # 🚩 Features
23
+ - [x] Release inference and training code.
24
+ - [x] Release model weights.
25
+ - [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link.
26
+ - [x] Release ComfyUI demo.
27
+
28
+ # ⚙️ Dependencies and Installation
29
+
30
+ We recommend using `Python>=3.10`, `PyTorch>=2.1.0`, and `CUDA>=12.1`.
31
+ ```bash
32
+ conda create --name PRM python=3.10
33
+ conda activate PRM
34
+ pip install -U pip
35
+
36
+ # Ensure Ninja is installed
37
+ conda install Ninja
38
+
39
+ # Install the correct version of CUDA
40
+ conda install cuda -c nvidia/label/cuda-12.1.0
41
+
42
+ # Install PyTorch and xformers
43
+ # You may need to install another xformers version if you use a different PyTorch version
44
+ pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
45
+ pip install xformers==0.0.22.post7
46
+
47
+ # Install Triton
48
+ pip install triton
49
+
50
+ # Install other requirements
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ # 💫 Inference
55
+
56
+ ## Download the pretrained model
57
+
58
+ The pretrained model can be found [model card](https://huggingface.co/LTT/PRM).
59
+
60
+ Our inference script will download the models automatically. Alternatively, you can manually download the models and put them under the `ckpts/` directory.
61
+
62
+ # 💻 Training
63
+
64
+ We provide our training code to facilitate future research.
65
+ For training data, we used filtered Objaverse for training. Before training, you need to pre-processe the environment maps and GLB files into formats that fit our dataloader.
66
+ For preprocessing GLB files, please run
67
+ ```bash
68
+ # GLB files to OBJ files
69
+ python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
70
+ ```
71
+ then
72
+ ```bash
73
+ # OBJ files to mesh files that can be readed
74
+ python obj2mesh.py path_to_obj save_path
75
+ ```
76
+ For preprocessing environment maps, please run
77
+ ```bash
78
+ # Pre-process environment maps
79
+ python light2map.py path_to_env save_path
80
+ ```
81
+
82
+
83
+ To train the sparse-view reconstruction models, please run:
84
+ ```bash
85
+ # Training on Mesh representation
86
+ python train.py --base configs/PRM.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
87
+ ```
88
+ Note that you need to change to root_dir and light_dir to pathes that you save the preprocessed GLB files and environment maps.
89
+
90
+ # :books: Citation
91
+
92
+ If you find our work useful for your research or applications, please cite using this BibTeX:
93
+
94
+ ```BibTeX
95
+ @article{xu2024instantmesh,
96
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
97
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
98
+ journal={arXiv preprint arXiv:2404.07191},
99
+ year={2024}
100
+ }
101
+ ```
102
+
103
+ # 🤗 Acknowledgements
104
+
105
+ We thank the authors of the following projects for their excellent contributions to 3D generative AI!
106
+
107
+ - [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes)
108
+ - [InstantMesh]([https://instant-3d.github.io/](https://github.com/TencentARC/InstantMesh))
109
+
110
+
111
+
app.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ import glm
13
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
14
+
15
+ from src.data.objaverse import load_mipmap
16
+ from src.utils import render_utils
17
+ from src.utils.train_util import instantiate_from_config
18
+ from src.utils.camera_util import (
19
+ FOV_to_intrinsics,
20
+ get_zero123plus_input_cameras,
21
+ get_circular_camera_poses,
22
+ )
23
+ from src.utils.mesh_util import save_obj, save_glb
24
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
25
+
26
+ import tempfile
27
+ from huggingface_hub import hf_hub_download
28
+
29
+
30
+ if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
31
+ device0 = torch.device('cuda:0')
32
+ device1 = torch.device('cuda:0')
33
+ else:
34
+ device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ device1 = device0
36
+
37
+ # Define the cache directory for model files
38
+ model_cache_dir = './ckpts/'
39
+ os.makedirs(model_cache_dir, exist_ok=True)
40
+
41
+ def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50):
42
+ """
43
+ Get the rendering camera parameters.
44
+ """
45
+ train_res = [512, 512]
46
+ cam_near_far = [0.1, 1000.0]
47
+ fovy = np.deg2rad(fov)
48
+ proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
49
+ all_mv = []
50
+ all_mvp = []
51
+ all_campos = []
52
+ if isinstance(elevation, tuple):
53
+ elevation_0 = np.deg2rad(elevation[0])
54
+ elevation_1 = np.deg2rad(elevation[1])
55
+ for i in range(M//2):
56
+ azimuth = 2 * np.pi * i / (M // 2)
57
+ z = radius * np.cos(azimuth) * np.sin(elevation_0)
58
+ x = radius * np.sin(azimuth) * np.sin(elevation_0)
59
+ y = radius * np.cos(elevation_0)
60
+
61
+ eye = glm.vec3(x, y, z)
62
+ at = glm.vec3(0.0, 0.0, 0.0)
63
+ up = glm.vec3(0.0, 1.0, 0.0)
64
+ view_matrix = glm.lookAt(eye, at, up)
65
+ mv = torch.from_numpy(np.array(view_matrix))
66
+ mvp = proj_mtx @ (mv) #w2c
67
+ campos = torch.linalg.inv(mv)[:3, 3]
68
+ all_mv.append(mv[None, ...].cuda())
69
+ all_mvp.append(mvp[None, ...].cuda())
70
+ all_campos.append(campos[None, ...].cuda())
71
+ for i in range(M//2):
72
+ azimuth = 2 * np.pi * i / (M // 2)
73
+ z = radius * np.cos(azimuth) * np.sin(elevation_1)
74
+ x = radius * np.sin(azimuth) * np.sin(elevation_1)
75
+ y = radius * np.cos(elevation_1)
76
+
77
+ eye = glm.vec3(x, y, z)
78
+ at = glm.vec3(0.0, 0.0, 0.0)
79
+ up = glm.vec3(0.0, 1.0, 0.0)
80
+ view_matrix = glm.lookAt(eye, at, up)
81
+ mv = torch.from_numpy(np.array(view_matrix))
82
+ mvp = proj_mtx @ (mv) #w2c
83
+ campos = torch.linalg.inv(mv)[:3, 3]
84
+ all_mv.append(mv[None, ...].cuda())
85
+ all_mvp.append(mvp[None, ...].cuda())
86
+ all_campos.append(campos[None, ...].cuda())
87
+ else:
88
+ # elevation = 90 - elevation
89
+ for i in range(M):
90
+ azimuth = 2 * np.pi * i / M
91
+ z = radius * np.cos(azimuth) * np.sin(elevation)
92
+ x = radius * np.sin(azimuth) * np.sin(elevation)
93
+ y = radius * np.cos(elevation)
94
+
95
+ eye = glm.vec3(x, y, z)
96
+ at = glm.vec3(0.0, 0.0, 0.0)
97
+ up = glm.vec3(0.0, 1.0, 0.0)
98
+ view_matrix = glm.lookAt(eye, at, up)
99
+ mv = torch.from_numpy(np.array(view_matrix))
100
+ mvp = proj_mtx @ (mv) #w2c
101
+ campos = torch.linalg.inv(mv)[:3, 3]
102
+ all_mv.append(mv[None, ...].cuda())
103
+ all_mvp.append(mvp[None, ...].cuda())
104
+ all_campos.append(campos[None, ...].cuda())
105
+ all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
106
+ all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
107
+ all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
108
+ return all_mv, all_mvp, all_campos
109
+
110
+
111
+ def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, is_flexicubes=False):
112
+ """
113
+ Render frames from triplanes.
114
+ """
115
+ frames = []
116
+ albedos = []
117
+ pbr_spec_lights = []
118
+ pbr_diffuse_lights = []
119
+ normals = []
120
+ alphas = []
121
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
122
+ if is_flexicubes:
123
+ out = model.forward_geometry(
124
+ planes,
125
+ render_cameras[:, i:i+chunk_size],
126
+ camera_pos[:, i:i+chunk_size],
127
+ [[env]*chunk_size],
128
+ [[materials]*chunk_size],
129
+ render_size=render_size,
130
+ )
131
+ frame = out['pbr_img']
132
+ albedo = out['albedo']
133
+ pbr_spec_light = out['pbr_spec_light']
134
+ pbr_diffuse_light = out['pbr_diffuse_light']
135
+ normal = out['normal']
136
+ alpha = out['mask']
137
+ else:
138
+ frame = model.forward_synthesizer(
139
+ planes,
140
+ render_cameras[i],
141
+ render_size=render_size,
142
+ )['images_rgb']
143
+ frames.append(frame)
144
+ albedos.append(albedo)
145
+ pbr_spec_lights.append(pbr_spec_light)
146
+ pbr_diffuse_lights.append(pbr_diffuse_light)
147
+ normals.append(normal)
148
+ alphas.append(alpha)
149
+
150
+ frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
151
+ alphas = torch.cat(alphas, dim=1)[0]
152
+ albedos = torch.cat(albedos, dim=1)[0]
153
+ pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0]
154
+ pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0]
155
+ normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3]
156
+ return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas
157
+
158
+
159
+
160
+ def images_to_video(images, output_path, fps=30):
161
+ # images: (N, C, H, W)
162
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
163
+ frames = []
164
+ for i in range(images.shape[0]):
165
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
166
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
167
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
168
+ assert frame.min() >= 0 and frame.max() <= 255, \
169
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
170
+ frames.append(frame)
171
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
172
+
173
+
174
+ ###############################################################################
175
+ # Configuration.
176
+ ###############################################################################
177
+
178
+ seed_everything(0)
179
+
180
+ config_path = 'configs/PRM_inference.yaml'
181
+ config = OmegaConf.load(config_path)
182
+ config_name = os.path.basename(config_path).replace('.yaml', '')
183
+ model_config = config.model_config
184
+ infer_config = config.infer_config
185
+
186
+ IS_FLEXICUBES = True
187
+
188
+ device = torch.device('cuda')
189
+
190
+ # load diffusion model
191
+ print('Loading diffusion model ...')
192
+ pipeline = DiffusionPipeline.from_pretrained(
193
+ "sudo-ai/zero123plus-v1.2",
194
+ custom_pipeline="zero123plus",
195
+ torch_dtype=torch.float16,
196
+ cache_dir=model_cache_dir
197
+ )
198
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
199
+ pipeline.scheduler.config, timestep_spacing='trailing'
200
+ )
201
+
202
+ # load custom white-background UNet
203
+ print('Loading custom white-background unet ...')
204
+ if os.path.exists(infer_config.unet_path):
205
+ unet_ckpt_path = infer_config.unet_path
206
+ else:
207
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
208
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
209
+ pipeline.unet.load_state_dict(state_dict, strict=True)
210
+
211
+ pipeline = pipeline.to(device)
212
+
213
+ # load reconstruction model
214
+ print('Loading reconstruction model ...')
215
+ model = instantiate_from_config(model_config)
216
+ if os.path.exists(infer_config.model_path):
217
+ model_ckpt_path = infer_config.model_path
218
+ else:
219
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
220
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
221
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
222
+ model.load_state_dict(state_dict, strict=True)
223
+
224
+ model = model.to(device1)
225
+ if IS_FLEXICUBES:
226
+ model.init_flexicubes_geometry(device1, fovy=30.0)
227
+ model = model.eval()
228
+
229
+ print('Loading Finished!')
230
+
231
+
232
+ def check_input_image(input_image):
233
+ if input_image is None:
234
+ raise gr.Error("No image uploaded!")
235
+
236
+
237
+ def preprocess(input_image, do_remove_background):
238
+
239
+ rembg_session = rembg.new_session() if do_remove_background else None
240
+ if do_remove_background:
241
+ input_image = remove_background(input_image, rembg_session)
242
+ input_image = resize_foreground(input_image, 0.85)
243
+
244
+ return input_image
245
+
246
+
247
+ def generate_mvs(input_image, sample_steps, sample_seed):
248
+
249
+ seed_everything(sample_seed)
250
+
251
+ # sampling
252
+ generator = torch.Generator(device=device0)
253
+ z123_image = pipeline(
254
+ input_image,
255
+ num_inference_steps=sample_steps,
256
+ generator=generator,
257
+ ).images[0]
258
+
259
+ show_image = np.asarray(z123_image, dtype=np.uint8)
260
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
261
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
262
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
263
+ show_image = Image.fromarray(show_image.numpy())
264
+
265
+ return z123_image, show_image
266
+
267
+
268
+ def make_mesh(mesh_fpath, planes):
269
+
270
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
271
+ mesh_dirname = os.path.dirname(mesh_fpath)
272
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
273
+
274
+ with torch.no_grad():
275
+ # get mesh
276
+
277
+ mesh_out = model.extract_mesh(
278
+ planes,
279
+ use_texture_map=False,
280
+ **infer_config,
281
+ )
282
+
283
+ vertices, faces, vertex_colors = mesh_out
284
+ vertices = vertices[:, [1, 2, 0]]
285
+
286
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
287
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
288
+
289
+ print(f"Mesh saved to {mesh_fpath}")
290
+
291
+ return mesh_fpath, mesh_glb_fpath
292
+
293
+
294
+ def make3d(images):
295
+
296
+ images = np.asarray(images, dtype=np.float32) / 255.0
297
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
298
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
299
+
300
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=3.2, fov=30).to(device).to(device1)
301
+ all_mv, all_mvp, all_campos = get_render_cameras(
302
+ batch_size=1,
303
+ M=240,
304
+ radius=4.5,
305
+ elevation=(90, 60.0),
306
+ is_flexicubes=IS_FLEXICUBES,
307
+ fov=30
308
+ )
309
+
310
+ images = images.unsqueeze(0).to(device1)
311
+ images = v2.functional.resize(images, (512, 512), interpolation=3, antialias=True).clamp(0, 1)
312
+
313
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
314
+ print(mesh_fpath)
315
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
316
+ mesh_dirname = os.path.dirname(mesh_fpath)
317
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
318
+ ENV = load_mipmap("env_mipmap/6")
319
+ materials = (0.0,0.9)
320
+ with torch.no_grad():
321
+ # get triplane
322
+ planes = model.forward_planes(images, input_cameras)
323
+
324
+ # get video
325
+ chunk_size = 20 if IS_FLEXICUBES else 1
326
+ render_size = 512
327
+
328
+ frames = []
329
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
330
+ model,
331
+ planes,
332
+ render_cameras=all_mvp,
333
+ camera_pos=all_campos,
334
+ env=ENV,
335
+ materials=materials,
336
+ render_size=render_size,
337
+ chunk_size=chunk_size,
338
+ is_flexicubes=IS_FLEXICUBES,
339
+ )
340
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
341
+ normals = normals * alphas + (1-alphas)
342
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
343
+
344
+ images_to_video(
345
+ all_frames,
346
+ video_fpath,
347
+ fps=30,
348
+ )
349
+
350
+ print(f"Video saved to {video_fpath}")
351
+
352
+ mesh_fpath, mesh_glb_fpath = make_mesh(mesh_fpath, planes)
353
+
354
+ return video_fpath, mesh_fpath, mesh_glb_fpath
355
+
356
+
357
+ import gradio as gr
358
+
359
+ _HEADER_ = '''
360
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/g3956/PRM' target='_blank'><b>PRM: Photometric Stereo based Large Reconstruction Model</b></a></h2>
361
+
362
+ **PRM** is a feed-forward framework for high-quality 3D mesh generation with fine-grained local details from a single image.
363
+
364
+ Code: <a href='https://github.com/g3956/PRM' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
365
+ '''
366
+
367
+ _CITE_ = r"""
368
+ If PRM is helpful, please help to ⭐ the <a href='https://github.com/g3956/PRM' target='_blank'>Github Repo</a>. Thanks!
369
+ ---
370
+ 📝 **Citation**
371
+
372
+ If you find our work useful for your research or applications, please cite using this bibtex:
373
+ ```bibtex
374
+ @article{xu2024instantmesh,
375
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
376
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
377
+ journal={arXiv preprint arXiv:2404.07191},
378
+ year={2024}
379
+ }
380
+ ```
381
+
382
+ 📋 **License**
383
+
384
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
385
+
386
+ 📧 **Contact**
387
+
388
+ If you have any questions, feel free to open a discussion or contact us at <b>jlin695@connect.hkust-gz.edu.cn</b>.
389
+ """
390
+
391
+ with gr.Blocks() as demo:
392
+ gr.Markdown(_HEADER_)
393
+ with gr.Row(variant="panel"):
394
+ with gr.Column():
395
+ with gr.Row():
396
+ input_image = gr.Image(
397
+ label="Input Image",
398
+ image_mode="RGBA",
399
+ sources="upload",
400
+ width=256,
401
+ height=256,
402
+ type="pil",
403
+ elem_id="content_image",
404
+ )
405
+ processed_image = gr.Image(
406
+ label="Processed Image",
407
+ image_mode="RGBA",
408
+ width=256,
409
+ height=256,
410
+ type="pil",
411
+ interactive=False
412
+ )
413
+ with gr.Row():
414
+ with gr.Group():
415
+ do_remove_background = gr.Checkbox(
416
+ label="Remove Background", value=True
417
+ )
418
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
419
+
420
+ sample_steps = gr.Slider(
421
+ label="Sample Steps",
422
+ minimum=30,
423
+ maximum=100,
424
+ value=75,
425
+ step=5
426
+ )
427
+
428
+ with gr.Row():
429
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
430
+
431
+ with gr.Row(variant="panel"):
432
+ gr.Examples(
433
+ examples=[
434
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
435
+ ],
436
+ inputs=[input_image],
437
+ label="Examples",
438
+ examples_per_page=20
439
+ )
440
+
441
+ with gr.Column():
442
+
443
+ with gr.Row():
444
+
445
+ with gr.Column():
446
+ mv_show_images = gr.Image(
447
+ label="Generated Multi-views",
448
+ type="pil",
449
+ width=379,
450
+ interactive=False
451
+ )
452
+
453
+ with gr.Column():
454
+ with gr.Column():
455
+ output_video = gr.Video(
456
+ label="video", format="mp4",
457
+ width=768,
458
+ autoplay=True,
459
+ interactive=False
460
+ )
461
+
462
+ with gr.Row():
463
+ with gr.Tab("OBJ"):
464
+ output_model_obj = gr.Model3D(
465
+ label="Output Model (OBJ Format)",
466
+ #width=768,
467
+ interactive=False,
468
+ )
469
+ gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
470
+ with gr.Tab("GLB"):
471
+ output_model_glb = gr.Model3D(
472
+ label="Output Model (GLB Format)",
473
+ #width=768,
474
+ interactive=False,
475
+ )
476
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
477
+
478
+ with gr.Row():
479
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
480
+
481
+ gr.Markdown(_CITE_)
482
+ mv_images = gr.State()
483
+
484
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
485
+ fn=preprocess,
486
+ inputs=[input_image, do_remove_background],
487
+ outputs=[processed_image],
488
+ ).success(
489
+ fn=generate_mvs,
490
+ inputs=[processed_image, sample_steps, sample_seed],
491
+ outputs=[mv_images, mv_show_images],
492
+ ).success(
493
+ fn=make3d,
494
+ inputs=[mv_images],
495
+ outputs=[output_video, output_model_obj, output_model_glb]
496
+ )
497
+
498
+ demo.queue(max_size=10)
499
+ demo.launch(server_port=1211)
configs/PRM.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.0e-06
3
+ target: src.model_mesh.MVRecon
4
+ params:
5
+ mesh_save_root: Objaverse
6
+ init_ckpt: nerf_base.ckpt
7
+ input_size: 512
8
+ render_size: 512
9
+ use_tv_loss: true
10
+ sample_points: null
11
+ use_gt_albedo: false
12
+
13
+ lrm_generator_config:
14
+ target: src.models.lrm_mesh.PRM
15
+ params:
16
+ encoder_feat_dim: 768
17
+ encoder_freeze: false
18
+ encoder_model_name: facebook/dino-vitb16
19
+ transformer_dim: 1024
20
+ transformer_layers: 16
21
+ transformer_heads: 16
22
+ triplane_low_res: 32
23
+ triplane_high_res: 64
24
+ triplane_dim: 80
25
+ rendering_samples_per_ray: 128
26
+ grid_res: 128
27
+ grid_scale: 2.1
28
+
29
+
30
+ data:
31
+ target: src.data.objaverse.DataModuleFromConfig
32
+ params:
33
+ batch_size: 1
34
+ num_workers: 8
35
+ train:
36
+ target: src.data.objaverse.ObjaverseData
37
+ params:
38
+ root_dir: Objaverse
39
+ light_dir: env_mipmap
40
+ input_view_num: [6]
41
+ target_view_num: 6
42
+ total_view_n: 18
43
+ distance: 5.0
44
+ fov: 30
45
+ camera_random: true
46
+ validation: false
47
+ validation:
48
+ target: src.data.objaverse.ValidationData
49
+ params:
50
+ root_dir: Objaverse
51
+ input_view_num: 6
52
+ input_image_size: 320
53
+ fov: 30
54
+
55
+
56
+ lightning:
57
+ modelcheckpoint:
58
+ params:
59
+ every_n_train_steps: 100
60
+ save_top_k: -1
61
+ save_last: true
62
+ callbacks: {}
63
+
64
+ trainer:
65
+ benchmark: true
66
+ max_epochs: -1
67
+ val_check_interval: 2000000000
68
+ num_sanity_val_steps: 0
69
+ accumulate_grad_batches: 8
70
+ log_every_n_steps: 1
71
+ check_val_every_n_epoch: null # if not set this, validation does not run
configs/PRM_inference.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.PRM
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/final_ckpt.ckpt
21
+ texture_resolution: 2048
22
+ render_resolution: 512
light2map.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from src.models.geometry.render import renderutils as ru
3
+ import torch
4
+ from src.models.geometry.render import util
5
+ import nvdiffrast.torch as dr
6
+ import os
7
+
8
+ from PIL import Image
9
+ import torchvision.transforms.functional as TF
10
+ import torchvision.utils as vutils
11
+ import imageio
12
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
13
+ LIGHT_MIN_RES = 16
14
+
15
+ MIN_ROUGHNESS = 0.04
16
+ MAX_ROUGHNESS = 1.00
17
+
18
+ class cubemap_mip(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, cubemap):
21
+ return util.avg_pool_nhwc(cubemap, (2,2))
22
+
23
+ @staticmethod
24
+ def backward(ctx, dout):
25
+ res = dout.shape[1] * 2
26
+ out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
27
+ for s in range(6):
28
+ gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
29
+ torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
30
+ indexing='ij')
31
+ v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
32
+ out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
33
+ return out
34
+
35
+ def build_mips(base, cutoff=0.99):
36
+ specular = [base]
37
+ while specular[-1].shape[1] > LIGHT_MIN_RES:
38
+ specular.append(cubemap_mip.apply(specular[-1]))
39
+ #specular.append(util.avg_pool_nhwc(specular[-1], (2,2)))
40
+
41
+ diffuse = ru.diffuse_cubemap(specular[-1])
42
+
43
+ for idx in range(len(specular) - 1):
44
+ roughness = (idx / (len(specular) - 2)) * (MAX_ROUGHNESS - MIN_ROUGHNESS) + MIN_ROUGHNESS
45
+ specular[idx] = ru.specular_cubemap(specular[idx], roughness, cutoff)
46
+ specular[-1] = ru.specular_cubemap(specular[-1], 1.0, cutoff)
47
+
48
+ return specular, diffuse
49
+
50
+
51
+ # Load from latlong .HDR file
52
+ def _load_env_hdr(fn, scale=1.0):
53
+ latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
54
+ cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
55
+
56
+ specular, diffuse = build_mips(cubemap)
57
+
58
+ return specular, diffuse
59
+
60
+ def main(path_hdr, save_path_map):
61
+ all_envs = os.listdir(path_hdr)
62
+
63
+ for env in all_envs:
64
+ env_path = os.path.join(path_hdr, env)
65
+ base_n = os.path.basename(env_path).split('.')[0]
66
+
67
+ try:
68
+ if not os.path.exists(os.path.join(save_path_map, base_n)):
69
+ os.makedirs(os.path.join(save_path_map, base_n))
70
+ specular, diffuse = _load_env_hdr(env_path)
71
+ for i in range(len(specular)):
72
+ tensor = specular[i]
73
+ torch.save(tensor, os.path.join(save_path_map, base_n, f'specular_{i}.pth'))
74
+
75
+ torch.save(diffuse, os.path.join(save_path_map, base_n, 'diffuse.pth'))
76
+ except Exception as e:
77
+ print(f"Error processing {env}: {e}")
78
+ continue
79
+
80
+ if __name__ == "__main__":
81
+ if len(sys.argv) != 3:
82
+ print("Usage: python script.py <path_hdr> <save_path_map>")
83
+ sys.exit(1)
84
+
85
+ path_hdr = sys.argv[1]
86
+ save_path_map = sys.argv[2]
87
+
88
+ if not os.path.exists(path_hdr):
89
+ print(f"Error: path_hdr '{path_hdr}' does not exist.")
90
+ sys.exit(1)
91
+
92
+ if not os.path.exists(save_path_map):
93
+ os.makedirs(save_path_map)
94
+
95
+ main(path_hdr, save_path_map)
obj2mesh.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ import psutil
5
+ import gc
6
+ from tqdm import tqdm
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from src.data.objaverse import load_obj
9
+ from src.utils import mesh
10
+ from src.utils.material import Material
11
+ import argparse
12
+
13
+
14
+ def bytes_to_megabytes(bytes):
15
+ return bytes / (1024 * 1024)
16
+
17
+
18
+ def bytes_to_gigabytes(bytes):
19
+ return bytes / (1024 * 1024 * 1024)
20
+
21
+
22
+ def print_memory_usage(stage):
23
+ process = psutil.Process(os.getpid())
24
+ memory_info = process.memory_info()
25
+ allocated = torch.cuda.memory_allocated() / 1024**2
26
+ cached = torch.cuda.memory_reserved() / 1024**2
27
+ print(
28
+ f"[{stage}] Process memory: {memory_info.rss / 1024**2:.2f} MB, "
29
+ f"Allocated CUDA memory: {allocated:.2f} MB, Cached CUDA memory: {cached:.2f} MB"
30
+ )
31
+
32
+
33
+ def process_obj(index, root_dir, final_save_dir, paths):
34
+ obj_path = os.path.join(root_dir, paths[index], paths[index] + '.obj')
35
+ mtl_path = os.path.join(root_dir, paths[index], paths[index] + '.mtl')
36
+
37
+ if os.path.exists(os.path.join(final_save_dir, f"{paths[index]}.pth")):
38
+ return None
39
+
40
+ try:
41
+ with torch.no_grad():
42
+ ref_mesh, vertices, faces, normals, nfaces, texcoords, tfaces, uber_material = load_obj(
43
+ obj_path, return_attributes=True
44
+ )
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ ref_mesh = mesh.compute_tangents(ref_mesh)
47
+
48
+ with open(mtl_path, 'r') as file:
49
+ lines = file.readlines()
50
+
51
+ if len(lines) >= 250:
52
+ return None
53
+
54
+ final_mesh_attributes = {
55
+ "v_pos": ref_mesh.v_pos.detach().cpu(),
56
+ "v_nrm": ref_mesh.v_nrm.detach().cpu(),
57
+ "v_tex": ref_mesh.v_tex.detach().cpu(),
58
+ "v_tng": ref_mesh.v_tng.detach().cpu(),
59
+ "t_pos_idx": ref_mesh.t_pos_idx.detach().cpu(),
60
+ "t_nrm_idx": ref_mesh.t_nrm_idx.detach().cpu(),
61
+ "t_tex_idx": ref_mesh.t_tex_idx.detach().cpu(),
62
+ "t_tng_idx": ref_mesh.t_tng_idx.detach().cpu(),
63
+ "mat_dict": {key: ref_mesh.material[key] for key in ref_mesh.material.mat_keys},
64
+ }
65
+
66
+ torch.save(final_mesh_attributes, f"{final_save_dir}/{paths[index]}.pth")
67
+ print(f"==> Saved to {final_save_dir}/{paths[index]}.pth")
68
+
69
+ del ref_mesh
70
+ torch.cuda.empty_cache()
71
+ return paths[index]
72
+
73
+ except Exception as e:
74
+ print(f"Failed to process {paths[index]}: {e}")
75
+ return None
76
+
77
+ finally:
78
+ gc.collect()
79
+ torch.cuda.empty_cache()
80
+
81
+
82
+ def main(root_dir, save_dir):
83
+ os.makedirs(save_dir, exist_ok=True)
84
+ finish_lists = os.listdir(save_dir)
85
+ paths = os.listdir(root_dir)
86
+
87
+ valid_uid = []
88
+
89
+ print_memory_usage("Start")
90
+
91
+ batch_size = 100
92
+ num_batches = (len(paths) + batch_size - 1) // batch_size
93
+
94
+ for batch in tqdm(range(num_batches)):
95
+ start_index = batch * batch_size
96
+ end_index = min(start_index + batch_size, len(paths))
97
+
98
+ with ThreadPoolExecutor(max_workers=8) as executor:
99
+ futures = [
100
+ executor.submit(process_obj, index, root_dir, save_dir, paths)
101
+ for index in range(start_index, end_index)
102
+ ]
103
+ for future in as_completed(futures):
104
+ result = future.result()
105
+ if result is not None:
106
+ valid_uid.append(result)
107
+
108
+ print_memory_usage(f"=====> After processing batch {batch + 1}")
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ print_memory_usage("End")
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = argparse.ArgumentParser(description="Process OBJ files and save final results.")
117
+ parser.add_argument("root_dir", type=str, help="Directory containing the root OBJ files.")
118
+ parser.add_argument("save_dir", type=str, help="Directory to save the processed results.")
119
+ args = parser.parse_args()
120
+
121
+ main(args.root_dir, args.save_dir)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch-lightning==2.1.2
2
+ gradio==3.41.2
3
+ huggingface-hub
4
+ einops
5
+ omegaconf
6
+ torchmetrics
7
+ webdataset
8
+ accelerate
9
+ tensorboard
10
+ PyMCubes
11
+ trimesh
12
+ rembg
13
+ transformers==4.34.1
14
+ diffusers==0.20.2
15
+ bitsandbytes
16
+ imageio[ffmpeg]
17
+ xatlas
18
+ plyfile
19
+ git+https://github.com/NVlabs/nvdiffrast/
20
+ PyGLM==2.7.0
21
+ open3d
run.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import glm
4
+ import numpy as np
5
+ import torch
6
+ import rembg
7
+ from PIL import Image
8
+ from torchvision.transforms import v2
9
+ import torchvision
10
+ from pytorch_lightning import seed_everything
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from tqdm import tqdm
14
+ from huggingface_hub import hf_hub_download
15
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
16
+
17
+ from src.data.objaverse import load_mipmap
18
+ from src.utils import render_utils
19
+ from src.utils.train_util import instantiate_from_config
20
+ from src.utils.camera_util import (
21
+ FOV_to_intrinsics,
22
+ center_looking_at_camera_pose,
23
+ get_zero123plus_input_cameras,
24
+ get_circular_camera_poses,
25
+ )
26
+ from src.utils.mesh_util import save_obj, save_obj_with_mtl
27
+ from src.utils.infer_util import remove_background, resize_foreground, save_video
28
+
29
+ def str_to_tuple(arg_str):
30
+ try:
31
+ return eval(arg_str)
32
+ except:
33
+ raise argparse.ArgumentTypeError("Tuple argument must be in the format (x, y)")
34
+
35
+
36
+ def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50):
37
+ """
38
+ Get the rendering camera parameters.
39
+ """
40
+ train_res = [512, 512]
41
+ cam_near_far = [0.1, 1000.0]
42
+ fovy = np.deg2rad(fov)
43
+ proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
44
+ all_mv = []
45
+ all_mvp = []
46
+ all_campos = []
47
+ if isinstance(elevation, tuple):
48
+ elevation_0 = np.deg2rad(elevation[0])
49
+ elevation_1 = np.deg2rad(elevation[1])
50
+ for i in range(M//2):
51
+ azimuth = 2 * np.pi * i / (M // 2)
52
+ z = radius * np.cos(azimuth) * np.sin(elevation_0)
53
+ x = radius * np.sin(azimuth) * np.sin(elevation_0)
54
+ y = radius * np.cos(elevation_0)
55
+
56
+ eye = glm.vec3(x, y, z)
57
+ at = glm.vec3(0.0, 0.0, 0.0)
58
+ up = glm.vec3(0.0, 1.0, 0.0)
59
+ view_matrix = glm.lookAt(eye, at, up)
60
+ mv = torch.from_numpy(np.array(view_matrix))
61
+ mvp = proj_mtx @ (mv) #w2c
62
+ campos = torch.linalg.inv(mv)[:3, 3]
63
+ all_mv.append(mv[None, ...].cuda())
64
+ all_mvp.append(mvp[None, ...].cuda())
65
+ all_campos.append(campos[None, ...].cuda())
66
+ for i in range(M//2):
67
+ azimuth = 2 * np.pi * i / (M // 2)
68
+ z = radius * np.cos(azimuth) * np.sin(elevation_1)
69
+ x = radius * np.sin(azimuth) * np.sin(elevation_1)
70
+ y = radius * np.cos(elevation_1)
71
+
72
+ eye = glm.vec3(x, y, z)
73
+ at = glm.vec3(0.0, 0.0, 0.0)
74
+ up = glm.vec3(0.0, 1.0, 0.0)
75
+ view_matrix = glm.lookAt(eye, at, up)
76
+ mv = torch.from_numpy(np.array(view_matrix))
77
+ mvp = proj_mtx @ (mv) #w2c
78
+ campos = torch.linalg.inv(mv)[:3, 3]
79
+ all_mv.append(mv[None, ...].cuda())
80
+ all_mvp.append(mvp[None, ...].cuda())
81
+ all_campos.append(campos[None, ...].cuda())
82
+ else:
83
+ # elevation = 90 - elevation
84
+ for i in range(M):
85
+ azimuth = 2 * np.pi * i / M
86
+ z = radius * np.cos(azimuth) * np.sin(elevation)
87
+ x = radius * np.sin(azimuth) * np.sin(elevation)
88
+ y = radius * np.cos(elevation)
89
+
90
+ eye = glm.vec3(x, y, z)
91
+ at = glm.vec3(0.0, 0.0, 0.0)
92
+ up = glm.vec3(0.0, 1.0, 0.0)
93
+ view_matrix = glm.lookAt(eye, at, up)
94
+ mv = torch.from_numpy(np.array(view_matrix))
95
+ mvp = proj_mtx @ (mv) #w2c
96
+ campos = torch.linalg.inv(mv)[:3, 3]
97
+ all_mv.append(mv[None, ...].cuda())
98
+ all_mvp.append(mvp[None, ...].cuda())
99
+ all_campos.append(campos[None, ...].cuda())
100
+ all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
101
+ all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
102
+ all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
103
+ return all_mv, all_mvp, all_campos
104
+
105
+ def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, is_flexicubes=False):
106
+ """
107
+ Render frames from triplanes.
108
+ """
109
+ frames = []
110
+ albedos = []
111
+ pbr_spec_lights = []
112
+ pbr_diffuse_lights = []
113
+ normals = []
114
+ alphas = []
115
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
116
+ if is_flexicubes:
117
+ out = model.forward_geometry(
118
+ planes,
119
+ render_cameras[:, i:i+chunk_size],
120
+ camera_pos[:, i:i+chunk_size],
121
+ [[env]*chunk_size],
122
+ [[materials]*chunk_size],
123
+ render_size=render_size,
124
+ )
125
+ frame = out['pbr_img']
126
+ albedo = out['albedo']
127
+ pbr_spec_light = out['pbr_spec_light']
128
+ pbr_diffuse_light = out['pbr_diffuse_light']
129
+ normal = out['normal']
130
+ alpha = out['mask']
131
+ else:
132
+ frame = model.forward_synthesizer(
133
+ planes,
134
+ render_cameras[i],
135
+ render_size=render_size,
136
+ )['images_rgb']
137
+ frames.append(frame)
138
+ albedos.append(albedo)
139
+ pbr_spec_lights.append(pbr_spec_light)
140
+ pbr_diffuse_lights.append(pbr_diffuse_light)
141
+ normals.append(normal)
142
+ alphas.append(alpha)
143
+
144
+ frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
145
+ alphas = torch.cat(alphas, dim=1)[0]
146
+ albedos = torch.cat(albedos, dim=1)[0]
147
+ pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0]
148
+ pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0]
149
+ normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3]
150
+ return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas
151
+
152
+
153
+ ###############################################################################
154
+ # Arguments.
155
+ ###############################################################################
156
+
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument('config', type=str, help='Path to config file.')
159
+ parser.add_argument('input_path', type=str, help='Path to input image or directory.')
160
+ parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
161
+ parser.add_argument('--model_ckpt_path', type=str, default="", help='Output directory.')
162
+ parser.add_argument('--diffusion_steps', type=int, default=100, help='Denoising Sampling steps.')
163
+ parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
164
+ parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
165
+ parser.add_argument('--materials', type=str_to_tuple, default=(1.0, 0.1), help=' metallic and roughness')
166
+ parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
167
+ parser.add_argument('--fov', type=float, default=30, help='Render distance.')
168
+ parser.add_argument('--env_path', type=str, default='data/env_mipmap/2', help='environment map')
169
+ parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
170
+ parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
171
+ parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
172
+ parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
173
+ args = parser.parse_args()
174
+ seed_everything(args.seed)
175
+
176
+ ###############################################################################
177
+ # Stage 0: Configuration.
178
+ ###############################################################################
179
+
180
+ config = OmegaConf.load(args.config)
181
+ config_name = os.path.basename(args.config).replace('.yaml', '')
182
+ model_config = config.model_config
183
+ infer_config = config.infer_config
184
+
185
+ IS_FLEXICUBES = True
186
+
187
+ device = torch.device('cuda')
188
+
189
+ # load diffusion model
190
+ print('Loading diffusion model ...')
191
+ pipeline = DiffusionPipeline.from_pretrained(
192
+ "sudo-ai/zero123plus-v1.2",
193
+ custom_pipeline="zero123plus",
194
+ torch_dtype=torch.float16,
195
+ )
196
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
197
+ pipeline.scheduler.config, timestep_spacing='trailing'
198
+ )
199
+
200
+ # load custom white-background UNet
201
+ print('Loading custom white-background unet ...')
202
+ if os.path.exists(infer_config.unet_path):
203
+ unet_ckpt_path = infer_config.unet_path
204
+ else:
205
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
206
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
207
+ pipeline.unet.load_state_dict(state_dict, strict=True)
208
+
209
+ pipeline = pipeline.to(device)
210
+
211
+ # load reconstruction model
212
+ print('Loading reconstruction model ...')
213
+ model = instantiate_from_config(model_config)
214
+ if os.path.exists(infer_config.model_path):
215
+ model_ckpt_path = infer_config.model_path
216
+ else:
217
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
218
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
219
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
220
+ model.load_state_dict(state_dict, strict=True)
221
+
222
+ model = model.to(device)
223
+ if IS_FLEXICUBES:
224
+ model.init_flexicubes_geometry(device, fovy=50.0)
225
+ model = model.eval()
226
+
227
+ # make output directories
228
+ image_path = os.path.join(args.output_path, config_name, 'images')
229
+ mesh_path = os.path.join(args.output_path, config_name, 'meshes')
230
+ video_path = os.path.join(args.output_path, config_name, 'videos')
231
+ os.makedirs(image_path, exist_ok=True)
232
+ os.makedirs(mesh_path, exist_ok=True)
233
+ os.makedirs(video_path, exist_ok=True)
234
+
235
+ # process input files
236
+ if os.path.isdir(args.input_path):
237
+ input_files = [
238
+ os.path.join(args.input_path, file)
239
+ for file in os.listdir(args.input_path)
240
+ if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
241
+ ]
242
+ else:
243
+ input_files = [args.input_path]
244
+ print(f'Total number of input images: {len(input_files)}')
245
+
246
+ ###############################################################################
247
+ # Stage 1: Multiview generation.
248
+ ###############################################################################
249
+
250
+ rembg_session = None if args.no_rembg else rembg.new_session()
251
+
252
+ outputs = []
253
+ for idx, image_file in enumerate(input_files):
254
+ name = os.path.basename(image_file).split('.')[0]
255
+ print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
256
+
257
+ # remove background optionally
258
+ input_image = Image.open(image_file)
259
+ if not args.no_rembg:
260
+ input_image = remove_background(input_image, rembg_session)
261
+ input_image = resize_foreground(input_image, 0.85)
262
+ # sampling
263
+ output_image = pipeline(
264
+ input_image,
265
+ num_inference_steps=args.diffusion_steps,
266
+ ).images[0]
267
+ print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
268
+
269
+ images = np.asarray(output_image, dtype=np.float32) / 255.0
270
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
271
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
272
+ torchvision.utils.save_image(images, os.path.join(image_path, f'{name}.png'))
273
+ sample = {'name': name, 'images': images}
274
+
275
+ # delete pipeline to save memory
276
+ # del pipeline
277
+
278
+ ###############################################################################
279
+ # Stage 2: Reconstruction.
280
+ ###############################################################################
281
+
282
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=3.2*args.scale, fov=30).to(device)
283
+ chunk_size = 20 if IS_FLEXICUBES else 1
284
+
285
+ # for idx, sample in enumerate(outputs):
286
+ name = sample['name']
287
+ print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
288
+
289
+ images = sample['images'].unsqueeze(0).to(device)
290
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
291
+
292
+ with torch.no_grad():
293
+ # get triplane
294
+ planes = model.forward_planes(images, input_cameras)
295
+
296
+ mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
297
+
298
+ mesh_out = model.extract_mesh(
299
+ planes,
300
+ use_texture_map=args.export_texmap,
301
+ **infer_config,
302
+ )
303
+ if args.export_texmap:
304
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
305
+ save_obj_with_mtl(
306
+ vertices.data.cpu().numpy(),
307
+ uvs.data.cpu().numpy(),
308
+ faces.data.cpu().numpy(),
309
+ mesh_tex_idx.data.cpu().numpy(),
310
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
311
+ mesh_path_idx,
312
+ )
313
+ else:
314
+ vertices, faces, vertex_colors = mesh_out
315
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
316
+ print(f"Mesh saved to {mesh_path_idx}")
317
+
318
+ render_size = 512
319
+ if args.save_video:
320
+ video_path_idx = os.path.join(video_path, f'{name}.mp4')
321
+ render_size = infer_config.render_resolution
322
+ ENV = load_mipmap(args.env_path)
323
+ materials = args.materials
324
+
325
+ all_mv, all_mvp, all_campos = get_render_cameras(
326
+ batch_size=1,
327
+ M=240,
328
+ radius=args.distance,
329
+ elevation=(90, 60.0),
330
+ is_flexicubes=IS_FLEXICUBES,
331
+ fov=args.fov
332
+ )
333
+
334
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
335
+ model,
336
+ planes,
337
+ render_cameras=all_mvp,
338
+ camera_pos=all_campos,
339
+ env=ENV,
340
+ materials=materials,
341
+ render_size=render_size,
342
+ chunk_size=chunk_size,
343
+ is_flexicubes=IS_FLEXICUBES,
344
+ )
345
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
346
+ normals = normals * alphas + (1-alphas)
347
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
348
+
349
+ # breakpoint()
350
+ save_video(
351
+ all_frames,
352
+ video_path_idx,
353
+ fps=30,
354
+ )
355
+ print(f"Video saved to {video_path_idx}")
run.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python run.py configs/PRM_inference.yaml examples/ \
2
+ --seed 10 \
3
+ --materials "(0.0, 0.9)" \
4
+ --env_path "./env_mipmap/6" \
5
+ --output_path "output/" \
6
+ --save_video \
7
+ --export_texmap \
run_hpc.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source /hpc2ssd/softwares/anaconda3/bin/activate instantmesh
2
+ module load cuda/12.1 compilers/gcc-11.1.0 compilers/icc-2023.1.0 cmake/3.27.0
3
+ export CXX=$(which g++)
4
+ export CC=$(which gcc)
5
+ export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
6
+ export CUDA_LAUNCH_BLOCKING=1
7
+ export NCCL_TIMEOUT=3600
8
+ export CUDA_VISIBLE_DEVICES="0"
9
+ # python app.py
10
+ python run.py configs/PRM_inference.yaml examples/恐龙套装.webp \
11
+ --seed 10 \
12
+ --materials "(0.0, 0.9)" \
13
+ --env_path "./env_mipmap/6" \
14
+ --output_path "output/" \
15
+ --save_video \
16
+ --export_texmap \
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (139 Bytes). View file
 
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (144 Bytes). View file
 
src/data/__pycache__/objaverse.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
src/data/bsdf_256_256.bin ADDED
Binary file (524 kB). View file
 
src/data/objaverse.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import math
3
+ import json
4
+ import glm
5
+ from pathlib import Path
6
+
7
+ import random
8
+ import numpy as np
9
+ from PIL import Image
10
+ import webdataset as wds
11
+ import pytorch_lightning as pl
12
+ import sys
13
+ from src.utils import obj, render_utils
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ import random
19
+ import itertools
20
+ from src.utils.train_util import instantiate_from_config
21
+ from src.utils.camera_util import (
22
+ FOV_to_intrinsics,
23
+ center_looking_at_camera_pose,
24
+ get_circular_camera_poses,
25
+ )
26
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
27
+ import re
28
+
29
+ def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
30
+ azimuths = np.deg2rad(azimuths)
31
+ elevations = np.deg2rad(elevations)
32
+
33
+ xs = radius * np.cos(elevations) * np.cos(azimuths)
34
+ ys = radius * np.cos(elevations) * np.sin(azimuths)
35
+ zs = radius * np.sin(elevations)
36
+
37
+ cam_locations = np.stack([xs, ys, zs], axis=-1)
38
+ cam_locations = torch.from_numpy(cam_locations).float()
39
+
40
+ c2ws = center_looking_at_camera_pose(cam_locations)
41
+ return c2ws
42
+
43
+ def find_matching_files(base_path, idx):
44
+ formatted_idx = '%03d' % idx
45
+ pattern = re.compile(r'^%s_\d+\.png$' % formatted_idx)
46
+ matching_files = []
47
+
48
+ if os.path.exists(base_path):
49
+ for filename in os.listdir(base_path):
50
+ if pattern.match(filename):
51
+ matching_files.append(filename)
52
+
53
+ return os.path.join(base_path, matching_files[0])
54
+
55
+ def load_mipmap(env_path):
56
+ diffuse_path = os.path.join(env_path, "diffuse.pth")
57
+ diffuse = torch.load(diffuse_path, map_location=torch.device('cpu'))
58
+
59
+ specular = []
60
+ for i in range(6):
61
+ specular_path = os.path.join(env_path, f"specular_{i}.pth")
62
+ specular_tensor = torch.load(specular_path, map_location=torch.device('cpu'))
63
+ specular.append(specular_tensor)
64
+ return [specular, diffuse]
65
+
66
+ def convert_to_white_bg(image, write_bg=True):
67
+ alpha = image[:, :, 3:]
68
+ if write_bg:
69
+ return image[:, :, :3] * alpha + 1. * (1 - alpha)
70
+ else:
71
+ return image[:, :, :3] * alpha
72
+
73
+ def load_obj(path, return_attributes=False, scale_factor=1.0):
74
+ return obj.load_obj(path, clear_ks=True, mtl_override=None, return_attributes=return_attributes, scale_factor=scale_factor)
75
+
76
+ def custom_collate_fn(batch):
77
+ return batch
78
+
79
+
80
+ def collate_fn_wrapper(batch):
81
+ return custom_collate_fn(batch)
82
+
83
+ class DataModuleFromConfig(pl.LightningDataModule):
84
+ def __init__(
85
+ self,
86
+ batch_size=8,
87
+ num_workers=4,
88
+ train=None,
89
+ validation=None,
90
+ test=None,
91
+ **kwargs,
92
+ ):
93
+ super().__init__()
94
+
95
+ self.batch_size = batch_size
96
+ self.num_workers = num_workers
97
+
98
+ self.dataset_configs = dict()
99
+ if train is not None:
100
+ self.dataset_configs['train'] = train
101
+ if validation is not None:
102
+ self.dataset_configs['validation'] = validation
103
+ if test is not None:
104
+ self.dataset_configs['test'] = test
105
+
106
+ def setup(self, stage):
107
+
108
+ if stage in ['fit']:
109
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
110
+ else:
111
+ raise NotImplementedError
112
+
113
+ def custom_collate_fn(self, batch):
114
+ collated_batch = {}
115
+ for key in batch[0].keys():
116
+ if key == 'input_env' or key == 'target_env':
117
+ collated_batch[key] = [d[key] for d in batch]
118
+ else:
119
+ collated_batch[key] = torch.stack([d[key] for d in batch], dim=0)
120
+ return collated_batch
121
+
122
+ def convert_to_white_bg(self, image):
123
+ alpha = image[:, :, 3:]
124
+ return image[:, :, :3] * alpha + 1. * (1 - alpha)
125
+
126
+ def load_obj(self, path):
127
+ return obj.load_obj(path, clear_ks=True, mtl_override=None)
128
+
129
+ def train_dataloader(self):
130
+
131
+ sampler = DistributedSampler(self.datasets['train'])
132
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper)
133
+
134
+ def val_dataloader(self):
135
+
136
+ sampler = DistributedSampler(self.datasets['validation'])
137
+ return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper)
138
+
139
+ def test_dataloader(self):
140
+
141
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
142
+
143
+
144
+ class ObjaverseData(Dataset):
145
+ def __init__(self,
146
+ root_dir='Objaverse_highQuality',
147
+ light_dir= 'env_mipmap',
148
+ input_view_num=6,
149
+ target_view_num=4,
150
+ total_view_n=18,
151
+ distance=3.5,
152
+ fov=50,
153
+ camera_random=False,
154
+ validation=False,
155
+ ):
156
+ self.root_dir = Path(root_dir)
157
+ self.light_dir = light_dir
158
+ self.all_env_name = []
159
+ for temp_dir in os.listdir(light_dir):
160
+ if os.listdir(os.path.join(self.light_dir, temp_dir)):
161
+ self.all_env_name.append(temp_dir)
162
+
163
+ self.input_view_num = input_view_num
164
+ self.target_view_num = target_view_num
165
+ self.total_view_n = total_view_n
166
+ self.fov = fov
167
+ self.camera_random = camera_random
168
+
169
+ self.train_res = [512, 512]
170
+ self.cam_near_far = [0.1, 1000.0]
171
+ self.fov_rad = np.deg2rad(fov)
172
+ self.fov_deg = fov
173
+ self.spp = 1
174
+ self.cam_radius = distance
175
+ self.layers = 1
176
+
177
+ numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
178
+ self.combinations = list(itertools.product(numbers, repeat=2))
179
+
180
+ self.paths = os.listdir(self.root_dir)
181
+
182
+ # with open("BJ_Mesh_list.json", 'r') as file:
183
+ # self.paths = json.load(file)
184
+
185
+ print('total training object num:', len(self.paths))
186
+
187
+ self.depth_scale = 6.0
188
+
189
+ total_objects = len(self.paths)
190
+ print('============= length of dataset %d =============' % total_objects)
191
+
192
+ def __len__(self):
193
+ return len(self.paths)
194
+
195
+ def load_obj(self, path):
196
+ return obj.load_obj(path, clear_ks=True, mtl_override=None)
197
+
198
+ def sample_spherical(self, phi, theta, cam_radius):
199
+ theta = np.deg2rad(theta)
200
+ phi = np.deg2rad(phi)
201
+
202
+ z = cam_radius * np.cos(phi) * np.sin(theta)
203
+ x = cam_radius * np.sin(phi) * np.sin(theta)
204
+ y = cam_radius * np.cos(theta)
205
+
206
+ return x, y, z
207
+
208
+ def _random_scene(self, cam_radius, fov_rad):
209
+ iter_res = self.train_res
210
+ proj_mtx = render_utils.perspective(fov_rad, iter_res[1] / iter_res[0], self.cam_near_far[0], self.cam_near_far[1])
211
+
212
+ azimuths = random.uniform(0, 360)
213
+ elevations = random.uniform(30, 150)
214
+ mv_embedding = spherical_camera_pose(azimuths, 90-elevations, cam_radius)
215
+ x, y, z = self.sample_spherical(azimuths, elevations, cam_radius)
216
+ eye = glm.vec3(x, y, z)
217
+ at = glm.vec3(0.0, 0.0, 0.0)
218
+ up = glm.vec3(0.0, 1.0, 0.0)
219
+ view_matrix = glm.lookAt(eye, at, up)
220
+ mv = torch.from_numpy(np.array(view_matrix))
221
+ mvp = proj_mtx @ (mv) #w2c
222
+ campos = torch.linalg.inv(mv)[:3, 3]
223
+ return mv[None, ...], mvp[None, ...], campos[None, ...], mv_embedding[None, ...], iter_res, self.spp # Add batch dimension
224
+
225
+ def load_im(self, path, color):
226
+ '''
227
+ replace background pixel with random color in rendering
228
+ '''
229
+ pil_img = Image.open(path)
230
+
231
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
232
+ alpha = image[:, :, 3:]
233
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
234
+
235
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
236
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
237
+ return image, alpha
238
+
239
+ def load_albedo(self, path, color, mask):
240
+ '''
241
+ replace background pixel with random color in rendering
242
+ '''
243
+ pil_img = Image.open(path)
244
+
245
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
246
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
247
+
248
+ color = torch.ones_like(image)
249
+ image = image * mask + color * (1 - mask)
250
+ return image
251
+
252
+ def convert_to_white_bg(self, image):
253
+ alpha = image[:, :, 3:]
254
+ return image[:, :, :3] * alpha + 1. * (1 - alpha)
255
+
256
+ def calculate_fov(self, initial_distance, initial_fov, new_distance):
257
+ initial_fov_rad = math.radians(initial_fov)
258
+
259
+ height = 2 * initial_distance * math.tan(initial_fov_rad / 2)
260
+
261
+ new_fov_rad = 2 * math.atan(height / (2 * new_distance))
262
+
263
+ new_fov = math.degrees(new_fov_rad)
264
+
265
+ return new_fov
266
+
267
+ def __getitem__(self, index):
268
+ obj_path = os.path.join(self.root_dir, self.paths[index])
269
+ mesh_attributes = torch.load(obj_path, map_location=torch.device('cpu'))
270
+ pose_list = []
271
+ env_list = []
272
+ material_list = []
273
+ camera_pos = []
274
+ c2w_list = []
275
+ camera_embedding_list = []
276
+ random_env = False
277
+ random_mr = False
278
+ if random.random() > 0.5:
279
+ random_env = True
280
+ if random.random() > 0.5:
281
+ random_mr = True
282
+ selected_env = random.randint(0, len(self.all_env_name)-1)
283
+ materials = random.choice(self.combinations)
284
+ if self.camera_random:
285
+ random_perturbation = random.uniform(-1.5, 1.5)
286
+ cam_radius = self.cam_radius + random_perturbation
287
+ fov_deg = self.calculate_fov(initial_distance=self.cam_radius, initial_fov=self.fov_deg, new_distance=cam_radius)
288
+ fov_rad = np.deg2rad(fov_deg)
289
+ else:
290
+ cam_radius = self.cam_radius
291
+ fov_rad = self.fov_rad
292
+ fov_deg = self.fov_deg
293
+
294
+ if len(self.input_view_num) >= 1:
295
+ input_view_num = random.choice(self.input_view_num)
296
+ else:
297
+ input_view_num = self.input_view_num
298
+ for _ in range(input_view_num + self.target_view_num):
299
+ mv, mvp, campos, mv_mebedding, iter_res, iter_spp = self._random_scene(cam_radius, fov_rad)
300
+ if random_env:
301
+ selected_env = random.randint(0, len(self.all_env_name)-1)
302
+ env_path = os.path.join(self.light_dir, self.all_env_name[selected_env])
303
+ env = load_mipmap(env_path)
304
+ if random_mr:
305
+ materials = random.choice(self.combinations)
306
+ pose_list.append(mvp)
307
+ camera_pos.append(campos)
308
+ c2w_list.append(mv)
309
+ env_list.append(env)
310
+ material_list.append(materials)
311
+ camera_embedding_list.append(mv_mebedding)
312
+ data = {
313
+ 'mesh_attributes': mesh_attributes,
314
+ 'input_view_num': input_view_num,
315
+ 'target_view_num': self.target_view_num,
316
+ 'obj_path': obj_path,
317
+ 'pose_list': pose_list,
318
+ 'camera_pos': camera_pos,
319
+ 'c2w_list': c2w_list,
320
+ 'env_list': env_list,
321
+ 'material_list': material_list,
322
+ 'camera_embedding_list': camera_embedding_list,
323
+ 'fov_deg':fov_deg,
324
+ 'raduis': cam_radius
325
+ }
326
+
327
+ return data
328
+
329
+ class ValidationData(Dataset):
330
+ def __init__(self,
331
+ root_dir='objaverse/',
332
+ input_view_num=6,
333
+ input_image_size=320,
334
+ fov=30,
335
+ ):
336
+ self.root_dir = Path(root_dir)
337
+ self.input_view_num = input_view_num
338
+ self.input_image_size = input_image_size
339
+ self.fov = fov
340
+ self.light_dir = 'env_mipmap'
341
+
342
+ # with open('Mesh_list.json') as f:
343
+ # filtered_dict = json.load(f)
344
+
345
+ self.paths = os.listdir(self.root_dir)
346
+
347
+ # self.paths = filtered_dict
348
+ print('============= length of dataset %d =============' % len(self.paths))
349
+
350
+ cam_distance = 4.0
351
+ azimuths = np.array([30, 90, 150, 210, 270, 330])
352
+ elevations = np.array([20, -10, 20, -10, 20, -10])
353
+ azimuths = np.deg2rad(azimuths)
354
+ elevations = np.deg2rad(elevations)
355
+
356
+ x = cam_distance * np.cos(elevations) * np.cos(azimuths)
357
+ y = cam_distance * np.cos(elevations) * np.sin(azimuths)
358
+ z = cam_distance * np.sin(elevations)
359
+
360
+ cam_locations = np.stack([x, y, z], axis=-1)
361
+ cam_locations = torch.from_numpy(cam_locations).float()
362
+ c2ws = center_looking_at_camera_pose(cam_locations)
363
+ self.c2ws = c2ws.float()
364
+ self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
365
+
366
+ render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
367
+ render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
368
+ self.render_c2ws = render_c2ws.float()
369
+ self.render_Ks = render_Ks.float()
370
+
371
+ def __len__(self):
372
+ return len(self.paths)
373
+
374
+ def load_im(self, path, color):
375
+ '''
376
+ replace background pixel with random color in rendering
377
+ '''
378
+ pil_img = Image.open(path)
379
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
380
+
381
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
382
+ if image.shape[-1] == 4:
383
+ alpha = image[:, :, 3:]
384
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
385
+ else:
386
+ alpha = np.ones_like(image[:, :, :1])
387
+
388
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
389
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
390
+ return image, alpha
391
+
392
+ def load_mat(self, path, color):
393
+ '''
394
+ replace background pixel with random color in rendering
395
+ '''
396
+ pil_img = Image.open(path)
397
+ pil_img = pil_img.resize((384,384), resample=Image.BICUBIC)
398
+
399
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
400
+ if image.shape[-1] == 4:
401
+ alpha = image[:, :, 3:]
402
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
403
+ else:
404
+ alpha = np.ones_like(image[:, :, :1])
405
+
406
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
407
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
408
+ return image, alpha
409
+
410
+ def load_albedo(self, path, color, mask):
411
+ '''
412
+ replace background pixel with random color in rendering
413
+ '''
414
+ pil_img = Image.open(path)
415
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
416
+
417
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
418
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
419
+
420
+ color = torch.ones_like(image)
421
+ image = image * mask + color * (1 - mask)
422
+ return image
423
+
424
+ def __getitem__(self, index):
425
+
426
+ # load data
427
+ input_image_path = os.path.join(self.root_dir, self.paths[index])
428
+
429
+ '''background color, default: white'''
430
+ bkg_color = [1.0, 1.0, 1.0]
431
+
432
+ image_list = []
433
+ albedo_list = []
434
+ alpha_list = []
435
+ specular_list = []
436
+ diffuse_list = []
437
+ metallic_list = []
438
+ roughness_list = []
439
+
440
+ exist_comb_list = []
441
+ for subfolder in os.listdir(input_image_path):
442
+ found_numeric_subfolder=False
443
+ subfolder_path = os.path.join(input_image_path, subfolder)
444
+ if os.path.isdir(subfolder_path) and '_' in subfolder and 'specular' not in subfolder and 'diffuse' not in subfolder:
445
+ try:
446
+ parts = subfolder.split('_')
447
+ float(parts[0]) # 尝试将分隔符前后的字符串转换为浮点数
448
+ float(parts[1])
449
+ found_numeric_subfolder = True
450
+ except ValueError:
451
+ continue
452
+ if found_numeric_subfolder:
453
+ exist_comb_list.append(subfolder)
454
+
455
+ selected_one_comb = random.choice(exist_comb_list)
456
+
457
+
458
+ for idx in range(self.input_view_num):
459
+ img_path = find_matching_files(os.path.join(input_image_path, selected_one_comb, 'rgb'), idx)
460
+ albedo_path = img_path.replace('rgb', 'albedo')
461
+ metallic_path = img_path.replace('rgb', 'metallic')
462
+ roughness_path = img_path.replace('rgb', 'roughness')
463
+
464
+ image, alpha = self.load_im(img_path, bkg_color)
465
+ albedo = self.load_albedo(albedo_path, bkg_color, alpha)
466
+ metallic,_ = self.load_mat(metallic_path, bkg_color)
467
+ roughness,_ = self.load_mat(roughness_path, bkg_color)
468
+
469
+ light_num = os.path.basename(img_path).split('_')[1].split('.')[0]
470
+ light_path = os.path.join(self.light_dir, str(int(light_num)+1))
471
+
472
+ specular, diffuse = load_mipmap(light_path)
473
+
474
+ image_list.append(image)
475
+ alpha_list.append(alpha)
476
+ albedo_list.append(albedo)
477
+ metallic_list.append(metallic)
478
+ roughness_list.append(roughness)
479
+ specular_list.append(specular)
480
+ diffuse_list.append(diffuse)
481
+
482
+ images = torch.stack(image_list, dim=0).float()
483
+ alphas = torch.stack(alpha_list, dim=0).float()
484
+ albedo = torch.stack(albedo_list, dim=0).float()
485
+ metallic = torch.stack(metallic_list, dim=0).float()
486
+ roughness = torch.stack(roughness_list, dim=0).float()
487
+
488
+ data = {
489
+ 'input_images': images,
490
+ 'input_alphas': alphas,
491
+ 'input_c2ws': self.c2ws,
492
+ 'input_Ks': self.Ks,
493
+
494
+ 'input_albedos': albedo[:self.input_view_num],
495
+ 'input_metallics': metallic[:self.input_view_num],
496
+ 'input_roughness': roughness[:self.input_view_num],
497
+
498
+ 'specular': specular_list[:self.input_view_num],
499
+ 'diffuse': diffuse_list[:self.input_view_num],
500
+
501
+ 'render_c2ws': self.render_c2ws,
502
+ 'render_Ks': self.render_Ks,
503
+ }
504
+ return data
505
+
506
+
507
+ if __name__ == '__main__':
508
+ dataset = ObjaverseData()
509
+ dataset.new(1)
src/model_mesh.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import gc
7
+ from torchvision.transforms import v2
8
+ from torchvision.utils import make_grid, save_image
9
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
10
+ import pytorch_lightning as pl
11
+ from einops import rearrange, repeat
12
+ from src.utils.camera_util import FOV_to_intrinsics
13
+ from src.utils.material import Material
14
+ from src.utils.train_util import instantiate_from_config
15
+ import nvdiffrast.torch as dr
16
+ from src.utils import render
17
+ from src.utils.mesh import Mesh, compute_tangents
18
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
19
+
20
+ # from pytorch3d.transforms import quaternion_to_matrix, euler_angles_to_matrix
21
+ GLCTX = [None] * torch.cuda.device_count()
22
+
23
+ def initialize_extension(gpu_id):
24
+ global GLCTX
25
+ if GLCTX[gpu_id] is None:
26
+ print(f"Initializing extension module renderutils_plugin on GPU {gpu_id}...")
27
+ torch.cuda.set_device(gpu_id)
28
+ GLCTX[gpu_id] = dr.RasterizeCudaContext()
29
+ return GLCTX[gpu_id]
30
+
31
+ # Regulrarization loss for FlexiCubes
32
+ def sdf_reg_loss_batch(sdf, all_edges):
33
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
34
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
35
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
36
+ sdf_diff = F.binary_cross_entropy_with_logits(
37
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
38
+ F.binary_cross_entropy_with_logits(
39
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
40
+ return sdf_diff
41
+
42
+ def rotate_x(a, device=None):
43
+ s, c = np.sin(a), np.cos(a)
44
+ return torch.tensor([[1, 0, 0, 0],
45
+ [0, c,-s, 0],
46
+ [0, s, c, 0],
47
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
48
+
49
+
50
+ def convert_to_white_bg(image, write_bg=True):
51
+ alpha = image[:, :, 3:]
52
+ if write_bg:
53
+ return image[:, :, :3] * alpha + 1. * (1 - alpha)
54
+ else:
55
+ return image[:, :, :3] * alpha
56
+
57
+
58
+ class MVRecon(pl.LightningModule):
59
+ def __init__(
60
+ self,
61
+ lrm_generator_config,
62
+ input_size=256,
63
+ render_size=512,
64
+ init_ckpt=None,
65
+ use_tv_loss=True,
66
+ mesh_save_root="Objaverse_highQuality",
67
+ sample_points=None,
68
+ use_gt_albedo=False,
69
+ ):
70
+ super(MVRecon, self).__init__()
71
+
72
+ self.use_gt_albedo = use_gt_albedo
73
+ self.use_tv_loss = use_tv_loss
74
+ self.input_size = input_size
75
+ self.render_size = render_size
76
+ self.mesh_save_root = mesh_save_root
77
+ self.sample_points = sample_points
78
+
79
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
80
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
81
+
82
+ if init_ckpt is not None:
83
+ sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
84
+ sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
85
+ sd_fc = {}
86
+ for k, v in sd.items():
87
+ if k.startswith('lrm_generator.synthesizer.decoder.net.'):
88
+ if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
89
+ # Here we assume the density filed's isosurface threshold is t,
90
+ # we reverse the sign of density filed to initialize SDF field.
91
+ # -(w*x + b - t) = (-w)*x + (t - b)
92
+ if 'weight' in k:
93
+ sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
94
+ else:
95
+ sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
96
+ sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
97
+ else:
98
+ sd_fc[k.replace('net.', 'net_sdf.')] = v
99
+ sd_fc[k.replace('net.', 'net_rgb.')] = v
100
+ else:
101
+ sd_fc[k] = v
102
+ sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
103
+ # missing `net_deformation` and `net_weight` parameters
104
+ self.lrm_generator.load_state_dict(sd_fc, strict=False)
105
+ print(f'Loaded weights from {init_ckpt}')
106
+
107
+ self.validation_step_outputs = []
108
+
109
+ def on_fit_start(self):
110
+ device = torch.device(f'cuda:{self.local_rank}')
111
+ self.lrm_generator.init_flexicubes_geometry(device)
112
+ if self.global_rank == 0:
113
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
114
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
115
+
116
+ def collate_fn(self, batch):
117
+ gpu_id = torch.cuda.current_device() # 获取当前线程的 GPU ID
118
+ glctx = initialize_extension(gpu_id)
119
+ batch_size = len(batch)
120
+ input_view_num = batch[0]["input_view_num"]
121
+ target_view_num = batch[0]["target_view_num"]
122
+ iter_res = [512, 512]
123
+ iter_spp = 1
124
+ layers = 1
125
+
126
+ # Initialize lists for input and target data
127
+ input_images, input_alphas, input_depths, input_normals, input_albedos = [], [], [], [], []
128
+ input_spec_light, input_diff_light, input_spec_albedo,input_diff_albedo = [], [], [], []
129
+ input_w2cs, input_Ks, input_camera_pos, input_c2ws = [], [], [], []
130
+ input_env, input_materials = [], []
131
+ input_camera_embeddings = [] # camera_embedding_list
132
+
133
+ target_images, target_alphas, target_depths, target_normals, target_albedos = [], [], [], [], []
134
+ target_spec_light, target_diff_light, target_spec_albedo, target_diff_albedo = [], [], [], []
135
+ target_w2cs, target_Ks, target_camera_pos = [], [], []
136
+ target_env, target_materials = [], []
137
+
138
+ for sample in batch:
139
+ obj_path = sample['obj_path']
140
+
141
+ with torch.no_grad():
142
+ mesh_attributes = sample['mesh_attributes']
143
+ v_pos = mesh_attributes["v_pos"].to(self.device)
144
+ v_nrm = mesh_attributes["v_nrm"].to(self.device)
145
+ v_tex = mesh_attributes["v_tex"].to(self.device)
146
+ v_tng = mesh_attributes["v_tng"].to(self.device)
147
+ t_pos_idx = mesh_attributes["t_pos_idx"].to(self.device)
148
+ t_nrm_idx = mesh_attributes["t_nrm_idx"].to(self.device)
149
+ t_tex_idx = mesh_attributes["t_tex_idx"].to(self.device)
150
+ t_tng_idx = mesh_attributes["t_tng_idx"].to(self.device)
151
+ material = Material(mesh_attributes["mat_dict"])
152
+ material = material.to(self.device)
153
+ ref_mesh = Mesh(v_pos=v_pos, v_nrm=v_nrm, v_tex=v_tex, v_tng=v_tng,
154
+ t_pos_idx=t_pos_idx, t_nrm_idx=t_nrm_idx,
155
+ t_tex_idx=t_tex_idx, t_tng_idx=t_tng_idx, material=material)
156
+
157
+ pose_list_sample = sample['pose_list'] # mvp
158
+ camera_pos_sample = sample['camera_pos'] # campos, mv.inverse
159
+ c2w_list_sample = sample['c2w_list'] # mv
160
+ env_list_sample = sample['env_list']
161
+ material_list_sample = sample['material_list']
162
+ camera_embeddings = sample["camera_embedding_list"]
163
+ fov_deg = sample['fov_deg']
164
+ raduis = sample['raduis']
165
+ # print(f"fov_deg:{fov_deg}, raduis:{raduis}")
166
+
167
+ sample_input_images, sample_input_alphas, sample_input_depths, sample_input_normals, sample_input_albedos = [], [], [], [], []
168
+ sample_input_w2cs, sample_input_Ks, sample_input_camera_pos, sample_input_c2ws = [], [], [], []
169
+ sample_input_camera_embeddings = []
170
+ sample_input_spec_light, sample_input_diff_light = [], []
171
+
172
+ sample_target_images, sample_target_alphas, sample_target_depths, sample_target_normals, sample_target_albedos = [], [], [], [], []
173
+ sample_target_w2cs, sample_target_Ks, sample_target_camera_pos = [], [], []
174
+ sample_target_spec_light, sample_target_diff_light = [], []
175
+
176
+ sample_input_env = []
177
+ sample_input_materials = []
178
+ sample_target_env = []
179
+ sample_target_materials = []
180
+
181
+ for i in range(len(pose_list_sample)):
182
+ mvp = pose_list_sample[i]
183
+ campos = camera_pos_sample[i]
184
+ env = env_list_sample[i]
185
+ materials = material_list_sample[i]
186
+ camera_embedding = camera_embeddings[i]
187
+
188
+ with torch.no_grad():
189
+ buffer_dict = render.render_mesh(glctx, ref_mesh, mvp.to(self.device), campos.to(self.device), [env], None, None,
190
+ materials, iter_res, spp=iter_spp, num_layers=layers, msaa=True,
191
+ background=None, gt_render=True)
192
+
193
+ image = convert_to_white_bg(buffer_dict['shaded'][0])
194
+ albedo = convert_to_white_bg(buffer_dict['albedo'][0]).clamp(0., 1.)
195
+ alpha = buffer_dict['mask'][0][:, :, 3:]
196
+ depth = convert_to_white_bg(buffer_dict['depth'][0])
197
+ normal = convert_to_white_bg(buffer_dict['gb_normal'][0], write_bg=False)
198
+ spec_light = convert_to_white_bg(buffer_dict['spec_light'][0])
199
+ diff_light = convert_to_white_bg(buffer_dict['diff_light'][0])
200
+ if i < input_view_num:
201
+ sample_input_images.append(image)
202
+ sample_input_albedos.append(albedo)
203
+ sample_input_alphas.append(alpha)
204
+ sample_input_depths.append(depth)
205
+ sample_input_normals.append(normal)
206
+ sample_input_spec_light.append(spec_light)
207
+ sample_input_diff_light.append(diff_light)
208
+ sample_input_w2cs.append(mvp)
209
+ sample_input_camera_pos.append(campos)
210
+ sample_input_c2ws.append(c2w_list_sample[i])
211
+ sample_input_Ks.append(FOV_to_intrinsics(fov_deg))
212
+ sample_input_env.append(env)
213
+ sample_input_materials.append(materials)
214
+ sample_input_camera_embeddings.append(camera_embedding)
215
+ else:
216
+ sample_target_images.append(image)
217
+ sample_target_albedos.append(albedo)
218
+ sample_target_alphas.append(alpha)
219
+ sample_target_depths.append(depth)
220
+ sample_target_normals.append(normal)
221
+ sample_target_spec_light.append(spec_light)
222
+ sample_target_diff_light.append(diff_light)
223
+ sample_target_w2cs.append(mvp)
224
+ sample_target_camera_pos.append(campos)
225
+ sample_target_Ks.append(FOV_to_intrinsics(fov_deg))
226
+ sample_target_env.append(env)
227
+ sample_target_materials.append(materials)
228
+
229
+ input_images.append(torch.stack(sample_input_images, dim=0).permute(0, 3, 1, 2))
230
+ input_albedos.append(torch.stack(sample_input_albedos, dim=0).permute(0, 3, 1, 2))
231
+ input_alphas.append(torch.stack(sample_input_alphas, dim=0).permute(0, 3, 1, 2))
232
+ input_depths.append(torch.stack(sample_input_depths, dim=0).permute(0, 3, 1, 2))
233
+ input_normals.append(torch.stack(sample_input_normals, dim=0).permute(0, 3, 1, 2))
234
+ input_spec_light.append(torch.stack(sample_input_spec_light, dim=0).permute(0, 3, 1, 2))
235
+ input_diff_light.append(torch.stack(sample_input_diff_light, dim=0).permute(0, 3, 1, 2))
236
+ input_w2cs.append(torch.stack(sample_input_w2cs, dim=0))
237
+ input_camera_pos.append(torch.stack(sample_input_camera_pos, dim=0))
238
+ input_c2ws.append(torch.stack(sample_input_c2ws, dim=0))
239
+ input_camera_embeddings.append(torch.stack(sample_input_camera_embeddings, dim=0))
240
+ input_Ks.append(torch.stack(sample_input_Ks, dim=0))
241
+ input_env.append(sample_input_env)
242
+ input_materials.append(sample_input_materials)
243
+
244
+ target_images.append(torch.stack(sample_target_images, dim=0).permute(0, 3, 1, 2))
245
+ target_albedos.append(torch.stack(sample_target_albedos, dim=0).permute(0, 3, 1, 2))
246
+ target_alphas.append(torch.stack(sample_target_alphas, dim=0).permute(0, 3, 1, 2))
247
+ target_depths.append(torch.stack(sample_target_depths, dim=0).permute(0, 3, 1, 2))
248
+ target_normals.append(torch.stack(sample_target_normals, dim=0).permute(0, 3, 1, 2))
249
+ target_spec_light.append(torch.stack(sample_target_spec_light, dim=0).permute(0, 3, 1, 2))
250
+ target_diff_light.append(torch.stack(sample_target_diff_light, dim=0).permute(0, 3, 1, 2))
251
+ target_w2cs.append(torch.stack(sample_target_w2cs, dim=0))
252
+ target_camera_pos.append(torch.stack(sample_target_camera_pos, dim=0))
253
+ target_Ks.append(torch.stack(sample_target_Ks, dim=0))
254
+ target_env.append(sample_target_env)
255
+ target_materials.append(sample_target_materials)
256
+
257
+ del ref_mesh
258
+ del material
259
+ del mesh_attributes
260
+ torch.cuda.empty_cache()
261
+ gc.collect()
262
+
263
+ data = {
264
+ 'input_images': torch.stack(input_images, dim=0).detach().cpu(), # (batch_size, input_view_num, 3, H, W)
265
+ 'input_alphas': torch.stack(input_alphas, dim=0).detach().cpu(), # (batch_size, input_view_num, 1, H, W)
266
+ 'input_depths': torch.stack(input_depths, dim=0).detach().cpu(),
267
+ 'input_normals': torch.stack(input_normals, dim=0).detach().cpu(),
268
+ 'input_albedos': torch.stack(input_albedos, dim=0).detach().cpu(),
269
+ 'input_spec_light': torch.stack(input_spec_light, dim=0).detach().cpu(),
270
+ 'input_diff_light': torch.stack(input_diff_light, dim=0).detach().cpu(),
271
+ 'input_materials': input_materials,
272
+ 'input_w2cs': torch.stack(input_w2cs, dim=0).squeeze(2), # (batch_size, input_view_num, 4, 4)
273
+ 'input_Ks': torch.stack(input_Ks, dim=0).float(), # (batch_size, input_view_num, 3, 3)
274
+ 'input_env': input_env,
275
+ 'input_camera_pos': torch.stack(input_camera_pos, dim=0).squeeze(2), # (batch_size, input_view_num, 3)
276
+ 'input_c2ws': torch.stack(input_c2ws, dim=0).squeeze(2), # (batch_size, input_view_num, 4, 4)
277
+ 'input_camera_embedding': torch.stack(input_camera_embeddings, dim=0).squeeze(2),
278
+
279
+ 'target_sample_points': None,
280
+ 'target_images': torch.stack(target_images, dim=0).detach().cpu(), # (batch_size, target_view_num, 3, H, W)
281
+ 'target_alphas': torch.stack(target_alphas, dim=0).detach().cpu(), # (batch_size, target_view_num, 1, H, W)
282
+ 'target_depths': torch.stack(target_depths, dim=0).detach().cpu(),
283
+ 'target_normals': torch.stack(target_normals, dim=0).detach().cpu(),
284
+ 'target_albedos': torch.stack(target_albedos, dim=0).detach().cpu(),
285
+ 'target_spec_light': torch.stack(target_spec_light, dim=0).detach().cpu(),
286
+ 'target_diff_light': torch.stack(target_diff_light, dim=0).detach().cpu(),
287
+ 'target_materials': target_materials,
288
+ 'target_w2cs': torch.stack(target_w2cs, dim=0).squeeze(2), # (batch_size, target_view_num, 4, 4)
289
+ 'target_Ks': torch.stack(target_Ks, dim=0).float(), # (batch_size, target_view_num, 3, 3)
290
+ 'target_env': target_env,
291
+ 'target_camera_pos': torch.stack(target_camera_pos, dim=0).squeeze(2) # (batch_size, target_view_num, 3)
292
+ }
293
+
294
+ return data
295
+
296
+ def prepare_batch_data(self, batch):
297
+ # breakpoint()
298
+ lrm_generator_input = {}
299
+ render_gt = {}
300
+
301
+ # input images
302
+ images = batch['input_images']
303
+ images = v2.functional.resize(images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
304
+ batch_size = images.shape[0]
305
+ # breakpoint()
306
+ lrm_generator_input['images'] = images.to(self.device)
307
+
308
+ # input cameras and render cameras
309
+ # input_c2ws = batch['input_c2ws']
310
+ input_Ks = batch['input_Ks']
311
+ # target_c2ws = batch['target_c2ws']
312
+ input_camera_embedding = batch["input_camera_embedding"].to(self.device)
313
+
314
+ input_w2cs = batch['input_w2cs']
315
+ target_w2cs = batch['target_w2cs']
316
+ render_w2cs = torch.cat([input_w2cs, target_w2cs], dim=1)
317
+
318
+ input_camera_pos = batch['input_camera_pos']
319
+ target_camera_pos = batch['target_camera_pos']
320
+ render_camera_pos = torch.cat([input_camera_pos, target_camera_pos], dim=1)
321
+
322
+ input_extrinsics = input_camera_embedding.flatten(-2)
323
+ input_extrinsics = input_extrinsics[:, :, :12]
324
+ input_intrinsics = input_Ks.flatten(-2).to(self.device)
325
+ input_intrinsics = torch.stack([
326
+ input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
327
+ input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
328
+ ], dim=-1)
329
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
330
+
331
+ # add noise to input_cameras
332
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
333
+
334
+ lrm_generator_input['cameras'] = cameras.to(self.device)
335
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
336
+ lrm_generator_input['cameras_pos'] = render_camera_pos.to(self.device)
337
+ lrm_generator_input['env'] = []
338
+ lrm_generator_input['materials'] = []
339
+ for i in range(batch_size):
340
+ lrm_generator_input['env'].append( batch['input_env'][i] + batch['target_env'][i])
341
+ lrm_generator_input['materials'].append( batch['input_materials'][i] + batch['target_materials'][i])
342
+ lrm_generator_input['albedo'] = torch.cat([batch['input_albedos'],batch['target_albedos']],dim=1)
343
+
344
+ # target images
345
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
346
+ target_albedos = torch.cat([batch['input_albedos'], batch['target_albedos']], dim=1)
347
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
348
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
349
+ target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
350
+ target_spec_lights = torch.cat([batch['input_spec_light'], batch['target_spec_light']], dim=1)
351
+ target_diff_lights = torch.cat([batch['input_diff_light'], batch['target_diff_light']], dim=1)
352
+
353
+ render_size = self.render_size
354
+ target_images = v2.functional.resize(
355
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
356
+ target_depths = v2.functional.resize(
357
+ target_depths, render_size, interpolation=0, antialias=True)
358
+ target_alphas = v2.functional.resize(
359
+ target_alphas, render_size, interpolation=0, antialias=True)
360
+ target_normals = v2.functional.resize(
361
+ target_normals, render_size, interpolation=3, antialias=True)
362
+
363
+ lrm_generator_input['render_size'] = render_size
364
+
365
+ render_gt['target_sample_points'] = batch['target_sample_points']
366
+ render_gt['target_images'] = target_images.to(self.device)
367
+ render_gt['target_albedos'] = target_albedos.to(self.device)
368
+ render_gt['target_depths'] = target_depths.to(self.device)
369
+ render_gt['target_alphas'] = target_alphas.to(self.device)
370
+ render_gt['target_normals'] = target_normals.to(self.device)
371
+ render_gt['target_spec_lights'] = target_spec_lights.to(self.device)
372
+ render_gt['target_diff_lights'] = target_diff_lights.to(self.device)
373
+ # render_gt['target_spec_albedos'] = target_spec_albedos.to(self.device)
374
+ # render_gt['target_diff_albedos'] = target_diff_albedos.to(self.device)
375
+ return lrm_generator_input, render_gt
376
+
377
+ def prepare_validation_batch_data(self, batch):
378
+ lrm_generator_input = {}
379
+
380
+ # input images
381
+ images = batch['input_images']
382
+ images = v2.functional.resize(
383
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
384
+
385
+ lrm_generator_input['images'] = images.to(self.device)
386
+ lrm_generator_input['specular_light'] = batch['specular']
387
+ lrm_generator_input['diffuse_light'] = batch['diffuse']
388
+
389
+ lrm_generator_input['metallic'] = batch['input_metallics']
390
+ lrm_generator_input['roughness'] = batch['input_roughness']
391
+
392
+ proj = self.perspective(0.449, 1, 0.1, 1000., self.device)
393
+
394
+ # input cameras
395
+ input_c2ws = batch['input_c2ws'].flatten(-2)
396
+ input_Ks = batch['input_Ks'].flatten(-2)
397
+
398
+ input_extrinsics = input_c2ws[:, :, :12]
399
+ input_intrinsics = torch.stack([
400
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
401
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
402
+ ], dim=-1)
403
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
404
+
405
+ lrm_generator_input['cameras'] = cameras.to(self.device)
406
+
407
+ # render cameras
408
+ render_c2ws = batch['render_c2ws']
409
+
410
+ lrm_generator_input['camera_pos'] = torch.linalg.inv(render_w2cs.to(self.device) @ rotate_x(np.pi / 2, self.device))[..., :3, 3]
411
+ render_w2cs = ( render_w2cs @ rotate_x(np.pi / 2) )
412
+
413
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
414
+ lrm_generator_input['render_size'] = 384
415
+
416
+ return lrm_generator_input
417
+
418
+ def forward_lrm_generator(self, images, cameras, camera_pos,env, materials, albedo_map, render_cameras, render_size=512, sample_points=None, gt_albedo_map=None):
419
+ planes = torch.utils.checkpoint.checkpoint(
420
+ self.lrm_generator.forward_planes,
421
+ images,
422
+ cameras,
423
+ use_reentrant=False,
424
+ )
425
+ out = self.lrm_generator.forward_geometry(
426
+ planes,
427
+ render_cameras,
428
+ camera_pos,
429
+ env,
430
+ materials,
431
+ albedo_map,
432
+ render_size,
433
+ sample_points,
434
+ gt_albedo_map
435
+ )
436
+ return out
437
+
438
+ def forward(self, lrm_generator_input, gt_albedo_map=None):
439
+ images = lrm_generator_input['images']
440
+ cameras = lrm_generator_input['cameras']
441
+ render_cameras = lrm_generator_input['render_cameras']
442
+ render_size = lrm_generator_input['render_size']
443
+ env = lrm_generator_input['env']
444
+ materials = lrm_generator_input['materials']
445
+ albedo_map = lrm_generator_input['albedo']
446
+ camera_pos = lrm_generator_input['cameras_pos']
447
+
448
+ out = self.forward_lrm_generator(
449
+ images, cameras, camera_pos, env, materials, albedo_map, render_cameras, render_size=render_size, sample_points=self.sample_points, gt_albedo_map=gt_albedo_map)
450
+
451
+ return out
452
+
453
+ def training_step(self, batch, batch_idx):
454
+ batch = self.collate_fn(batch)
455
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
456
+ if self.use_gt_albedo:
457
+ gt_albedo_map = render_gt['target_albedos']
458
+ else:
459
+ gt_albedo_map = None
460
+ render_out = self.forward(lrm_generator_input, gt_albedo_map=gt_albedo_map)
461
+
462
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
463
+
464
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, batch_size=len(batch['input_images']), sync_dist=True)
465
+
466
+ if self.global_step % 20 == 0 and self.global_rank == 0 :
467
+ B, N, C, H, W = render_gt['target_images'].shape
468
+ N_in = lrm_generator_input['images'].shape[1]
469
+
470
+ target_images = rearrange(render_gt['target_images'], 'b n c h w -> b c h (n w)')
471
+ render_images = rearrange(render_out['pbr_img'], 'b n c h w -> b c h (n w)')
472
+ target_alphas = rearrange(repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
473
+ target_spec_light = rearrange(render_gt['target_spec_lights'], 'b n c h w -> b c h (n w)')
474
+ target_diff_light = rearrange(render_gt['target_diff_lights'], 'b n c h w -> b c h (n w)')
475
+
476
+ render_alphas = rearrange(render_out['mask'], 'b n c h w -> b c h (n w)')
477
+ render_albodos = rearrange(render_out['albedo'], 'b n c h w -> b c h (n w)')
478
+ target_albedos = rearrange(render_gt['target_albedos'], 'b n c h w -> b c h (n w)')
479
+
480
+ render_spec_light = rearrange(render_out['pbr_spec_light'], 'b n c h w -> b c h (n w)')
481
+ render_diffuse_light = rearrange(render_out['pbr_diffuse_light'], 'b n c h w -> b c h (n w)')
482
+ render_normal = rearrange(render_out['normal_img'], 'b n c h w -> b c h (n w)')
483
+ target_depths = rearrange(render_gt['target_depths'], 'b n c h w -> b c h (n w)')
484
+ render_depths = rearrange(render_out['depth'], 'b n c h w -> b c h (n w)')
485
+ target_normals = rearrange(render_gt['target_normals'], 'b n c h w -> b c h (n w)')
486
+
487
+ MAX_DEPTH = torch.max(target_depths)
488
+ target_depths = target_depths / MAX_DEPTH * target_alphas
489
+ render_depths = render_depths / MAX_DEPTH * render_alphas
490
+
491
+ grid = torch.cat([
492
+ target_images, render_images,
493
+ target_alphas, render_alphas,
494
+ target_albedos, render_albodos,
495
+ target_spec_light, render_spec_light,
496
+ target_diff_light, render_diffuse_light,
497
+ (target_normals+1)/2, (render_normal+1)/2,
498
+ target_depths, render_depths
499
+ ], dim=-2).detach().cpu()
500
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
501
+
502
+ image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
503
+ save_image(grid, image_path)
504
+ print(f"Saved image to {image_path}")
505
+ return loss
506
+
507
+
508
+ def total_variation_loss(self, img, beta=2.0):
509
+ bs_img, n_view, c_img, h_img, w_img = img.size()
510
+ tv_h = torch.pow(img[...,1:,:]-img[...,:-1,:], beta).sum()
511
+ tv_w = torch.pow(img[...,:,1:]-img[...,:,:-1], beta).sum()
512
+ return (tv_h+tv_w)/(bs_img*n_view*c_img*h_img*w_img)
513
+
514
+
515
+ def compute_loss(self, render_out, render_gt):
516
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
517
+ render_albedo_image = render_out['albedo']
518
+ render_pbr_image = render_out['pbr_img']
519
+ render_spec_light = render_out['pbr_spec_light']
520
+ render_diff_light = render_out['pbr_diffuse_light']
521
+
522
+ target_images = render_gt['target_images'].to(render_albedo_image)
523
+ target_albedos = render_gt['target_albedos'].to(render_albedo_image)
524
+ target_spec_light = render_gt['target_spec_lights'].to(render_albedo_image)
525
+ target_diff_light = render_gt['target_diff_lights'].to(render_albedo_image)
526
+
527
+ render_images = rearrange(render_pbr_image, 'b n ... -> (b n) ...') * 2.0 - 1.0
528
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
529
+
530
+ render_albedos = rearrange(render_albedo_image, 'b n ... -> (b n) ...') * 2.0 - 1.0
531
+ target_albedos = rearrange(target_albedos, 'b n ... -> (b n) ...') * 2.0 - 1.0
532
+
533
+ render_spec_light = rearrange(render_spec_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
534
+ target_spec_light = rearrange(target_spec_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
535
+
536
+ render_diff_light = rearrange(render_diff_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
537
+ target_diff_light = rearrange(target_diff_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
538
+
539
+
540
+ loss_mse = F.mse_loss(render_images, target_images)
541
+ loss_mse_albedo = F.mse_loss(render_albedos, target_albedos)
542
+ loss_rgb_lpips = 2.0 * self.lpips(render_images, target_images)
543
+ loss_albedo_lpips = 2.0 * self.lpips(render_albedos, target_albedos)
544
+
545
+ loss_spec_light = F.mse_loss(render_spec_light, target_spec_light)
546
+ loss_diff_light = F.mse_loss(render_diff_light, target_diff_light)
547
+ loss_spec_light_lpips = 2.0 * self.lpips(render_spec_light.clamp(-1., 1.), target_spec_light.clamp(-1., 1.))
548
+ loss_diff_light_lpips = 2.0 * self.lpips(render_diff_light.clamp(-1., 1.), target_diff_light.clamp(-1., 1.))
549
+
550
+ render_alphas = render_out['mask'][:,:,:1,:,:]
551
+ target_alphas = render_gt['target_alphas']
552
+
553
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
554
+ render_depths = torch.mean(render_out['depth'], dim=2, keepdim=True)
555
+ target_depths = torch.mean(render_gt['target_depths'], dim=2, keepdim=True)
556
+ loss_depth = 0.5 * F.l1_loss(render_depths[(target_alphas>0)], target_depths[target_alphas>0])
557
+
558
+ render_normals = render_out['normal'][...,:3].permute(0,3,1,2).unsqueeze(0)
559
+ target_normals = render_gt['target_normals']
560
+ similarity = (render_normals * target_normals).sum(dim=-3).abs()
561
+ normal_mask = target_alphas.squeeze(-3)
562
+ loss_normal = 1 - similarity[normal_mask>0].mean()
563
+ loss_normal = 0.2 * loss_normal * 1.0
564
+
565
+ # tv loss
566
+ if self.use_tv_loss:
567
+ triplane = render_out['triplane']
568
+ tv_loss = self.total_variation_loss(triplane, beta=2.0)
569
+
570
+ # flexicubes regularization loss
571
+ sdf = render_out['sdf']
572
+ sdf_reg_loss = render_out['sdf_reg_loss']
573
+ sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
574
+ _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
575
+ flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
576
+ flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
577
+
578
+ loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
579
+ loss_reg = loss_reg
580
+ loss = loss_mse + loss_rgb_lpips + loss_albedo_lpips + loss_mask + loss_reg + loss_mse_albedo + loss_depth + \
581
+ loss_normal + loss_spec_light + loss_diff_light + loss_spec_light_lpips + loss_diff_light_lpips
582
+ if self.use_tv_loss:
583
+ loss += tv_loss * 2e-4
584
+
585
+ prefix = 'train'
586
+ loss_dict = {}
587
+
588
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse.item()})
589
+ loss_dict.update({f'{prefix}/loss_mse_albedo': loss_mse_albedo.item()})
590
+ loss_dict.update({f'{prefix}/loss_rgb_lpips': loss_rgb_lpips.item()})
591
+ loss_dict.update({f'{prefix}/loss_albedo_lpips': loss_albedo_lpips.item()})
592
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask.item()})
593
+ loss_dict.update({f'{prefix}/loss_normal': loss_normal.item()})
594
+ loss_dict.update({f'{prefix}/loss_depth': loss_depth.item()})
595
+ loss_dict.update({f'{prefix}/loss_spec_light': loss_spec_light.item()})
596
+ loss_dict.update({f'{prefix}/loss_diff_light': loss_diff_light.item()})
597
+ loss_dict.update({f'{prefix}/loss_spec_light_lpips': loss_spec_light_lpips.item()})
598
+ loss_dict.update({f'{prefix}/loss_diff_light_lpips': loss_diff_light_lpips.item()})
599
+ loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy.item()})
600
+ loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg.item()})
601
+ loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg.item()})
602
+ if self.use_tv_loss:
603
+ loss_dict.update({f'{prefix}/loss_tv': tv_loss.item()})
604
+ loss_dict.update({f'{prefix}/loss': loss.item()})
605
+
606
+ return loss, loss_dict
607
+
608
+ @torch.no_grad()
609
+ def validation_step(self, batch, batch_idx):
610
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
611
+
612
+ render_out = self.forward(lrm_generator_input)
613
+ render_images = rearrange(render_out['pbr_img'], 'b n c h w -> b c h (n w)')
614
+ render_albodos = rearrange(render_out['img'], 'b n c h w -> b c h (n w)')
615
+
616
+ self.validation_step_outputs.append(render_images)
617
+ self.validation_step_outputs.append(render_albodos)
618
+
619
+ def on_validation_epoch_end(self):
620
+ images = torch.cat(self.validation_step_outputs, dim=0)
621
+
622
+ all_images = self.all_gather(images)
623
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
624
+
625
+ if self.global_rank == 0:
626
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
627
+
628
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
629
+
630
+ save_image(grid, image_path)
631
+ print(f"Saved image to {image_path}")
632
+
633
+ self.validation_step_outputs.clear()
634
+
635
+ def configure_optimizers(self):
636
+ lr = self.learning_rate
637
+
638
+ optimizer = torch.optim.AdamW(
639
+ self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
640
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
641
+
642
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/models/__init__.py ADDED
File without changes
src/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
src/models/__pycache__/lrm_mesh.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
src/models/decoder/__init__.py ADDED
File without changes
src/models/decoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
src/models/decoder/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (3.45 kB). View file
 
src/models/decoder/transformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class BasicTransformerBlock(nn.Module):
21
+ """
22
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
+ """
24
+ # use attention from torch.nn.MultiHeadAttention
25
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
+ def __init__(
27
+ self,
28
+ inner_dim: int,
29
+ cond_dim: int,
30
+ num_heads: int,
31
+ eps: float,
32
+ attn_drop: float = 0.,
33
+ attn_bias: bool = False,
34
+ mlp_ratio: float = 4.,
35
+ mlp_drop: float = 0.,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm1 = nn.LayerNorm(inner_dim)
40
+ self.cross_attn = nn.MultiheadAttention(
41
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
43
+ self.norm2 = nn.LayerNorm(inner_dim)
44
+ self.self_attn = nn.MultiheadAttention(
45
+ embed_dim=inner_dim, num_heads=num_heads,
46
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
47
+ self.norm3 = nn.LayerNorm(inner_dim)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
+ nn.GELU(),
51
+ nn.Dropout(mlp_drop),
52
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
+ nn.Dropout(mlp_drop),
54
+ )
55
+
56
+ def forward(self, x, cond):
57
+ # x: [N, L, D]
58
+ # cond: [N, L_cond, D_cond]
59
+ x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
+ before_sa = self.norm2(x)
61
+ x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
+ x = x + self.mlp(self.norm3(x))
63
+ return x
64
+
65
+
66
+ class TriplaneTransformer(nn.Module):
67
+ """
68
+ Transformer with condition that generates a triplane representation.
69
+
70
+ Reference:
71
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
+ """
73
+ def __init__(
74
+ self,
75
+ inner_dim: int,
76
+ image_feat_dim: int,
77
+ triplane_low_res: int,
78
+ triplane_high_res: int,
79
+ triplane_dim: int,
80
+ num_layers: int,
81
+ num_heads: int,
82
+ eps: float = 1e-6,
83
+ ):
84
+ super().__init__()
85
+
86
+ # attributes
87
+ self.triplane_low_res = triplane_low_res
88
+ self.triplane_high_res = triplane_high_res
89
+ self.triplane_dim = triplane_dim
90
+
91
+ # modules
92
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
+ self.layers = nn.ModuleList([
95
+ BasicTransformerBlock(
96
+ inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
+ for _ in range(num_layers)
98
+ ])
99
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
+
102
+ def forward(self, image_feats):
103
+ # image_feats: [N, L_cond, D_cond]
104
+
105
+ N = image_feats.shape[0]
106
+ H = W = self.triplane_low_res
107
+ L = 3 * H * W
108
+
109
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
+ for layer in self.layers:
111
+ x = layer(x, image_feats)
112
+ x = self.norm(x)
113
+
114
+ # separate each plane and apply deconv
115
+ x = x.view(N, 3, H, W, -1)
116
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
+ x = self.deconv(x) # [3*N, D', H', W']
119
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
+ x = x.contiguous()
122
+
123
+ return x
src/models/encoder/__init__.py ADDED
File without changes
src/models/encoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
src/models/encoder/__pycache__/dino.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
src/models/encoder/dino.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch ViT model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPooling,
29
+ )
30
+ from transformers import PreTrainedModel, ViTConfig
31
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
+
33
+
34
+ class ViTEmbeddings(nn.Module):
35
+ """
36
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
+ """
38
+
39
+ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
+ super().__init__()
41
+
42
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
+ self.patch_embeddings = ViTPatchEmbeddings(config)
45
+ num_patches = self.patch_embeddings.num_patches
46
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
+ self.config = config
49
+
50
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
+ """
52
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
+ resolution images.
54
+
55
+ Source:
56
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
+ """
58
+
59
+ num_patches = embeddings.shape[1] - 1
60
+ num_positions = self.position_embeddings.shape[1] - 1
61
+ if num_patches == num_positions and height == width:
62
+ return self.position_embeddings
63
+ class_pos_embed = self.position_embeddings[:, 0]
64
+ patch_pos_embed = self.position_embeddings[:, 1:]
65
+ dim = embeddings.shape[-1]
66
+ h0 = height // self.config.patch_size
67
+ w0 = width // self.config.patch_size
68
+ # we add a small number to avoid floating point error in the interpolation
69
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
70
+ h0, w0 = h0 + 0.1, w0 + 0.1
71
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
+ patch_pos_embed = nn.functional.interpolate(
74
+ patch_pos_embed,
75
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
+ mode="bicubic",
77
+ align_corners=False,
78
+ )
79
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+
83
+ def forward(
84
+ self,
85
+ pixel_values: torch.Tensor,
86
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
87
+ interpolate_pos_encoding: bool = False,
88
+ ) -> torch.Tensor:
89
+ batch_size, num_channels, height, width = pixel_values.shape
90
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
+
92
+ if bool_masked_pos is not None:
93
+ seq_length = embeddings.shape[1]
94
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
+ # replace the masked visual tokens by mask_tokens
96
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
+
99
+ # add the [CLS] token to the embedded patch tokens
100
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
+
103
+ # add positional encoding to each token
104
+ if interpolate_pos_encoding:
105
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
+ else:
107
+ embeddings = embeddings + self.position_embeddings
108
+
109
+ embeddings = self.dropout(embeddings)
110
+
111
+ return embeddings
112
+
113
+
114
+ class ViTPatchEmbeddings(nn.Module):
115
+ """
116
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
+ Transformer.
119
+ """
120
+
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ image_size, patch_size = config.image_size, config.patch_size
124
+ num_channels, hidden_size = config.num_channels, config.hidden_size
125
+
126
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.num_patches = num_patches
133
+
134
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
+
136
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
+ batch_size, num_channels, height, width = pixel_values.shape
138
+ if num_channels != self.num_channels:
139
+ raise ValueError(
140
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
+ f" Expected {self.num_channels} but got {num_channels}."
142
+ )
143
+ if not interpolate_pos_encoding:
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model"
147
+ f" ({self.image_size[0]}*{self.image_size[1]})."
148
+ )
149
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
+ return embeddings
151
+
152
+
153
+ class ViTSelfAttention(nn.Module):
154
+ def __init__(self, config: ViTConfig) -> None:
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
+ f"heads {config.num_attention_heads}."
160
+ )
161
+
162
+ self.num_attention_heads = config.num_attention_heads
163
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
165
+
166
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
+
170
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
+
172
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
+ x = x.view(new_x_shape)
175
+ return x.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
+ mixed_query_layer = self.query(hidden_states)
181
+
182
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
183
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ # Take the dot product between "query" and "key" to get the raw attention scores.
187
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
+
189
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
+
194
+ # This is actually dropping out entire tokens to attend to, which might
195
+ # seem a bit unusual, but is taken from the original Transformer paper.
196
+ attention_probs = self.dropout(attention_probs)
197
+
198
+ # Mask heads if we want to
199
+ if head_mask is not None:
200
+ attention_probs = attention_probs * head_mask
201
+
202
+ context_layer = torch.matmul(attention_probs, value_layer)
203
+
204
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
+ context_layer = context_layer.view(new_context_layer_shape)
207
+
208
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
+
210
+ return outputs
211
+
212
+
213
+ class ViTSelfOutput(nn.Module):
214
+ """
215
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
+ layernorm applied before each block.
217
+ """
218
+
219
+ def __init__(self, config: ViTConfig) -> None:
220
+ super().__init__()
221
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+
224
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.dropout(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+
231
+ class ViTAttention(nn.Module):
232
+ def __init__(self, config: ViTConfig) -> None:
233
+ super().__init__()
234
+ self.attention = ViTSelfAttention(config)
235
+ self.output = ViTSelfOutput(config)
236
+ self.pruned_heads = set()
237
+
238
+ def prune_heads(self, heads: Set[int]) -> None:
239
+ if len(heads) == 0:
240
+ return
241
+ heads, index = find_pruneable_heads_and_indices(
242
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
+ )
244
+
245
+ # Prune linear layers
246
+ self.attention.query = prune_linear_layer(self.attention.query, index)
247
+ self.attention.key = prune_linear_layer(self.attention.key, index)
248
+ self.attention.value = prune_linear_layer(self.attention.value, index)
249
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
+
251
+ # Update hyper params and store pruned heads
252
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ head_mask: Optional[torch.Tensor] = None,
260
+ output_attentions: bool = False,
261
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
+
264
+ attention_output = self.output(self_outputs[0], hidden_states)
265
+
266
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
+ return outputs
268
+
269
+
270
+ class ViTIntermediate(nn.Module):
271
+ def __init__(self, config: ViTConfig) -> None:
272
+ super().__init__()
273
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
+ if isinstance(config.hidden_act, str):
275
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
+ else:
277
+ self.intermediate_act_fn = config.hidden_act
278
+
279
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
+ hidden_states = self.dense(hidden_states)
281
+ hidden_states = self.intermediate_act_fn(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class ViTOutput(nn.Module):
287
+ def __init__(self, config: ViTConfig) -> None:
288
+ super().__init__()
289
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+
296
+ hidden_states = hidden_states + input_tensor
297
+
298
+ return hidden_states
299
+
300
+
301
+ def modulate(x, shift, scale):
302
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
+
304
+
305
+ class ViTLayer(nn.Module):
306
+ """This corresponds to the Block class in the timm implementation."""
307
+
308
+ def __init__(self, config: ViTConfig) -> None:
309
+ super().__init__()
310
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
+ self.seq_len_dim = 1
312
+ self.attention = ViTAttention(config)
313
+ self.intermediate = ViTIntermediate(config)
314
+ self.output = ViTOutput(config)
315
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+
318
+ self.adaLN_modulation = nn.Sequential(
319
+ nn.SiLU(),
320
+ nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
+ )
322
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ adaln_input: torch.Tensor = None,
329
+ head_mask: Optional[torch.Tensor] = None,
330
+ output_attentions: bool = False,
331
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
+ shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
+
334
+ self_attention_outputs = self.attention(
335
+ modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
+ head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+ attention_output = self_attention_outputs[0]
340
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
+
342
+ # first residual connection
343
+ hidden_states = attention_output + hidden_states
344
+
345
+ # in ViT, layernorm is also applied after self-attention
346
+ layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
+ layer_output = self.intermediate(layer_output)
348
+
349
+ # second residual connection is done here
350
+ layer_output = self.output(layer_output, hidden_states)
351
+
352
+ outputs = (layer_output,) + outputs
353
+
354
+ return outputs
355
+
356
+
357
+ class ViTEncoder(nn.Module):
358
+ def __init__(self, config: ViTConfig) -> None:
359
+ super().__init__()
360
+ self.config = config
361
+ self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
+ self.gradient_checkpointing = False
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ adaln_input: torch.Tensor = None,
368
+ head_mask: Optional[torch.Tensor] = None,
369
+ output_attentions: bool = False,
370
+ output_hidden_states: bool = False,
371
+ return_dict: bool = True,
372
+ ) -> Union[tuple, BaseModelOutput]:
373
+ all_hidden_states = () if output_hidden_states else None
374
+ all_self_attentions = () if output_attentions else None
375
+
376
+ for i, layer_module in enumerate(self.layer):
377
+ if output_hidden_states:
378
+ all_hidden_states = all_hidden_states + (hidden_states,)
379
+
380
+ layer_head_mask = head_mask[i] if head_mask is not None else None
381
+
382
+ if self.gradient_checkpointing and self.training:
383
+ layer_outputs = self._gradient_checkpointing_func(
384
+ layer_module.__call__,
385
+ hidden_states,
386
+ adaln_input,
387
+ layer_head_mask,
388
+ output_attentions,
389
+ )
390
+ else:
391
+ layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
+ return BaseModelOutput(
404
+ last_hidden_state=hidden_states,
405
+ hidden_states=all_hidden_states,
406
+ attentions=all_self_attentions,
407
+ )
408
+
409
+
410
+ class ViTPreTrainedModel(PreTrainedModel):
411
+ """
412
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
+ models.
414
+ """
415
+
416
+ config_class = ViTConfig
417
+ base_model_prefix = "vit"
418
+ main_input_name = "pixel_values"
419
+ supports_gradient_checkpointing = True
420
+ _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
+
422
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
+ """Initialize the weights"""
424
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
425
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
+ # `trunc_normal_cpu` not implemented in `half` issues
427
+ module.weight.data = nn.init.trunc_normal_(
428
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
+ ).to(module.weight.dtype)
430
+ if module.bias is not None:
431
+ module.bias.data.zero_()
432
+ elif isinstance(module, nn.LayerNorm):
433
+ module.bias.data.zero_()
434
+ module.weight.data.fill_(1.0)
435
+ elif isinstance(module, ViTEmbeddings):
436
+ module.position_embeddings.data = nn.init.trunc_normal_(
437
+ module.position_embeddings.data.to(torch.float32),
438
+ mean=0.0,
439
+ std=self.config.initializer_range,
440
+ ).to(module.position_embeddings.dtype)
441
+
442
+ module.cls_token.data = nn.init.trunc_normal_(
443
+ module.cls_token.data.to(torch.float32),
444
+ mean=0.0,
445
+ std=self.config.initializer_range,
446
+ ).to(module.cls_token.dtype)
447
+
448
+
449
+ class ViTModel(ViTPreTrainedModel):
450
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
+ super().__init__(config)
452
+ self.config = config
453
+
454
+ self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
+ self.encoder = ViTEncoder(config)
456
+
457
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.pooler = ViTPooler(config) if add_pooling_layer else None
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
+ return self.embeddings.patch_embeddings
465
+
466
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
+ """
468
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
+ class PreTrainedModel
470
+ """
471
+ for layer, heads in heads_to_prune.items():
472
+ self.encoder.layer[layer].attention.prune_heads(heads)
473
+
474
+ def forward(
475
+ self,
476
+ pixel_values: Optional[torch.Tensor] = None,
477
+ adaln_input: Optional[torch.Tensor] = None,
478
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
479
+ head_mask: Optional[torch.Tensor] = None,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ interpolate_pos_encoding: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
+ r"""
486
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
+ """
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
+
495
+ if pixel_values is None:
496
+ raise ValueError("You have to specify pixel_values")
497
+
498
+ # Prepare head mask if needed
499
+ # 1.0 in head_mask indicate we keep the head
500
+ # attention_probs has shape bsz x n_heads x N x N
501
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
+
505
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
+ if pixel_values.dtype != expected_dtype:
508
+ pixel_values = pixel_values.to(expected_dtype)
509
+
510
+ embedding_output = self.embeddings(
511
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
+ )
513
+
514
+ encoder_outputs = self.encoder(
515
+ embedding_output,
516
+ adaln_input=adaln_input,
517
+ head_mask=head_mask,
518
+ output_attentions=output_attentions,
519
+ output_hidden_states=output_hidden_states,
520
+ return_dict=return_dict,
521
+ )
522
+ sequence_output = encoder_outputs[0]
523
+ sequence_output = self.layernorm(sequence_output)
524
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
+
526
+ if not return_dict:
527
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
+ return head_outputs + encoder_outputs[1:]
529
+
530
+ return BaseModelOutputWithPooling(
531
+ last_hidden_state=sequence_output,
532
+ pooler_output=pooled_output,
533
+ hidden_states=encoder_outputs.hidden_states,
534
+ attentions=encoder_outputs.attentions,
535
+ )
536
+
537
+
538
+ class ViTPooler(nn.Module):
539
+ def __init__(self, config: ViTConfig):
540
+ super().__init__()
541
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
+ self.activation = nn.Tanh()
543
+
544
+ def forward(self, hidden_states):
545
+ # We "pool" the model by simply taking the hidden state corresponding
546
+ # to the first token.
547
+ first_token_tensor = hidden_states[:, 0]
548
+ pooled_output = self.dense(first_token_tensor)
549
+ pooled_output = self.activation(pooled_output)
550
+ return pooled_output
src/models/encoder/dino_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+ from transformers import ViTImageProcessor
18
+ from einops import rearrange, repeat
19
+ from .dino import ViTModel
20
+
21
+
22
+ class DinoWrapper(nn.Module):
23
+ """
24
+ Dino v1 wrapper using huggingface transformer implementation.
25
+ """
26
+ def __init__(self, model_name: str, freeze: bool = True):
27
+ super().__init__()
28
+ self.model, self.processor = self._build_dino(model_name)
29
+ self.camera_embedder = nn.Sequential(
30
+ nn.Linear(16, self.model.config.hidden_size, bias=True),
31
+ nn.SiLU(),
32
+ nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
+ )
34
+ if freeze:
35
+ self._freeze()
36
+
37
+ def forward(self, image, camera):
38
+ # image: [B, N, C, H, W]
39
+ # camera: [B, N, D]
40
+ # RGB image with [0,1] scale and properly sized
41
+ if image.ndim == 5:
42
+ image = rearrange(image, 'b n c h w -> (b n) c h w')
43
+ dtype = image.dtype
44
+ inputs = self.processor(
45
+ images=image.float(),
46
+ return_tensors="pt",
47
+ do_rescale=False,
48
+ do_resize=False,
49
+ ).to(self.model.device).to(dtype)
50
+ # embed camera
51
+ N = camera.shape[1]
52
+ camera_embeddings = self.camera_embedder(camera)
53
+ camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
+ embeddings = camera_embeddings
55
+ # This resampling of positional embedding uses bicubic interpolation
56
+ outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
+ last_hidden_states = outputs.last_hidden_state
58
+ return last_hidden_states
59
+
60
+ def _freeze(self):
61
+ print(f"======== Freezing DinoWrapper ========")
62
+ self.model.eval()
63
+ for name, param in self.model.named_parameters():
64
+ param.requires_grad = False
65
+
66
+ @staticmethod
67
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
+ import requests
69
+ try:
70
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
+ processor = ViTImageProcessor.from_pretrained(model_name)
72
+ return model, processor
73
+ except requests.exceptions.ProxyError as err:
74
+ if proxy_error_retries > 0:
75
+ print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
+ import time
77
+ time.sleep(proxy_error_cooldown)
78
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
+ else:
80
+ raise err
src/models/geometry/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
src/models/geometry/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
src/models/geometry/camera/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class Camera(nn.Module):
14
+ def __init__(self):
15
+ super(Camera, self).__init__()
16
+ pass
src/models/geometry/camera/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (547 Bytes). View file
 
src/models/geometry/camera/__pycache__/perspective_camera.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
src/models/geometry/camera/perspective_camera.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from . import Camera
11
+ import numpy as np
12
+
13
+
14
+ def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
15
+ if near_plane is None:
16
+ near_plane = n
17
+ return np.array(
18
+ [[n / x, 0, 0, 0],
19
+ [0, n / -x, 0, 0],
20
+ [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
21
+ [0, 0, -1, 0]]).astype(np.float32)
22
+
23
+
24
+ class PerspectiveCamera(Camera):
25
+ def __init__(self, fovy=49.0, device='cuda'):
26
+ super(PerspectiveCamera, self).__init__()
27
+ self.device = device
28
+ focal = np.tan(fovy / 180.0 * np.pi * 0.5)
29
+ self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
30
+
31
+ def project(self, points_bxnx4):
32
+ out = torch.matmul(
33
+ points_bxnx4,
34
+ torch.transpose(self.proj_mtx, 1, 2))
35
+ return out
src/models/geometry/render/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Renderer():
4
+ def __init__(self):
5
+ pass
6
+
7
+ def forward(self):
8
+ pass
src/models/geometry/render/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (565 Bytes). View file
 
src/models/geometry/render/__pycache__/neural_render.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
src/models/geometry/render/__pycache__/util.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
src/models/geometry/render/neural_render.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import nvdiffrast.torch as dr
12
+ from . import Renderer
13
+ from . import util
14
+ from . import renderutils as ru
15
+ _FG_LUT = None
16
+
17
+
18
+ def interpolate(attr, rast, attr_idx, rast_db=None):
19
+ return dr.interpolate(
20
+ attr.contiguous(), rast, attr_idx, rast_db=rast_db,
21
+ diff_attrs=None if rast_db is None else 'all')
22
+
23
+
24
+ def xfm_points(points, matrix, use_python=True):
25
+ '''Transform points.
26
+ Args:
27
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
28
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
29
+ use_python: Use PyTorch's torch.matmul (for validation)
30
+ Returns:
31
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
32
+ '''
33
+ out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
34
+ if torch.is_anomaly_enabled():
35
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
36
+ return out
37
+
38
+
39
+ def dot(x, y):
40
+ return torch.sum(x * y, -1, keepdim=True)
41
+
42
+
43
+ def compute_vertex_normal(v_pos, t_pos_idx):
44
+ i0 = t_pos_idx[:, 0]
45
+ i1 = t_pos_idx[:, 1]
46
+ i2 = t_pos_idx[:, 2]
47
+
48
+ v0 = v_pos[i0, :]
49
+ v1 = v_pos[i1, :]
50
+ v2 = v_pos[i2, :]
51
+
52
+ face_normals = torch.cross(v1 - v0, v2 - v0)
53
+
54
+ # Splat face normals to vertices
55
+ v_nrm = torch.zeros_like(v_pos)
56
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
57
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
58
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
59
+
60
+ # Normalize, replace zero (degenerated) normals with some default value
61
+ v_nrm = torch.where(
62
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
63
+ )
64
+ v_nrm = F.normalize(v_nrm, dim=1)
65
+ assert torch.all(torch.isfinite(v_nrm))
66
+
67
+ return v_nrm
68
+
69
+
70
+ class NeuralRender(Renderer):
71
+ def __init__(self, device='cuda', camera_model=None):
72
+ super(NeuralRender, self).__init__()
73
+ self.device = device
74
+ self.ctx = dr.RasterizeCudaContext(device=device)
75
+ self.projection_mtx = None
76
+ self.camera = camera_model
77
+
78
+ # ==============================================================================================
79
+ # pixel shader
80
+ # ==============================================================================================
81
+ # def shade(
82
+ # self,
83
+ # gb_pos,
84
+ # gb_geometric_normal,
85
+ # gb_normal,
86
+ # gb_tangent,
87
+ # gb_texc,
88
+ # gb_texc_deriv,
89
+ # view_pos,
90
+ # ):
91
+
92
+ # ################################################################################
93
+ # # Texture lookups
94
+ # ################################################################################
95
+ # breakpoint()
96
+ # # Separate kd into alpha and color, default alpha = 1
97
+ # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
98
+ # kd = kd[..., 0:3]
99
+
100
+ # ################################################################################
101
+ # # Normal perturbation & normal bend
102
+ # ################################################################################
103
+
104
+ # perturbed_nrm = None
105
+
106
+ # gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
107
+
108
+ # ################################################################################
109
+ # # Evaluate BSDF
110
+ # ################################################################################
111
+
112
+ # assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
113
+ # bsdf = material['bsdf'] if bsdf is None else bsdf
114
+ # if bsdf == 'pbr':
115
+ # if isinstance(lgt, light.EnvironmentLight):
116
+ # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
117
+ # else:
118
+ # assert False, "Invalid light type"
119
+ # elif bsdf == 'diffuse':
120
+ # if isinstance(lgt, light.EnvironmentLight):
121
+ # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
122
+ # else:
123
+ # assert False, "Invalid light type"
124
+ # elif bsdf == 'normal':
125
+ # shaded_col = (gb_normal + 1.0)*0.5
126
+ # elif bsdf == 'tangent':
127
+ # shaded_col = (gb_tangent + 1.0)*0.5
128
+ # elif bsdf == 'kd':
129
+ # shaded_col = kd
130
+ # elif bsdf == 'ks':
131
+ # shaded_col = ks
132
+ # else:
133
+ # assert False, "Invalid BSDF '%s'" % bsdf
134
+
135
+ # # Return multiple buffers
136
+ # buffers = {
137
+ # 'shaded' : torch.cat((shaded_col, alpha), dim=-1),
138
+ # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
139
+ # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1)
140
+ # }
141
+ # return buffers
142
+
143
+ # ==============================================================================================
144
+ # Render a depth slice of the mesh (scene), some limitations:
145
+ # - Single mesh
146
+ # - Single light
147
+ # - Single material
148
+ # ==============================================================================================
149
+ def render_layer(
150
+ self,
151
+ rast,
152
+ rast_deriv,
153
+ mesh,
154
+ view_pos,
155
+ resolution,
156
+ spp,
157
+ msaa
158
+ ):
159
+
160
+ # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
161
+ rast_out_s = rast
162
+ rast_out_deriv_s = rast_deriv
163
+
164
+ ################################################################################
165
+ # Interpolate attributes
166
+ ################################################################################
167
+
168
+ # Interpolate world space position
169
+ gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
170
+
171
+ # Compute geometric normals. We need those because of bent normals trick (for bump mapping)
172
+ v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
173
+ v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
174
+ v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
175
+ face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
176
+ face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
177
+ gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
178
+
179
+ # Compute tangent space
180
+ assert mesh.v_nrm is not None and mesh.v_tng is not None
181
+ gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
182
+ gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
183
+
184
+ # Texture coordinate
185
+ # assert mesh.v_tex is not None
186
+ # gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
187
+ perturbed_nrm = None
188
+ gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
189
+
190
+ return gb_pos, gb_normal
191
+
192
+ def render_mesh(
193
+ self,
194
+ mesh_v_pos_bxnx3,
195
+ mesh_t_pos_idx_fx3,
196
+ mesh,
197
+ camera_mv_bx4x4,
198
+ camera_pos,
199
+ mesh_v_feat_bxnxd,
200
+ resolution=256,
201
+ spp=1,
202
+ device='cuda',
203
+ hierarchical_mask=False
204
+ ):
205
+ assert not hierarchical_mask
206
+
207
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
208
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
209
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
210
+
211
+ # view_pos = torch.linalg.inv(mtx_in)[:, :3, 3]
212
+ view_pos = camera_pos
213
+ v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
214
+
215
+ # Render the image,
216
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
217
+ num_layers = 1
218
+ mask_pyramid = None
219
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
220
+
221
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization]
222
+
223
+ layers = []
224
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler:
225
+ for _ in range(num_layers):
226
+ rast, db = peeler.rasterize_next_layer()
227
+ gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False)
228
+
229
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
230
+ for _ in range(num_layers):
231
+ rast, db = peeler.rasterize_next_layer()
232
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
233
+
234
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
235
+ antialias_mask = dr.antialias(
236
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
237
+ mesh_t_pos_idx_fx3)
238
+
239
+ depth = gb_feat[..., -2:-1]
240
+ ori_mesh_feature = gb_feat[..., :-4]
241
+
242
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
243
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
244
+ # normal = F.normalize(normal, dim=-1)
245
+ # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
246
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
247
+
248
+ def render_mesh_light(
249
+ self,
250
+ mesh_v_pos_bxnx3,
251
+ mesh_t_pos_idx_fx3,
252
+ mesh,
253
+ camera_mv_bx4x4,
254
+ mesh_v_feat_bxnxd,
255
+ resolution=256,
256
+ spp=1,
257
+ device='cuda',
258
+ hierarchical_mask=False
259
+ ):
260
+ assert not hierarchical_mask
261
+
262
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
263
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
264
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
265
+
266
+ v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
267
+
268
+ # Render the image,
269
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
270
+ num_layers = 1
271
+ mask_pyramid = None
272
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
273
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
274
+
275
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
276
+ for _ in range(num_layers):
277
+ rast, db = peeler.rasterize_next_layer()
278
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
279
+
280
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
281
+ antialias_mask = dr.antialias(
282
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
283
+ mesh_t_pos_idx_fx3)
284
+
285
+ depth = gb_feat[..., -2:-1]
286
+ ori_mesh_feature = gb_feat[..., :-4]
287
+
288
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
289
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
290
+ normal = F.normalize(normal, dim=-1)
291
+ normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
292
+
293
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
src/models/geometry/render/renderutils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
11
+ __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
src/models/geometry/render/renderutils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (612 Bytes). View file
 
src/models/geometry/render/renderutils/__pycache__/bsdf.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
src/models/geometry/render/renderutils/__pycache__/loss.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
src/models/geometry/render/renderutils/__pycache__/ops.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
src/models/geometry/render/renderutils/bsdf.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import math
11
+ import torch
12
+
13
+ NORMAL_THRESHOLD = 0.1
14
+
15
+ ################################################################################
16
+ # Vector utility functions
17
+ ################################################################################
18
+
19
+ def _dot(x, y):
20
+ return torch.sum(x*y, -1, keepdim=True)
21
+
22
+ def _reflect(x, n):
23
+ return 2*_dot(x, n)*n - x
24
+
25
+ def _safe_normalize(x):
26
+ return torch.nn.functional.normalize(x, dim = -1)
27
+
28
+ def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
29
+ # Swap normal direction for backfacing surfaces
30
+ if two_sided_shading:
31
+ smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
32
+ geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
33
+
34
+ t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
35
+ return torch.lerp(geom_nrm, smooth_nrm, t)
36
+
37
+
38
+ def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
39
+ smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
40
+ if opengl:
41
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
42
+ else:
43
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
44
+ return _safe_normalize(shading_nrm)
45
+
46
+ def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
47
+ smooth_nrm = _safe_normalize(smooth_nrm)
48
+ smooth_tng = _safe_normalize(smooth_tng)
49
+ view_vec = _safe_normalize(view_pos - pos)
50
+ shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
51
+ return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
52
+
53
+ ################################################################################
54
+ # Simple lambertian diffuse BSDF
55
+ ################################################################################
56
+
57
+ def bsdf_lambert(nrm, wi):
58
+ return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
59
+
60
+ ################################################################################
61
+ # Frostbite diffuse
62
+ ################################################################################
63
+
64
+ def bsdf_frostbite(nrm, wi, wo, linearRoughness):
65
+ wiDotN = _dot(wi, nrm)
66
+ woDotN = _dot(wo, nrm)
67
+
68
+ h = _safe_normalize(wo + wi)
69
+ wiDotH = _dot(wi, h)
70
+
71
+ energyBias = 0.5 * linearRoughness
72
+ energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
73
+ f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
74
+ f0 = 1.0
75
+
76
+ wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
77
+ woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
78
+ res = wiScatter * woScatter * energyFactor
79
+ return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
80
+
81
+ ################################################################################
82
+ # Phong specular, loosely based on mitsuba implementation
83
+ ################################################################################
84
+
85
+ def bsdf_phong(nrm, wo, wi, N):
86
+ dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
87
+ dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
88
+ return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
89
+
90
+ ################################################################################
91
+ # PBR's implementation of GGX specular
92
+ ################################################################################
93
+
94
+ specular_epsilon = 1e-4
95
+
96
+ def bsdf_fresnel_shlick(f0, f90, cosTheta):
97
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
98
+ return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
99
+
100
+ def bsdf_ndf_ggx(alphaSqr, cosTheta):
101
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
102
+ d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
103
+ return alphaSqr / (d * d * math.pi)
104
+
105
+ def bsdf_lambda_ggx(alphaSqr, cosTheta):
106
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
107
+ cosThetaSqr = _cosTheta * _cosTheta
108
+ tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
109
+ res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
110
+ return res
111
+
112
+ def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
113
+ lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
114
+ lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
115
+ return 1 / (1 + lambdaI + lambdaO)
116
+
117
+ def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
118
+ _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
119
+ alphaSqr = _alpha * _alpha
120
+
121
+ h = _safe_normalize(wo + wi)
122
+ woDotN = _dot(wo, nrm)
123
+ wiDotN = _dot(wi, nrm)
124
+ woDotH = _dot(wo, h)
125
+ nDotH = _dot(nrm, h)
126
+
127
+ D = bsdf_ndf_ggx(alphaSqr, nDotH)
128
+ G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
129
+ F = bsdf_fresnel_shlick(col, 1, woDotH)
130
+
131
+ w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
132
+
133
+ frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
134
+ return torch.where(frontfacing, w, torch.zeros_like(w))
135
+
136
+ def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
137
+ wo = _safe_normalize(view_pos - pos)
138
+ wi = _safe_normalize(light_pos - pos)
139
+
140
+ spec_str = arm[..., 0:1] # x component
141
+ roughness = arm[..., 1:2] # y component
142
+ metallic = arm[..., 2:3] # z component
143
+ ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
144
+ kd = kd * (1.0 - metallic)
145
+
146
+ if BSDF == 0:
147
+ diffuse = kd * bsdf_lambert(nrm, wi)
148
+ else:
149
+ diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
150
+ specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
151
+ return diffuse + specular
src/models/geometry/render/renderutils/c_src/bsdf.cu ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "bsdf.h"
14
+
15
+ #define SPECULAR_EPSILON 1e-4f
16
+
17
+ //------------------------------------------------------------------------
18
+ // Lambert functions
19
+
20
+ __device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
21
+ {
22
+ return max(dot(nrm, wi) / M_PI, 0.0f);
23
+ }
24
+
25
+ __device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
26
+ {
27
+ if (dot(nrm, wi) > 0.0f)
28
+ bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
29
+ }
30
+
31
+ //------------------------------------------------------------------------
32
+ // Fresnel Schlick
33
+
34
+ __device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
35
+ {
36
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
37
+ float scale = powf(1.0f - _cosTheta, 5.0f);
38
+ return f0 * (1.0f - scale) + f90 * scale;
39
+ }
40
+
41
+ __device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
42
+ {
43
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
44
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
45
+ d_f0 += d_out * (1.0 - scale);
46
+ d_f90 += d_out * scale;
47
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
48
+ {
49
+ d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
50
+ }
51
+ }
52
+
53
+ __device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
54
+ {
55
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
56
+ float scale = powf(1.0f - _cosTheta, 5.0f);
57
+ return f0 * (1.0f - scale) + f90 * scale;
58
+ }
59
+
60
+ __device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
61
+ {
62
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
63
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
64
+ d_f0 += d_out * (1.0 - scale);
65
+ d_f90 += d_out * scale;
66
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
67
+ {
68
+ d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
69
+ }
70
+ }
71
+
72
+ //------------------------------------------------------------------------
73
+ // Frostbite diffuse
74
+
75
+ __device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
76
+ {
77
+ float wiDotN = dot(wi, nrm);
78
+ float woDotN = dot(wo, nrm);
79
+ if (wiDotN > 0.0f && woDotN > 0.0f)
80
+ {
81
+ vec3f h = safeNormalize(wo + wi);
82
+ float wiDotH = dot(wi, h);
83
+
84
+ float energyBias = 0.5f * linearRoughness;
85
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
86
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
87
+ float f0 = 1.f;
88
+
89
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
90
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
91
+
92
+ return wiScatter * woScatter * energyFactor;
93
+ }
94
+ else return 0.0f;
95
+ }
96
+
97
+ __device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
98
+ {
99
+ float wiDotN = dot(wi, nrm);
100
+ float woDotN = dot(wo, nrm);
101
+
102
+ if (wiDotN > 0.0f && woDotN > 0.0f)
103
+ {
104
+ vec3f h = safeNormalize(wo + wi);
105
+ float wiDotH = dot(wi, h);
106
+
107
+ float energyBias = 0.5f * linearRoughness;
108
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
109
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
110
+ float f0 = 1.f;
111
+
112
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
113
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
114
+
115
+ // -------------- BWD --------------
116
+ // Backprop: return wiScatter * woScatter * energyFactor;
117
+ float d_wiScatter = d_out * woScatter * energyFactor;
118
+ float d_woScatter = d_out * wiScatter * energyFactor;
119
+ float d_energyFactor = d_out * wiScatter * woScatter;
120
+
121
+ // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
122
+ float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
123
+ bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
124
+
125
+ // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
126
+ float d_wiDotN = 0.0f;
127
+ bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
128
+
129
+ // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
130
+ float d_energyBias = d_f90;
131
+ float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
132
+ d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
133
+
134
+ // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
135
+ d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
136
+
137
+ // Backprop: float energyBias = 0.5f * linearRoughness;
138
+ d_linearRoughness += 0.5 * d_energyBias;
139
+
140
+ // Backprop: float wiDotH = dot(wi, h);
141
+ vec3f d_h(0);
142
+ bwdDot(wi, h, d_wi, d_h, d_wiDotH);
143
+
144
+ // Backprop: vec3f h = safeNormalize(wo + wi);
145
+ vec3f d_wo_wi(0);
146
+ bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
147
+ d_wi += d_wo_wi; d_wo += d_wo_wi;
148
+
149
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
150
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
151
+ }
152
+ }
153
+
154
+ //------------------------------------------------------------------------
155
+ // Ndf GGX
156
+
157
+ __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
158
+ {
159
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
160
+ float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
161
+ return alphaSqr / (d * d * M_PI);
162
+ }
163
+
164
+ __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
165
+ {
166
+ // Torch only back propagates if clamp doesn't trigger
167
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
168
+ float cosThetaSqr = _cosTheta * _cosTheta;
169
+ d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
170
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
171
+ {
172
+ d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
173
+ }
174
+ }
175
+
176
+ //------------------------------------------------------------------------
177
+ // Lambda GGX
178
+
179
+ __device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
180
+ {
181
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
182
+ float cosThetaSqr = _cosTheta * _cosTheta;
183
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
184
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
185
+ return res;
186
+ }
187
+
188
+ __device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
189
+ {
190
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
191
+ float cosThetaSqr = _cosTheta * _cosTheta;
192
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
193
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
194
+
195
+ d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
196
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
197
+ d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
198
+ }
199
+
200
+ //------------------------------------------------------------------------
201
+ // Masking GGX
202
+
203
+ __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
204
+ {
205
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
206
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
207
+ return 1.0f / (1.0f + lambdaI + lambdaO);
208
+ }
209
+
210
+ __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
211
+ {
212
+ // FWD eval
213
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
214
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
215
+
216
+ // BWD eval
217
+ float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
218
+ bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
219
+ bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
220
+ }
221
+
222
+ //------------------------------------------------------------------------
223
+ // GGX specular
224
+
225
+ __device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
226
+ {
227
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
228
+ float alphaSqr = _alpha * _alpha;
229
+
230
+ vec3f h = safeNormalize(wo + wi);
231
+ float woDotN = dot(wo, nrm);
232
+ float wiDotN = dot(wi, nrm);
233
+ float woDotH = dot(wo, h);
234
+ float nDotH = dot(nrm, h);
235
+
236
+ float D = fwdNdfGGX(alphaSqr, nDotH);
237
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
238
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
239
+ vec3f w = F * D * G * 0.25 / woDotN;
240
+
241
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
242
+ return frontfacing ? w : 0.0f;
243
+ }
244
+
245
+ __device__ void bwdPbrSpecular(
246
+ const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
247
+ vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
248
+ {
249
+ ///////////////////////////////////////////////////////////////////////
250
+ // FWD eval
251
+
252
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
253
+ float alphaSqr = _alpha * _alpha;
254
+
255
+ vec3f h = safeNormalize(wo + wi);
256
+ float woDotN = dot(wo, nrm);
257
+ float wiDotN = dot(wi, nrm);
258
+ float woDotH = dot(wo, h);
259
+ float nDotH = dot(nrm, h);
260
+
261
+ float D = fwdNdfGGX(alphaSqr, nDotH);
262
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
263
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
264
+ vec3f w = F * D * G * 0.25 / woDotN;
265
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
266
+
267
+ if (frontfacing)
268
+ {
269
+ ///////////////////////////////////////////////////////////////////////
270
+ // BWD eval
271
+
272
+ vec3f d_F = d_out * D * G * 0.25f / woDotN;
273
+ float d_D = sum(d_out * F * G * 0.25f / woDotN);
274
+ float d_G = sum(d_out * F * D * 0.25f / woDotN);
275
+
276
+ float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
277
+
278
+ vec3f d_f90(0);
279
+ float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
280
+ bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
281
+ bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
282
+ bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
283
+
284
+ vec3f d_h(0);
285
+ bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
286
+ bwdDot(wo, h, d_wo, d_h, d_woDotH);
287
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
288
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
289
+
290
+ vec3f d_h_unnorm(0);
291
+ bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
292
+ d_wo += d_h_unnorm;
293
+ d_wi += d_h_unnorm;
294
+
295
+ if (alpha > min_roughness * min_roughness)
296
+ d_alpha += d_alphaSqr * 2 * alpha;
297
+ }
298
+ }
299
+
300
+ //------------------------------------------------------------------------
301
+ // Full PBR BSDF
302
+
303
+ __device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
304
+ {
305
+ vec3f wo = safeNormalize(view_pos - pos);
306
+ vec3f wi = safeNormalize(light_pos - pos);
307
+
308
+ float alpha = arm.y * arm.y;
309
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
310
+ vec3f diff_col = kd * (1.0f - arm.z);
311
+
312
+ float diff = 0.0f;
313
+ if (BSDF == 0)
314
+ diff = fwdLambert(nrm, wi);
315
+ else
316
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
317
+ vec3f diffuse = diff_col * diff;
318
+ vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
319
+
320
+ return diffuse + specular;
321
+ }
322
+
323
+ __device__ void bwdPbrBSDF(
324
+ const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
325
+ vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
326
+ {
327
+ ////////////////////////////////////////////////////////////////////////
328
+ // FWD
329
+ vec3f _wi = light_pos - pos;
330
+ vec3f _wo = view_pos - pos;
331
+ vec3f wi = safeNormalize(_wi);
332
+ vec3f wo = safeNormalize(_wo);
333
+
334
+ float alpha = arm.y * arm.y;
335
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
336
+ vec3f diff_col = kd * (1.0f - arm.z);
337
+ float diff = 0.0f;
338
+ if (BSDF == 0)
339
+ diff = fwdLambert(nrm, wi);
340
+ else
341
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
342
+
343
+ ////////////////////////////////////////////////////////////////////////
344
+ // BWD
345
+
346
+ float d_alpha(0);
347
+ vec3f d_spec_col(0), d_wi(0), d_wo(0);
348
+ bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
349
+
350
+ float d_diff = sum(diff_col * d_out);
351
+ if (BSDF == 0)
352
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
353
+ else
354
+ bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
355
+
356
+ // Backprop: diff_col = kd * (1.0f - arm.z)
357
+ vec3f d_diff_col = d_out * diff;
358
+ d_kd += d_diff_col * (1.0f - arm.z);
359
+ d_arm.z -= sum(d_diff_col * kd);
360
+
361
+ // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
362
+ d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
363
+ d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
364
+ d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
365
+
366
+ // Backprop: alpha = arm.y * arm.y
367
+ d_arm.y += d_alpha * 2 * arm.y;
368
+
369
+ // Backprop: vec3f wi = safeNormalize(light_pos - pos);
370
+ vec3f d__wi(0);
371
+ bwdSafeNormalize(_wi, d__wi, d_wi);
372
+ d_light_pos += d__wi;
373
+ d_pos -= d__wi;
374
+
375
+ // Backprop: vec3f wo = safeNormalize(view_pos - pos);
376
+ vec3f d__wo(0);
377
+ bwdSafeNormalize(_wo, d__wo, d_wo);
378
+ d_view_pos += d__wo;
379
+ d_pos -= d__wo;
380
+ }
381
+
382
+ //------------------------------------------------------------------------
383
+ // Kernels
384
+
385
+ __global__ void LambertFwdKernel(LambertKernelParams p)
386
+ {
387
+ // Calculate pixel position.
388
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
389
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
390
+ unsigned int pz = blockIdx.z;
391
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
392
+ return;
393
+
394
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
395
+ vec3f wi = p.wi.fetch3(px, py, pz);
396
+
397
+ float res = fwdLambert(nrm, wi);
398
+
399
+ p.out.store(px, py, pz, res);
400
+ }
401
+
402
+ __global__ void LambertBwdKernel(LambertKernelParams p)
403
+ {
404
+ // Calculate pixel position.
405
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
406
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
407
+ unsigned int pz = blockIdx.z;
408
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
409
+ return;
410
+
411
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
412
+ vec3f wi = p.wi.fetch3(px, py, pz);
413
+ float d_out = p.out.fetch1(px, py, pz);
414
+
415
+ vec3f d_nrm(0), d_wi(0);
416
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
417
+
418
+ p.nrm.store_grad(px, py, pz, d_nrm);
419
+ p.wi.store_grad(px, py, pz, d_wi);
420
+ }
421
+
422
+ __global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
423
+ {
424
+ // Calculate pixel position.
425
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
426
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
427
+ unsigned int pz = blockIdx.z;
428
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
429
+ return;
430
+
431
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
432
+ vec3f wi = p.wi.fetch3(px, py, pz);
433
+ vec3f wo = p.wo.fetch3(px, py, pz);
434
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
435
+
436
+ float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
437
+
438
+ p.out.store(px, py, pz, res);
439
+ }
440
+
441
+ __global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
442
+ {
443
+ // Calculate pixel position.
444
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
445
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
446
+ unsigned int pz = blockIdx.z;
447
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
448
+ return;
449
+
450
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
451
+ vec3f wi = p.wi.fetch3(px, py, pz);
452
+ vec3f wo = p.wo.fetch3(px, py, pz);
453
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
454
+ float d_out = p.out.fetch1(px, py, pz);
455
+
456
+ float d_linearRoughness = 0.0f;
457
+ vec3f d_nrm(0), d_wi(0), d_wo(0);
458
+ bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
459
+
460
+ p.nrm.store_grad(px, py, pz, d_nrm);
461
+ p.wi.store_grad(px, py, pz, d_wi);
462
+ p.wo.store_grad(px, py, pz, d_wo);
463
+ p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
464
+ }
465
+
466
+ __global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
467
+ {
468
+ // Calculate pixel position.
469
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
470
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
471
+ unsigned int pz = blockIdx.z;
472
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
473
+ return;
474
+
475
+ vec3f f0 = p.f0.fetch3(px, py, pz);
476
+ vec3f f90 = p.f90.fetch3(px, py, pz);
477
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
478
+
479
+ vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
480
+ p.out.store(px, py, pz, res);
481
+ }
482
+
483
+ __global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
484
+ {
485
+ // Calculate pixel position.
486
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
487
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
488
+ unsigned int pz = blockIdx.z;
489
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
490
+ return;
491
+
492
+ vec3f f0 = p.f0.fetch3(px, py, pz);
493
+ vec3f f90 = p.f90.fetch3(px, py, pz);
494
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
495
+ vec3f d_out = p.out.fetch3(px, py, pz);
496
+
497
+ vec3f d_f0(0), d_f90(0);
498
+ float d_cosTheta(0);
499
+ bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
500
+
501
+ p.f0.store_grad(px, py, pz, d_f0);
502
+ p.f90.store_grad(px, py, pz, d_f90);
503
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
504
+ }
505
+
506
+ __global__ void ndfGGXFwdKernel(NdfGGXParams p)
507
+ {
508
+ // Calculate pixel position.
509
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
510
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
511
+ unsigned int pz = blockIdx.z;
512
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
513
+ return;
514
+
515
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
516
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
517
+ float res = fwdNdfGGX(alphaSqr, cosTheta);
518
+
519
+ p.out.store(px, py, pz, res);
520
+ }
521
+
522
+ __global__ void ndfGGXBwdKernel(NdfGGXParams p)
523
+ {
524
+ // Calculate pixel position.
525
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
526
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
527
+ unsigned int pz = blockIdx.z;
528
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
529
+ return;
530
+
531
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
532
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
533
+ float d_out = p.out.fetch1(px, py, pz);
534
+
535
+ float d_alphaSqr(0), d_cosTheta(0);
536
+ bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
537
+
538
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
539
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
540
+ }
541
+
542
+ __global__ void lambdaGGXFwdKernel(NdfGGXParams p)
543
+ {
544
+ // Calculate pixel position.
545
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
546
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
547
+ unsigned int pz = blockIdx.z;
548
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
549
+ return;
550
+
551
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
552
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
553
+ float res = fwdLambdaGGX(alphaSqr, cosTheta);
554
+
555
+ p.out.store(px, py, pz, res);
556
+ }
557
+
558
+ __global__ void lambdaGGXBwdKernel(NdfGGXParams p)
559
+ {
560
+ // Calculate pixel position.
561
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
562
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
563
+ unsigned int pz = blockIdx.z;
564
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
565
+ return;
566
+
567
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
568
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
569
+ float d_out = p.out.fetch1(px, py, pz);
570
+
571
+ float d_alphaSqr(0), d_cosTheta(0);
572
+ bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
573
+
574
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
575
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
576
+ }
577
+
578
+ __global__ void maskingSmithFwdKernel(MaskingSmithParams p)
579
+ {
580
+ // Calculate pixel position.
581
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
582
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
583
+ unsigned int pz = blockIdx.z;
584
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
585
+ return;
586
+
587
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
588
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
589
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
590
+ float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
591
+
592
+ p.out.store(px, py, pz, res);
593
+ }
594
+
595
+ __global__ void maskingSmithBwdKernel(MaskingSmithParams p)
596
+ {
597
+ // Calculate pixel position.
598
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
599
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
600
+ unsigned int pz = blockIdx.z;
601
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
602
+ return;
603
+
604
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
605
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
606
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
607
+ float d_out = p.out.fetch1(px, py, pz);
608
+
609
+ float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
610
+ bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
611
+
612
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
613
+ p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
614
+ p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
615
+ }
616
+
617
+ __global__ void pbrSpecularFwdKernel(PbrSpecular p)
618
+ {
619
+ // Calculate pixel position.
620
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
621
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
622
+ unsigned int pz = blockIdx.z;
623
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
624
+ return;
625
+
626
+ vec3f col = p.col.fetch3(px, py, pz);
627
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
628
+ vec3f wo = p.wo.fetch3(px, py, pz);
629
+ vec3f wi = p.wi.fetch3(px, py, pz);
630
+ float alpha = p.alpha.fetch1(px, py, pz);
631
+
632
+ vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
633
+
634
+ p.out.store(px, py, pz, res);
635
+ }
636
+
637
+ __global__ void pbrSpecularBwdKernel(PbrSpecular p)
638
+ {
639
+ // Calculate pixel position.
640
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
641
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
642
+ unsigned int pz = blockIdx.z;
643
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
644
+ return;
645
+
646
+ vec3f col = p.col.fetch3(px, py, pz);
647
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
648
+ vec3f wo = p.wo.fetch3(px, py, pz);
649
+ vec3f wi = p.wi.fetch3(px, py, pz);
650
+ float alpha = p.alpha.fetch1(px, py, pz);
651
+ vec3f d_out = p.out.fetch3(px, py, pz);
652
+
653
+ float d_alpha(0);
654
+ vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
655
+ bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
656
+
657
+ p.col.store_grad(px, py, pz, d_col);
658
+ p.nrm.store_grad(px, py, pz, d_nrm);
659
+ p.wo.store_grad(px, py, pz, d_wo);
660
+ p.wi.store_grad(px, py, pz, d_wi);
661
+ p.alpha.store_grad(px, py, pz, d_alpha);
662
+ }
663
+
664
+ __global__ void pbrBSDFFwdKernel(PbrBSDF p)
665
+ {
666
+ // Calculate pixel position.
667
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
668
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
669
+ unsigned int pz = blockIdx.z;
670
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
671
+ return;
672
+
673
+ vec3f kd = p.kd.fetch3(px, py, pz);
674
+ vec3f arm = p.arm.fetch3(px, py, pz);
675
+ vec3f pos = p.pos.fetch3(px, py, pz);
676
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
677
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
678
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
679
+
680
+ vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
681
+
682
+ p.out.store(px, py, pz, res);
683
+ }
684
+ __global__ void pbrBSDFBwdKernel(PbrBSDF p)
685
+ {
686
+ // Calculate pixel position.
687
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
688
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
689
+ unsigned int pz = blockIdx.z;
690
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
691
+ return;
692
+
693
+ vec3f kd = p.kd.fetch3(px, py, pz);
694
+ vec3f arm = p.arm.fetch3(px, py, pz);
695
+ vec3f pos = p.pos.fetch3(px, py, pz);
696
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
697
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
698
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
699
+ vec3f d_out = p.out.fetch3(px, py, pz);
700
+
701
+ vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
702
+ bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
703
+
704
+ p.kd.store_grad(px, py, pz, d_kd);
705
+ p.arm.store_grad(px, py, pz, d_arm);
706
+ p.pos.store_grad(px, py, pz, d_pos);
707
+ p.nrm.store_grad(px, py, pz, d_nrm);
708
+ p.view_pos.store_grad(px, py, pz, d_view_pos);
709
+ p.light_pos.store_grad(px, py, pz, d_light_pos);
710
+ }