fffiloni commited on
Commit
2252f3d
1 Parent(s): e9c8d11

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE.txt +21 -0
  3. ORIGINAL_README.md +79 -0
  4. assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 +0 -0
  5. assets/result_clr_scale4_pexels-zdmit-6780091.mp4 +0 -0
  6. blender/blender_render_human_ortho.py +837 -0
  7. blender/check_render.py +46 -0
  8. blender/count.py +44 -0
  9. blender/distribute.py +149 -0
  10. blender/rename_smpl_files.py +25 -0
  11. blender/render.sh +4 -0
  12. blender/render_human.py +88 -0
  13. blender/render_single.sh +7 -0
  14. blender/utils.py +128 -0
  15. configs/inference-768-6view.yaml +72 -0
  16. configs/remesh.yaml +18 -0
  17. configs/train-768-6view-onlyscan_face.yaml +145 -0
  18. configs/train-768-6view-onlyscan_face_smplx.yaml +154 -0
  19. core/opt.py +197 -0
  20. core/remesh.py +359 -0
  21. econdataset.py +370 -0
  22. examples/02986d0998ce01aa0aa67a99fbd1e09a.png +0 -0
  23. examples/16171.png +0 -0
  24. examples/26d2e846349647ff04c536816e0e8ca1.png +0 -0
  25. examples/30755.png +0 -0
  26. examples/3930.png +0 -0
  27. examples/4656716-3016170581.png +0 -0
  28. examples/663dcd6db19490de0b790da430bd5681.png +3 -0
  29. examples/7332.png +0 -0
  30. examples/85891251f52a2399e660a63c2a7fdf40.png +0 -0
  31. examples/a689a48d23d6b8d58d67ff5146c6e088.png +0 -0
  32. examples/b0d178743c7e3e09700aaee8d2b1ec47.png +0 -0
  33. examples/case5.png +0 -0
  34. examples/d40776a1e1582179d97907d36f84d776.png +0 -0
  35. examples/durant.png +0 -0
  36. examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png +0 -0
  37. examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png +0 -0
  38. examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png +0 -0
  39. examples/pexels-barbara-olsen-7869640.png +0 -0
  40. examples/pexels-julia-m-cameron-4145040.png +0 -0
  41. examples/pexels-marta-wave-6437749.png +0 -0
  42. examples/pexels-photo-6311555-removebg.png +0 -0
  43. examples/pexels-zdmit-6780091.png +0 -0
  44. inference.py +221 -0
  45. lib/__init__.py +0 -0
  46. lib/common/__init__.py +0 -0
  47. lib/common/cloth_extraction.py +182 -0
  48. lib/common/config.py +218 -0
  49. lib/common/imutils.py +364 -0
  50. lib/common/render.py +398 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/663dcd6db19490de0b790da430bd5681.png filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ORIGINAL_README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PSHuman
2
+
3
+ This is the official implementation of *PSHuman: Photorealistic Single-image 3D Human Reconstruction using Cross-Scale Multiview Diffusion*.
4
+
5
+ ### [Project Page](https://penghtyx.github.io/PSHuman/) | [Arxiv](https://arxiv.org/pdf/2409.10141) | [Weights](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views)
6
+
7
+ https://github.com/user-attachments/assets/b62e3305-38a7-4b51-aed8-1fde967cca70
8
+
9
+ https://github.com/user-attachments/assets/76100d2e-4a1a-41ad-815c-816340ac6500
10
+
11
+
12
+ Given a single image of a clothed person, **PSHuman** facilitates detailed geometry and realistic 3D human appearance across various poses within one minute.
13
+
14
+ ### 📝 Update
15
+ - __[2024.11.30]__: Release the SMPL-free [version](https://huggingface.co/pengHTYX/PSHuman_Unclip_768_6views), which does not requires SMPL condition for multview generation and perfome well in general posed human.
16
+
17
+
18
+ ### Installation
19
+ ```
20
+ conda create -n pshuman python=3.10
21
+ conda activate pshuman
22
+
23
+ # torch
24
+ pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
25
+
26
+ # other depedency
27
+ pip install -r requirement.txt
28
+ ```
29
+
30
+ This project is also based on SMPLX. We borrowed the related models from [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU), and re-orginized them, which can be downloaded from [Onedrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD).
31
+
32
+
33
+
34
+ ### Inference
35
+ 1. Given a human image, we use [Clipdrop](https://github.com/xxlong0/Wonder3D?tab=readme-ov-file) or ```rembg``` to remove the background. For the latter, we provide a simple scrip.
36
+ ```
37
+ python utils/remove_bg.py --path $DATA_PATH$
38
+ ```
39
+ Then, put the RGBA images in the ```$DATA_PATH$```.
40
+
41
+ 2. By running [inference.py](inference.py), the textured mesh and rendered video will be saved in ```out```.
42
+ ```
43
+ CUDA_VISIBLE_DEVICES=$GPU python inference.py --config configs/inference-768-6view.yaml \
44
+ pretrained_model_name_or_path='pengHTYX/PSHuman_Unclip_768_6views' \
45
+ validation_dataset.crop_size=740 \
46
+ with_smpl=false \
47
+ validation_dataset.root_dir=$DATA_PATH$ \
48
+ seed=600 \
49
+ num_views=7 \
50
+ save_mode='rgb'
51
+
52
+ ```
53
+ You can adjust the ```crop_size``` (720 or 740) and ```seed``` (42 or 600) to obtain best results for some cases.
54
+
55
+ ### Training
56
+ For the data preparing and preprocessing, please refer to our [paper](https://arxiv.org/pdf/2409.10141). Once the data is ready, we begin the training by running
57
+ ```
58
+ bash scripts/train_768.sh
59
+ ```
60
+ You should modified some parameters, such as ```data_common.root_dir``` and ```data_common.object_list```.
61
+
62
+ ### Related projects
63
+ We collect code from following projects. We thanks for the contributions from the open-source community!
64
+
65
+ [ECON](https://github.com/YuliangXiu/ECON) and [SIFU](https://github.com/River-Zhang/SIFU) recover human mesh from single human image.
66
+ [Era3D](https://github.com/pengHTYX/Era3D) and [Unique3D](https://github.com/AiuniAI/Unique3D) generate consistent multiview images with single color image.
67
+ [Continuous-Remeshing](https://github.com/Profactor/continuous-remeshing) for Inverse Rendering.
68
+
69
+
70
+ ### Citation
71
+ If you find this codebase useful, please consider cite our work.
72
+ ```
73
+ @article{li2024pshuman,
74
+ title={PSHuman: Photorealistic Single-view Human Reconstruction using Cross-Scale Diffusion},
75
+ author={Li, Peng and Zheng, Wangguandong and Liu, Yuan and Yu, Tao and Li, Yangguang and Qi, Xingqun and Li, Mengfei and Chi, Xiaowei and Xia, Siyu and Xue, Wei and others},
76
+ journal={arXiv preprint arXiv:2409.10141},
77
+ year={2024}
78
+ }
79
+ ```
assets/result_clr_scale4_pexels-barbara-olsen-7869640.mp4 ADDED
Binary file (320 kB). View file
 
assets/result_clr_scale4_pexels-zdmit-6780091.mp4 ADDED
Binary file (629 kB). View file
 
blender/blender_render_human_ortho.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Blender script to render images of 3D models.
2
+
3
+ This script is used to render images of 3D models. It takes in a list of paths
4
+ to .glb files and renders images of each model. The images are from rotating the
5
+ object around the origin. The images are saved to the output directory.
6
+
7
+ Example usage:
8
+ blender -b -P blender_script.py -- \
9
+ --object_path my_object.glb \
10
+ --output_dir ./views \
11
+ --engine CYCLES \
12
+ --scale 0.8 \
13
+ --num_images 12 \
14
+ --camera_dist 1.2
15
+
16
+ Here, input_model_paths.json is a json file containing a list of paths to .glb.
17
+ """
18
+ import argparse
19
+ import json
20
+ import math
21
+ import os
22
+ import random
23
+ import sys
24
+ import time
25
+ import glob
26
+ import urllib.request
27
+ import uuid
28
+ from typing import Tuple
29
+ from mathutils import Vector, Matrix
30
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
31
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
32
+ import cv2
33
+ import numpy as np
34
+ from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
35
+
36
+ import bpy
37
+ from mathutils import Vector
38
+
39
+ import OpenEXR
40
+ import Imath
41
+ from PIL import Image
42
+
43
+ # import blenderproc as bproc
44
+
45
+ bpy.app.debug_value=256
46
+
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument(
49
+ "--object_path",
50
+ type=str,
51
+ required=True,
52
+ help="Path to the object file",
53
+ )
54
+ parser.add_argument("--smpl_path", type=str, required=True, help="Path to the object file")
55
+ parser.add_argument("--output_dir", type=str, default="/views_whole_sphere-test2")
56
+ parser.add_argument(
57
+ "--engine", type=str, default="BLENDER_EEVEE", choices=["CYCLES", "BLENDER_EEVEE"]
58
+ )
59
+ parser.add_argument("--scale", type=float, default=1.0)
60
+ parser.add_argument("--num_images", type=int, default=8)
61
+ parser.add_argument("--random_images", type=int, default=3)
62
+ parser.add_argument("--random_ortho", type=int, default=1)
63
+ parser.add_argument("--device", type=str, default="CUDA")
64
+ parser.add_argument("--resolution", type=int, default=512)
65
+
66
+
67
+ argv = sys.argv[sys.argv.index("--") + 1 :]
68
+ args = parser.parse_args(argv)
69
+
70
+
71
+
72
+ print('===================', args.engine, '===================')
73
+
74
+ context = bpy.context
75
+ scene = context.scene
76
+ render = scene.render
77
+
78
+ cam = scene.objects["Camera"]
79
+ cam.data.type = 'ORTHO'
80
+ cam.data.ortho_scale = 1.
81
+ cam.data.lens = 35
82
+ cam.data.sensor_height = 32
83
+ cam.data.sensor_width = 32
84
+
85
+ cam_constraint = cam.constraints.new(type="TRACK_TO")
86
+ cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
87
+ cam_constraint.up_axis = "UP_Y"
88
+
89
+ # setup lighting
90
+ # bpy.ops.object.light_add(type="AREA")
91
+ # light2 = bpy.data.lights["Area"]
92
+ # light2.energy = 3000
93
+ # bpy.data.objects["Area"].location[2] = 0.5
94
+ # bpy.data.objects["Area"].scale[0] = 100
95
+ # bpy.data.objects["Area"].scale[1] = 100
96
+ # bpy.data.objects["Area"].scale[2] = 100
97
+
98
+ render.engine = args.engine
99
+ render.image_settings.file_format = "PNG"
100
+ render.image_settings.color_mode = "RGBA"
101
+ render.resolution_x = args.resolution
102
+ render.resolution_y = args.resolution
103
+ render.resolution_percentage = 100
104
+ render.threads_mode = 'FIXED' # 使用固定线程数模式
105
+ render.threads = 32 # 设置线程数
106
+
107
+ scene.cycles.device = "GPU"
108
+ scene.cycles.samples = 128 # 128
109
+ scene.cycles.diffuse_bounces = 1
110
+ scene.cycles.glossy_bounces = 1
111
+ scene.cycles.transparent_max_bounces = 3 # 3
112
+ scene.cycles.transmission_bounces = 3 # 3
113
+ # scene.cycles.filter_width = 0.01
114
+ bpy.context.scene.cycles.adaptive_threshold = 0
115
+ scene.cycles.use_denoising = True
116
+ scene.render.film_transparent = True
117
+
118
+ bpy.context.preferences.addons["cycles"].preferences.get_devices()
119
+ # Set the device_type
120
+ bpy.context.preferences.addons["cycles"].preferences.compute_device_type = 'CUDA' # or "OPENCL"
121
+ bpy.context.scene.cycles.tile_size = 8192
122
+
123
+
124
+ # eevee = scene.eevee
125
+ # eevee.use_soft_shadows = True
126
+ # eevee.use_ssr = True
127
+ # eevee.use_ssr_refraction = True
128
+ # eevee.taa_render_samples = 64
129
+ # eevee.use_gtao = True
130
+ # eevee.gtao_distance = 1
131
+ # eevee.use_volumetric_shadows = True
132
+ # eevee.volumetric_tile_size = '2'
133
+ # eevee.gi_diffuse_bounces = 1
134
+ # eevee.gi_cubemap_resolution = '128'
135
+ # eevee.gi_visibility_resolution = '16'
136
+ # eevee.gi_irradiance_smoothing = 0
137
+
138
+
139
+ # for depth & normal
140
+ context.view_layer.use_pass_normal = True
141
+ context.view_layer.use_pass_z = True
142
+ context.scene.use_nodes = True
143
+
144
+
145
+ tree = bpy.context.scene.node_tree
146
+ nodes = bpy.context.scene.node_tree.nodes
147
+ links = bpy.context.scene.node_tree.links
148
+
149
+ # Clear default nodes
150
+ for n in nodes:
151
+ nodes.remove(n)
152
+
153
+ # # Create input render layer node.
154
+ render_layers = nodes.new('CompositorNodeRLayers')
155
+
156
+ scale_normal = nodes.new(type="CompositorNodeMixRGB")
157
+ scale_normal.blend_type = 'MULTIPLY'
158
+ scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1)
159
+ links.new(render_layers.outputs['Normal'], scale_normal.inputs[1])
160
+ bias_normal = nodes.new(type="CompositorNodeMixRGB")
161
+ bias_normal.blend_type = 'ADD'
162
+ bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0)
163
+ links.new(scale_normal.outputs[0], bias_normal.inputs[1])
164
+ normal_file_output = nodes.new(type="CompositorNodeOutputFile")
165
+ normal_file_output.label = 'Normal Output'
166
+ links.new(bias_normal.outputs[0], normal_file_output.inputs[0])
167
+
168
+ normal_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
169
+ normal_file_output.format.color_mode = "RGB" # default is "BW"
170
+
171
+ depth_file_output = nodes.new(type="CompositorNodeOutputFile")
172
+ depth_file_output.label = 'Depth Output'
173
+ links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0])
174
+ depth_file_output.format.file_format = "OPEN_EXR" # default is "PNG"
175
+ depth_file_output.format.color_mode = "RGB" # default is "BW"
176
+
177
+ def prepare_depth_outputs():
178
+ tree = bpy.context.scene.node_tree
179
+ links = tree.links
180
+ render_node = tree.nodes['Render Layers']
181
+ depth_out_node = tree.nodes.new(type="CompositorNodeOutputFile")
182
+ depth_map_node = tree.nodes.new(type="CompositorNodeMapRange")
183
+ depth_out_node.base_path = ''
184
+ depth_out_node.format.file_format = 'OPEN_EXR'
185
+ depth_out_node.format.color_depth = '32'
186
+
187
+ depth_map_node.inputs[1].default_value = 0.54
188
+ depth_map_node.inputs[2].default_value = 1.96
189
+ depth_map_node.inputs[3].default_value = 0
190
+ depth_map_node.inputs[4].default_value = 1
191
+ depth_map_node.use_clamp = True
192
+ links.new(render_node.outputs[2],depth_map_node.inputs[0])
193
+ links.new(depth_map_node.outputs[0], depth_out_node.inputs[0])
194
+ return depth_out_node, depth_map_node
195
+
196
+ depth_file_output, depth_map_node = prepare_depth_outputs()
197
+
198
+
199
+ def exr_to_png(exr_path):
200
+ depth_path = exr_path.replace('.exr', '.png')
201
+ exr_image = OpenEXR.InputFile(exr_path)
202
+ dw = exr_image.header()['dataWindow']
203
+ (width, height) = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
204
+
205
+ def read_exr(s, width, height):
206
+ mat = np.fromstring(s, dtype=np.float32)
207
+ mat = mat.reshape(height, width)
208
+ return mat
209
+
210
+ dmap, _, _ = [read_exr(s, width, height) for s in exr_image.channels('BGR', Imath.PixelType(Imath.PixelType.FLOAT))]
211
+ dmap = np.clip(np.asarray(dmap,np.float64),a_max=1.0, a_min=0.0) * 65535
212
+ dmap = Image.fromarray(dmap.astype(np.uint16))
213
+ dmap.save(depth_path)
214
+ exr_image.close()
215
+ # os.system('rm {}'.format(exr_path))
216
+
217
+ def extract_depth(directory):
218
+ fns = glob.glob(f'{directory}/*.exr')
219
+ for fn in fns: exr_to_png(fn)
220
+ os.system(f'rm {directory}/*.exr')
221
+
222
+ def sample_point_on_sphere(radius: float) -> Tuple[float, float, float]:
223
+ theta = random.random() * 2 * math.pi
224
+ phi = math.acos(2 * random.random() - 1)
225
+ return (
226
+ radius * math.sin(phi) * math.cos(theta),
227
+ radius * math.sin(phi) * math.sin(theta),
228
+ radius * math.cos(phi),
229
+ )
230
+
231
+ def sample_spherical(radius=3.0, maxz=3.0, minz=0.):
232
+ correct = False
233
+ while not correct:
234
+ vec = np.random.uniform(-1, 1, 3)
235
+ vec[2] = np.abs(vec[2])
236
+ vec = vec / np.linalg.norm(vec, axis=0) * radius
237
+ if maxz > vec[2] > minz:
238
+ correct = True
239
+ return vec
240
+
241
+ def sample_spherical(radius_min=1.5, radius_max=2.0, maxz=1.6, minz=-0.75):
242
+ correct = False
243
+ while not correct:
244
+ vec = np.random.uniform(-1, 1, 3)
245
+ # vec[2] = np.abs(vec[2])
246
+ radius = np.random.uniform(radius_min, radius_max, 1)
247
+ vec = vec / np.linalg.norm(vec, axis=0) * radius[0]
248
+ if maxz > vec[2] > minz:
249
+ correct = True
250
+ return vec
251
+
252
+ def randomize_camera():
253
+ elevation = random.uniform(0., 90.)
254
+ azimuth = random.uniform(0., 360)
255
+ distance = random.uniform(0.8, 1.6)
256
+ return set_camera_location(elevation, azimuth, distance)
257
+
258
+ def set_camera_location(elevation, azimuth, distance):
259
+ # from https://blender.stackexchange.com/questions/18530/
260
+ x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2)
261
+ camera = bpy.data.objects["Camera"]
262
+ camera.location = x, y, z
263
+
264
+ direction = - camera.location
265
+ rot_quat = direction.to_track_quat('-Z', 'Y')
266
+ camera.rotation_euler = rot_quat.to_euler()
267
+ return camera
268
+
269
+ def set_camera_mvdream(azimuth, elevation, distance):
270
+ # theta, phi = np.deg2rad(azimuth), np.deg2rad(elevation)
271
+ azimuth, elevation = np.deg2rad(azimuth), np.deg2rad(elevation)
272
+ point = (
273
+ distance * math.cos(azimuth) * math.cos(elevation),
274
+ distance * math.sin(azimuth) * math.cos(elevation),
275
+ distance * math.sin(elevation),
276
+ )
277
+ camera = bpy.data.objects["Camera"]
278
+ camera.location = point
279
+
280
+ direction = -camera.location
281
+ rot_quat = direction.to_track_quat('-Z', 'Y')
282
+ camera.rotation_euler = rot_quat.to_euler()
283
+ return camera
284
+
285
+ def reset_scene() -> None:
286
+ """Resets the scene to a clean state.
287
+
288
+ Returns:
289
+ None
290
+ """
291
+ # delete everything that isn't part of a camera or a light
292
+ for obj in bpy.data.objects:
293
+ if obj.type not in {"CAMERA", "LIGHT"}:
294
+ bpy.data.objects.remove(obj, do_unlink=True)
295
+
296
+ # delete all the materials
297
+ for material in bpy.data.materials:
298
+ bpy.data.materials.remove(material, do_unlink=True)
299
+
300
+ # delete all the textures
301
+ for texture in bpy.data.textures:
302
+ bpy.data.textures.remove(texture, do_unlink=True)
303
+
304
+ # delete all the images
305
+ for image in bpy.data.images:
306
+ bpy.data.images.remove(image, do_unlink=True)
307
+ def process_ply(obj):
308
+ # obj = bpy.context.selected_objects[0]
309
+
310
+ # 创建一个新的材质
311
+ material = bpy.data.materials.new(name="VertexColors")
312
+ material.use_nodes = True
313
+ obj.data.materials.append(material)
314
+
315
+ # 获取材质的节点树
316
+ nodes = material.node_tree.nodes
317
+ links = material.node_tree.links
318
+
319
+ # 删除原有的'Principled BSDF'节点
320
+ principled_bsdf_node = nodes.get("Principled BSDF")
321
+ if principled_bsdf_node:
322
+ nodes.remove(principled_bsdf_node)
323
+
324
+ # 创建一个新的'Emission'节点
325
+ emission_node = nodes.new(type="ShaderNodeEmission")
326
+ emission_node.location = 0, 0
327
+
328
+ # 创建一个'Attribute'节点
329
+ attribute_node = nodes.new(type="ShaderNodeAttribute")
330
+ attribute_node.location = -300, 0
331
+ attribute_node.attribute_name = "Col" # 顶点颜色属性名称
332
+
333
+ # 创建一个'Output'节点
334
+ output_node = nodes.get("Material Output")
335
+
336
+ # 连接节点
337
+ links.new(attribute_node.outputs["Color"], emission_node.inputs["Color"])
338
+ links.new(emission_node.outputs["Emission"], output_node.inputs["Surface"])
339
+
340
+ # # load the glb model
341
+ # def load_object(object_path: str) -> None:
342
+
343
+ # if object_path.endswith(".glb"):
344
+ # bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
345
+ # elif object_path.endswith(".fbx"):
346
+ # bpy.ops.import_scene.fbx(filepath=object_path)
347
+ # elif object_path.endswith(".obj"):
348
+ # bpy.ops.import_scene.obj(filepath=object_path)
349
+ # elif object_path.endswith(".ply"):
350
+ # bpy.ops.import_mesh.ply(filepath=object_path)
351
+ # obj = bpy.context.selected_objects[0]
352
+ # obj.rotation_euler[0] = 1.5708
353
+ # # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
354
+ # process_ply(obj)
355
+ # else:
356
+ # raise ValueError(f"Unsupported file type: {object_path}")
357
+
358
+
359
+
360
+ def scene_bbox(
361
+ single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
362
+ ) -> Tuple[Vector, Vector]:
363
+ """Returns the bounding box of the scene.
364
+
365
+ Taken from Shap-E rendering script
366
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
367
+
368
+ Args:
369
+ single_obj (Optional[bpy.types.Object], optional): If not None, only computes
370
+ the bounding box for the given object. Defaults to None.
371
+ ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
372
+ to False.
373
+
374
+ Raises:
375
+ RuntimeError: If there are no objects in the scene.
376
+
377
+ Returns:
378
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
379
+ """
380
+ bbox_min = (math.inf,) * 3
381
+ bbox_max = (-math.inf,) * 3
382
+ found = False
383
+ for obj in get_scene_meshes() if single_obj is None else [single_obj]:
384
+ found = True
385
+ for coord in obj.bound_box:
386
+ coord = Vector(coord)
387
+ if not ignore_matrix:
388
+ coord = obj.matrix_world @ coord
389
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
390
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
391
+
392
+ if not found:
393
+ raise RuntimeError("no objects in scene to compute bounding box for")
394
+
395
+ return Vector(bbox_min), Vector(bbox_max)
396
+
397
+
398
+ def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
399
+ """Returns all root objects in the scene.
400
+
401
+ Yields:
402
+ Generator[bpy.types.Object, None, None]: Generator of all root objects in the
403
+ scene.
404
+ """
405
+ for obj in bpy.context.scene.objects.values():
406
+ if not obj.parent:
407
+ yield obj
408
+
409
+
410
+ def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
411
+ """Returns all meshes in the scene.
412
+
413
+ Yields:
414
+ Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
415
+ """
416
+ for obj in bpy.context.scene.objects.values():
417
+ if isinstance(obj.data, (bpy.types.Mesh)):
418
+ yield obj
419
+
420
+
421
+ # Build intrinsic camera parameters from Blender camera data
422
+ #
423
+ # See notes on this in
424
+ # blender.stackexchange.com/questions/15102/what-is-blenders-camera-projection-matrix-model
425
+ def get_calibration_matrix_K_from_blender(camd):
426
+ f_in_mm = camd.lens
427
+ scene = bpy.context.scene
428
+ resolution_x_in_px = scene.render.resolution_x
429
+ resolution_y_in_px = scene.render.resolution_y
430
+ scale = scene.render.resolution_percentage / 100
431
+ sensor_width_in_mm = camd.sensor_width
432
+ sensor_height_in_mm = camd.sensor_height
433
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
434
+ if (camd.sensor_fit == 'VERTICAL'):
435
+ # the sensor height is fixed (sensor fit is horizontal),
436
+ # the sensor width is effectively changed with the pixel aspect ratio
437
+ s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio
438
+ s_v = resolution_y_in_px * scale / sensor_height_in_mm
439
+ else: # 'HORIZONTAL' and 'AUTO'
440
+ # the sensor width is fixed (sensor fit is horizontal),
441
+ # the sensor height is effectively changed with the pixel aspect ratio
442
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
443
+ s_u = resolution_x_in_px * scale / sensor_width_in_mm
444
+ s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm
445
+
446
+ # Parameters of intrinsic calibration matrix K
447
+ alpha_u = f_in_mm * s_u
448
+ alpha_v = f_in_mm * s_v
449
+ u_0 = resolution_x_in_px * scale / 2
450
+ v_0 = resolution_y_in_px * scale / 2
451
+ skew = 0 # only use rectangular pixels
452
+
453
+ K = Matrix(
454
+ ((alpha_u, skew, u_0),
455
+ ( 0 , alpha_v, v_0),
456
+ ( 0 , 0, 1 )))
457
+ return K
458
+
459
+
460
+ def get_calibration_matrix_K_from_blender_for_ortho(camd, ortho_scale):
461
+ scene = bpy.context.scene
462
+ resolution_x_in_px = scene.render.resolution_x
463
+ resolution_y_in_px = scene.render.resolution_y
464
+ scale = scene.render.resolution_percentage / 100
465
+ pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
466
+
467
+ fx = resolution_x_in_px / ortho_scale
468
+ fy = resolution_y_in_px / ortho_scale / pixel_aspect_ratio
469
+
470
+ cx = resolution_x_in_px / 2
471
+ cy = resolution_y_in_px / 2
472
+
473
+ K = Matrix(
474
+ ((fx, 0, cx),
475
+ (0, fy, cy),
476
+ (0 , 0, 1)))
477
+ return K
478
+
479
+
480
+ def get_3x4_RT_matrix_from_blender(cam):
481
+ bpy.context.view_layer.update()
482
+ location, rotation = cam.matrix_world.decompose()[0:2]
483
+ R = np.asarray(rotation.to_matrix())
484
+ t = np.asarray(location)
485
+
486
+ cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
487
+ R = R.T
488
+ t = -R @ t
489
+ R_world2cv = cam_rec @ R
490
+ t_world2cv = cam_rec @ t
491
+
492
+ RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
493
+ return RT
494
+
495
+ def delete_invisible_objects() -> None:
496
+ """Deletes all invisible objects in the scene.
497
+
498
+ Returns:
499
+ None
500
+ """
501
+ bpy.ops.object.select_all(action="DESELECT")
502
+ for obj in scene.objects:
503
+ if obj.hide_viewport or obj.hide_render:
504
+ obj.hide_viewport = False
505
+ obj.hide_render = False
506
+ obj.hide_select = False
507
+ obj.select_set(True)
508
+ bpy.ops.object.delete()
509
+
510
+ # Delete invisible collections
511
+ invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
512
+ for col in invisible_collections:
513
+ bpy.data.collections.remove(col)
514
+
515
+
516
+ def normalize_scene():
517
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
518
+ at the origin.
519
+
520
+ Mostly taken from the Point-E / Shap-E rendering script
521
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
522
+ but fix for multiple root objects: (see bug report here:
523
+ https://github.com/openai/shap-e/pull/60).
524
+
525
+ Returns:
526
+ None
527
+ """
528
+ if len(list(get_scene_root_objects())) > 1:
529
+ print('we have more than one root objects!!')
530
+ # create an empty object to be used as a parent for all root objects
531
+ parent_empty = bpy.data.objects.new("ParentEmpty", None)
532
+ bpy.context.scene.collection.objects.link(parent_empty)
533
+
534
+ # parent all root objects to the empty object
535
+ for obj in get_scene_root_objects():
536
+ if obj != parent_empty:
537
+ obj.parent = parent_empty
538
+
539
+ bbox_min, bbox_max = scene_bbox()
540
+ dxyz = bbox_max - bbox_min
541
+ dist = np.sqrt(dxyz[0]**2+ dxyz[1]**2+dxyz[2]**2)
542
+ scale = 1 / dist
543
+ for obj in get_scene_root_objects():
544
+ obj.scale = obj.scale * scale
545
+
546
+ # Apply scale to matrix_world.
547
+ bpy.context.view_layer.update()
548
+ bbox_min, bbox_max = scene_bbox()
549
+ offset = -(bbox_min + bbox_max) / 2
550
+ for obj in get_scene_root_objects():
551
+ obj.matrix_world.translation += offset
552
+ bpy.ops.object.select_all(action="DESELECT")
553
+
554
+ # unparent the camera
555
+ bpy.data.objects["Camera"].parent = None
556
+ return scale, offset
557
+
558
+ def download_object(object_url: str) -> str:
559
+ """Download the object and return the path."""
560
+ # uid = uuid.uuid4()
561
+ uid = object_url.split("/")[-1].split(".")[0]
562
+ tmp_local_path = os.path.join("tmp-objects", f"{uid}.glb" + ".tmp")
563
+ local_path = os.path.join("tmp-objects", f"{uid}.glb")
564
+ # wget the file and put it in local_path
565
+ os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True)
566
+ urllib.request.urlretrieve(object_url, tmp_local_path)
567
+ os.rename(tmp_local_path, local_path)
568
+ # get the absolute path
569
+ local_path = os.path.abspath(local_path)
570
+ return local_path
571
+
572
+
573
+ def render_and_save(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
574
+ # print(view_id)
575
+ # render the image
576
+ render_path = os.path.join(args.output_dir, 'image', f"{view_id:03d}.png")
577
+ scene.render.filepath = render_path
578
+
579
+ if not ortho:
580
+ cam.data.lens = len_val
581
+
582
+ depth_map_node.inputs[1].default_value = distance - 1
583
+ depth_map_node.inputs[2].default_value = distance + 1
584
+ depth_file_output.base_path = os.path.join(args.output_dir, object_uid, 'depth')
585
+
586
+ depth_file_output.file_slots[0].path = f"{view_id:03d}"
587
+ normal_file_output.file_slots[0].path = f"{view_id:03d}"
588
+
589
+ if not os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id+1:03d}.png")):
590
+ bpy.ops.render.render(write_still=True)
591
+
592
+
593
+ if os.path.exists(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr")):
594
+ os.rename(os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}0001.exr"),
595
+ os.path.join(args.output_dir, object_uid, 'depth', f"{view_id:03d}.exr"))
596
+
597
+ if os.path.exists(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr")):
598
+ normal = cv2.imread(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
599
+ normal_unit16 = (normal * 65535).astype(np.uint16)
600
+ cv2.imwrite(os.path.join(args.output_dir, 'normal', f"{view_id:03d}.png"), normal_unit16)
601
+ os.remove(os.path.join(args.output_dir, 'normal', f"{view_id:03d}0001.exr"))
602
+
603
+ # save camera KRT matrix
604
+ if ortho:
605
+ K = get_calibration_matrix_K_from_blender_for_ortho(cam.data, ortho_scale=cam.data.ortho_scale)
606
+ else:
607
+ K = get_calibration_matrix_K_from_blender(cam.data)
608
+
609
+ RT = get_3x4_RT_matrix_from_blender(cam)
610
+ para_path = os.path.join(args.output_dir, 'camera', f"{view_id:03d}.npy")
611
+ # np.save(RT_path, RT)
612
+ paras = {}
613
+ paras['intrinsic'] = np.array(K, np.float32)
614
+ paras['extrinsic'] = np.array(RT, np.float32)
615
+ paras['fov'] = cam.data.angle
616
+ paras['azimuth'] = azimuth
617
+ paras['elevation'] = elevation
618
+ paras['distance'] = distance
619
+ paras['focal'] = cam.data.lens
620
+ paras['sensor_width'] = cam.data.sensor_width
621
+ paras['near'] = distance - 1
622
+ paras['far'] = distance + 1
623
+ paras['camera'] = 'persp' if not ortho else 'ortho'
624
+ np.save(para_path, paras)
625
+
626
+ def render_and_save_smpl(view_id, object_uid, len_val, azimuth, elevation, distance, ortho=False):
627
+
628
+
629
+ if not ortho:
630
+ cam.data.lens = len_val
631
+
632
+ render_path = os.path.join(args.output_dir, 'smpl_image', f"{view_id:03d}.png")
633
+ scene.render.filepath = render_path
634
+
635
+ normal_file_output.file_slots[0].path = f"{view_id:03d}"
636
+ if not os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png")):
637
+ bpy.ops.render.render(write_still=True)
638
+
639
+ if os.path.exists(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr")):
640
+ normal = cv2.imread(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr"), cv2.IMREAD_UNCHANGED)
641
+ normal_unit16 = (normal * 65535).astype(np.uint16)
642
+ cv2.imwrite(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}.png"), normal_unit16)
643
+ os.remove(os.path.join(args.output_dir, 'smpl_normal', f"{view_id:03d}0001.exr"))
644
+
645
+
646
+
647
+ def scene_meshes():
648
+ for obj in bpy.context.scene.objects.values():
649
+ if isinstance(obj.data, (bpy.types.Mesh)):
650
+ yield obj
651
+
652
+ def load_object(object_path: str) -> None:
653
+ """Loads a glb model into the scene."""
654
+ if object_path.endswith(".glb"):
655
+ bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=False)
656
+ elif object_path.endswith(".fbx"):
657
+ bpy.ops.import_scene.fbx(filepath=object_path)
658
+ elif object_path.endswith(".obj"):
659
+ bpy.ops.import_scene.obj(filepath=object_path)
660
+ obj = bpy.context.selected_objects[0]
661
+ obj.rotation_euler[0] = 6.28319
662
+ # obj.rotation_euler[2] = 1.5708
663
+ elif object_path.endswith(".ply"):
664
+ bpy.ops.import_mesh.ply(filepath=object_path)
665
+ obj = bpy.context.selected_objects[0]
666
+ obj.rotation_euler[0] = 1.5708
667
+ obj.rotation_euler[2] = 1.5708
668
+ # bpy.ops.wm.ply_import(filepath=object_path, directory=os.path.dirname(object_path),forward_axis='X', up_axis='Y')
669
+ process_ply(obj)
670
+ else:
671
+ raise ValueError(f"Unsupported file type: {object_path}")
672
+
673
+ def save_images(object_file: str, smpl_file: str) -> None:
674
+ """Saves rendered images of the object in the scene."""
675
+ object_uid = '' # os.path.basename(object_file).split(".")[0]
676
+ # # if we already render this object, we skip it
677
+ if os.path.exists(os.path.join(args.output_dir, 'meta.npy')): return
678
+ os.makedirs(args.output_dir, exist_ok=True)
679
+ os.makedirs(os.path.join(args.output_dir, 'camera'), exist_ok=True)
680
+
681
+ reset_scene()
682
+ load_object(object_file)
683
+
684
+ lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
685
+ for light in lights:
686
+ bpy.data.objects.remove(light, do_unlink=True)
687
+
688
+ # bproc.init()
689
+
690
+ world_tree = bpy.context.scene.world.node_tree
691
+ back_node = world_tree.nodes['Background']
692
+ env_light = 0.5
693
+ back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0])
694
+ back_node.inputs['Strength'].default_value = 1.0
695
+
696
+ #Make light just directional, disable shadows.
697
+ light_data = bpy.data.lights.new(name=f'Light', type='SUN')
698
+ light = bpy.data.objects.new(name=f'Light', object_data=light_data)
699
+ bpy.context.collection.objects.link(light)
700
+ light = bpy.data.lights['Light']
701
+ light.use_shadow = False
702
+ # Possibly disable specular shading:
703
+ light.specular_factor = 1.0
704
+ light.energy = 5.0
705
+
706
+ #Add another light source so stuff facing away from light is not completely dark
707
+ light_data = bpy.data.lights.new(name=f'Light2', type='SUN')
708
+ light = bpy.data.objects.new(name=f'Light2', object_data=light_data)
709
+ bpy.context.collection.objects.link(light)
710
+ light2 = bpy.data.lights['Light2']
711
+ light2.use_shadow = False
712
+ light2.specular_factor = 1.0
713
+ light2.energy = 3 #0.015
714
+ bpy.data.objects['Light2'].rotation_euler = bpy.data.objects['Light2'].rotation_euler
715
+ bpy.data.objects['Light2'].rotation_euler[0] += 180
716
+
717
+ #Add another light source so stuff facing away from light is not completely dark
718
+ light_data = bpy.data.lights.new(name=f'Light3', type='SUN')
719
+ light = bpy.data.objects.new(name=f'Light3', object_data=light_data)
720
+ bpy.context.collection.objects.link(light)
721
+ light3 = bpy.data.lights['Light3']
722
+ light3.use_shadow = False
723
+ light3.specular_factor = 1.0
724
+ light3.energy = 3 #0.015
725
+ bpy.data.objects['Light3'].rotation_euler = bpy.data.objects['Light3'].rotation_euler
726
+ bpy.data.objects['Light3'].rotation_euler[0] += 90
727
+
728
+ #Add another light source so stuff facing away from light is not completely dark
729
+ light_data = bpy.data.lights.new(name=f'Light4', type='SUN')
730
+ light = bpy.data.objects.new(name=f'Light4', object_data=light_data)
731
+ bpy.context.collection.objects.link(light)
732
+ light4 = bpy.data.lights['Light4']
733
+ light4.use_shadow = False
734
+ light4.specular_factor = 1.0
735
+ light4.energy = 3 #0.015
736
+ bpy.data.objects['Light4'].rotation_euler = bpy.data.objects['Light4'].rotation_euler
737
+ bpy.data.objects['Light4'].rotation_euler[0] += -90
738
+
739
+ scale, offset = normalize_scene()
740
+
741
+
742
+ try:
743
+ # some objects' normals are affected by textures
744
+ mesh_objects = [obj for obj in scene_meshes()]
745
+ main_bsdf_name = 'BsdfPrincipled'
746
+ normal_name = 'Normal'
747
+ for obj in mesh_objects:
748
+ for mat in obj.data.materials:
749
+ for node in mat.node_tree.nodes:
750
+ if main_bsdf_name in node.bl_idname:
751
+ principled_bsdf = node
752
+ # remove links, we don't want add normal textures
753
+ if principled_bsdf.inputs[normal_name].links:
754
+ mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
755
+ except:
756
+ print("don't know why")
757
+ # create an empty object to track
758
+ empty = bpy.data.objects.new("Empty", None)
759
+ scene.collection.objects.link(empty)
760
+ cam_constraint.target = empty
761
+
762
+ subject_width = 1.0
763
+
764
+ normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'normal')
765
+ for i in range(args.num_images):
766
+ # change the camera to orthogonal
767
+ cam.data.type = 'ORTHO'
768
+ cam.data.ortho_scale = subject_width
769
+ distance = 1.5
770
+ azimuth = i * 360 / args.num_images
771
+ bpy.context.view_layer.update()
772
+ set_camera_mvdream(azimuth, 0, distance)
773
+ render_and_save(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
774
+ extract_depth(os.path.join(args.output_dir, object_uid, 'depth'))
775
+ # #### smpl
776
+ reset_scene()
777
+ load_object(smpl_file)
778
+
779
+ lights = [obj for obj in bpy.context.scene.objects if obj.type == 'LIGHT']
780
+ for light in lights:
781
+ bpy.data.objects.remove(light, do_unlink=True)
782
+
783
+ scale, offset = normalize_scene()
784
+
785
+ try:
786
+ # some objects' normals are affected by textures
787
+ mesh_objects = [obj for obj in scene_meshes()]
788
+ main_bsdf_name = 'BsdfPrincipled'
789
+ normal_name = 'Normal'
790
+ for obj in mesh_objects:
791
+ for mat in obj.data.materials:
792
+ for node in mat.node_tree.nodes:
793
+ if main_bsdf_name in node.bl_idname:
794
+ principled_bsdf = node
795
+ # remove links, we don't want add normal textures
796
+ if principled_bsdf.inputs[normal_name].links:
797
+ mat.node_tree.links.remove(principled_bsdf.inputs[normal_name].links[0])
798
+ except:
799
+ print("don't know why")
800
+ # create an empty object to track
801
+ empty = bpy.data.objects.new("Empty", None)
802
+ scene.collection.objects.link(empty)
803
+ cam_constraint.target = empty
804
+
805
+ subject_width = 1.0
806
+
807
+ normal_file_output.base_path = os.path.join(args.output_dir, object_uid, 'smpl_normal')
808
+ for i in range(args.num_images):
809
+ # change the camera to orthogonal
810
+ cam.data.type = 'ORTHO'
811
+ cam.data.ortho_scale = subject_width
812
+ distance = 1.5
813
+ azimuth = i * 360 / args.num_images
814
+ bpy.context.view_layer.update()
815
+ set_camera_mvdream(azimuth, 0, distance)
816
+ render_and_save_smpl(i * (args.random_images+1), object_uid, -1, azimuth, 0, distance, ortho=True)
817
+
818
+
819
+ np.save(os.path.join(args.output_dir, object_uid, 'meta.npy'), np.asarray([scale, offset[0], offset[1], offset[1]],np.float32))
820
+
821
+
822
+ if __name__ == "__main__":
823
+ try:
824
+ start_i = time.time()
825
+ if args.object_path.startswith("http"):
826
+ local_path = download_object(args.object_path)
827
+ else:
828
+ local_path = args.object_path
829
+ save_images(local_path, args.smpl_path)
830
+ end_i = time.time()
831
+ print("Finished", local_path, "in", end_i - start_i, "seconds")
832
+ # delete the object if it was downloaded
833
+ if args.object_path.startswith("http"):
834
+ os.remove(local_path)
835
+ except Exception as e:
836
+ print("Failed to render", args.object_path)
837
+ print(e)
blender/check_render.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import json
4
+ from icecream import ic
5
+
6
+
7
+ def check_render(dataset, st=None, end=None):
8
+ total_lists = []
9
+ with open(dataset+'.json', 'r') as f:
10
+ glb_list = json.load(f)
11
+ for x in glb_list:
12
+ total_lists.append(x.split('/')[-2] )
13
+
14
+ if st is not None:
15
+ end = min(end, len(total_lists))
16
+ total_lists = total_lists[st:end]
17
+ glb_list = glb_list[st:end]
18
+
19
+ save_dir = '/data/lipeng/human_8view_with_smplx/'+dataset
20
+ unrendered = set(total_lists) - set(os.listdir(save_dir))
21
+
22
+ num_finish = 0
23
+ num_failed = len(unrendered)
24
+ failed_case = []
25
+ for case in os.listdir(save_dir):
26
+ if not os.path.exists(os.path.join(save_dir, case, 'smpl_normal', '007.png')):
27
+ failed_case.append(case)
28
+ num_failed += 1
29
+ else:
30
+ num_finish += 1
31
+ ic(num_failed)
32
+ ic(num_finish)
33
+
34
+
35
+ need_render = []
36
+ for full_path in glb_list:
37
+ for case in failed_case:
38
+ if case in full_path:
39
+ need_render.append(full_path)
40
+
41
+ with open('need_render.json', 'w') as f:
42
+ json.dump(need_render, f, indent=4)
43
+
44
+ if __name__ == '__main__':
45
+ dataset = 'THuman2.1'
46
+ check_render(dataset)
blender/count.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ def find_files(directory, extensions):
4
+ results = []
5
+ for foldername, subfolders, filenames in os.walk(directory):
6
+ for filename in filenames:
7
+ if filename.endswith(extensions):
8
+ file_path = os.path.abspath(os.path.join(foldername, filename))
9
+ results.append(file_path)
10
+ return results
11
+
12
+ def count_customhumans(root):
13
+ directory_path = ['CustomHumans/mesh']
14
+
15
+ extensions = ('.ply', '.obj')
16
+
17
+ lists = []
18
+ for dataset_path in directory_path:
19
+ dir = os.path.join(root, dataset_path)
20
+ file_paths = find_files(dir, extensions)
21
+ # import pdb;pdb.set_trace()
22
+ dataset_name = dataset_path.split('/')[0]
23
+ for file_path in file_paths:
24
+ lists.append(file_path.replace(root, ""))
25
+ with open(f'{dataset_name}.json', 'w') as f:
26
+ json.dump(lists, f, indent=4)
27
+
28
+ def count_thuman21(root):
29
+ directory_path = ['THuman2.1/mesh']
30
+ extensions = ('.ply', '.obj')
31
+ lists = []
32
+ for dataset_path in directory_path:
33
+ dir = os.path.join(root, dataset_path)
34
+ file_paths = find_files(dir, extensions)
35
+ dataset_name = dataset_path.split('/')[0]
36
+ for file_path in file_paths:
37
+ lists.append(file_path.replace(root, ""))
38
+ with open(f'{dataset_name}.json', 'w') as f:
39
+ json.dump(lists, f, indent=4)
40
+
41
+ if __name__ == '__main__':
42
+ root = '/data/lipeng/human_scan/'
43
+ # count_customhumans(root)
44
+ count_thuman21(root)
blender/distribute.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import multiprocessing
4
+ import shutil
5
+ import subprocess
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+ import os
10
+
11
+ import boto3
12
+
13
+
14
+ from glob import glob
15
+
16
+ import argparse
17
+
18
+ parser = argparse.ArgumentParser(description='distributed rendering')
19
+
20
+ parser.add_argument('--workers_per_gpu', type=int, default=10,
21
+ help='number of workers per gpu.')
22
+ parser.add_argument('--input_models_path', type=str, default='/data/lipeng/human_scan/',
23
+ help='Path to a json file containing a list of 3D object files.')
24
+ parser.add_argument('--num_gpus', type=int, default=-1,
25
+ help='number of gpus to use. -1 means all available gpus.')
26
+ parser.add_argument('--gpu_list',nargs='+', type=int,
27
+ help='the avalaible gpus')
28
+ parser.add_argument('--resolution', type=int, default=512,
29
+ help='')
30
+ parser.add_argument('--random_images', type=int, default=0)
31
+ parser.add_argument('--start_i', type=int, default=0,
32
+ help='the index of first object to be rendered.')
33
+ parser.add_argument('--end_i', type=int, default=-1,
34
+ help='the index of the last object to be rendered.')
35
+
36
+ parser.add_argument('--data_dir', type=str, default='/data/lipeng/human_scan/',
37
+ help='Path to a json file containing a list of 3D object files.')
38
+
39
+ parser.add_argument('--json_path', type=str, default='2K2K.json')
40
+
41
+ parser.add_argument('--save_dir', type=str, default='/data/lipeng/human_8view',
42
+ help='Path to a json file containing a list of 3D object files.')
43
+
44
+ parser.add_argument('--ortho_scale', type=float, default=1.,
45
+ help='ortho rendering usage; how large the object is')
46
+
47
+
48
+ args = parser.parse_args()
49
+
50
+ def parse_obj_list(xs):
51
+ cases = []
52
+ # print(xs[:2])
53
+
54
+ for x in xs:
55
+ if 'THuman3.0' in x:
56
+ # print(apath)
57
+ splits = x.split('/')
58
+ x = os.path.join('THuman3.0', splits[-2])
59
+ elif 'THuman2.1' in x:
60
+ splits = x.split('/')
61
+ x = os.path.join('THuman2.1', splits[-2])
62
+ elif 'CustomHumans' in x:
63
+ splits = x.split('/')
64
+ x = os.path.join('CustomHumans', splits[-2])
65
+ elif '1M' in x:
66
+ splits = x.split('/')
67
+ x = os.path.join('2K2K', splits[-2])
68
+ elif 'realistic_8k_model' in x:
69
+ splits = x.split('/')
70
+ x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
71
+ cases.append(f'{args.save_dir}/{x}')
72
+ return cases
73
+
74
+
75
+ with open(args.json_path, 'r') as f:
76
+ glb_list = json.load(f)
77
+
78
+ # glb_list = ['THuman2.1/mesh/1618/1618.obj']
79
+ # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
80
+ # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj']
81
+ # glb_list = ['1M/01968/01968.ply', '1M/00103/00103.ply']
82
+ # glb_list = ['realistic_8k_model/01aab099a2fe4af7be120110a385105d.glb']
83
+
84
+ total_num_glbs = len(glb_list)
85
+
86
+
87
+
88
+ def worker(
89
+ queue: multiprocessing.JoinableQueue,
90
+ count: multiprocessing.Value,
91
+ gpu: int,
92
+ s3: Optional[boto3.client],
93
+ ) -> None:
94
+ print("Worker started")
95
+ while True:
96
+ case, save_p = queue.get()
97
+ src_path = os.path.join(args.data_dir, case)
98
+ smpl_path = src_path.replace('mesh', 'smplx', 1)
99
+
100
+ command = ('blender -b -P blender_render_human_ortho.py'
101
+ f' -- --object_path {src_path}'
102
+ f' --smpl_path {smpl_path}'
103
+ f' --output_dir {save_p} --engine CYCLES'
104
+ f' --resolution {args.resolution}'
105
+ f' --random_images {args.random_images}'
106
+ )
107
+
108
+ print(command)
109
+ subprocess.run(command, shell=True)
110
+
111
+ with count.get_lock():
112
+ count.value += 1
113
+
114
+ queue.task_done()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ # args = tyro.cli(Args)
119
+
120
+ s3 = None
121
+ queue = multiprocessing.JoinableQueue()
122
+ count = multiprocessing.Value("i", 0)
123
+
124
+ # Start worker processes on each of the GPUs
125
+ for gpu_i in range(args.num_gpus):
126
+ for worker_i in range(args.workers_per_gpu):
127
+ worker_i = gpu_i * args.workers_per_gpu + worker_i
128
+ process = multiprocessing.Process(
129
+ target=worker, args=(queue, count, args.gpu_list[gpu_i], s3)
130
+ )
131
+ process.daemon = True
132
+ process.start()
133
+
134
+ # Add items to the queue
135
+
136
+ save_dirs = parse_obj_list(glb_list)
137
+ args.end_i = len(save_dirs) if args.end_i > len(save_dirs) or args.end_i==-1 else args.end_i
138
+
139
+ for case_sub, save_dir in zip(glb_list[args.start_i:args.end_i], save_dirs[args.start_i:args.end_i]):
140
+ queue.put([case_sub, save_dir])
141
+
142
+
143
+
144
+ # Wait for all tasks to be completed
145
+ queue.join()
146
+
147
+ # Add sentinels to the queue to stop the worker processes
148
+ for i in range(args.num_gpus * args.workers_per_gpu):
149
+ queue.put(None)
blender/rename_smpl_files.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ from glob import glob
4
+
5
+ def rename_customhumans():
6
+ root = '/data/lipeng/human_scan/CustomHumans/smplx/'
7
+ file_paths = glob(os.path.join(root, '*/*_smpl.obj'))
8
+ for file_path in tqdm(file_paths):
9
+ new_path = file_path.replace('_smpl', '')
10
+ os.rename(file_path, new_path)
11
+
12
+ def rename_thuman21():
13
+ root = '/data/lipeng/human_scan/THuman2.1/smplx/'
14
+ file_paths = glob(os.path.join(root, '*/*.obj'))
15
+ for file_path in tqdm(file_paths):
16
+ obj_name = file_path.split('/')[-2]
17
+ folder_name = os.path.dirname(file_path)
18
+ new_path = os.path.join(folder_name, obj_name+'.obj')
19
+ # print(new_path)
20
+ # print(file_path)
21
+ os.rename(file_path, new_path)
22
+
23
+ if __name__ == '__main__':
24
+ rename_thuman21()
25
+ rename_customhumans()
blender/render.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #### install environment
2
+ # ~/pkgs/blender-3.6.4/3.6/python/bin/python3.10 -m pip install openexr opencv-python
3
+
4
+ python render_human.py
blender/render_human.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ import threading
6
+ from tqdm import tqdm
7
+
8
+ # from glcontext import egl
9
+ # egl.create_context()
10
+ # exit(0)
11
+
12
+ LOCAL_RANK = 0
13
+
14
+ num_processes = 4
15
+ NODE_RANK = int(os.getenv("SLURM_PROCID"))
16
+ WORLD_SIZE = 1
17
+ NODE_NUM=1
18
+ # NODE_RANK = int(os.getenv("SLURM_NODEID"))
19
+ IS_MAIN = False
20
+ if NODE_RANK == 0 and LOCAL_RANK == 0:
21
+ IS_MAIN = True
22
+
23
+ GLOBAL_RANK = NODE_RANK * (WORLD_SIZE//NODE_NUM) + LOCAL_RANK
24
+
25
+
26
+ # json_path = "object_lists/Thuman2.0.json"
27
+ # json_path = "object_lists/THuman3.0.json"
28
+ json_path = "object_lists/CustomHumans.json"
29
+ data_dir = '/aifs4su/mmcode/lipeng'
30
+ save_dir = '/aifs4su/mmcode/lipeng/human_8view_new'
31
+ def parse_obj_list(x):
32
+ if 'THuman3.0' in x:
33
+ # print(apath)
34
+ splits = x.split('/')
35
+ x = os.path.join('THuman3.0', splits[-2])
36
+ elif 'Thuman2.0' in x:
37
+ splits = x.split('/')
38
+ x = os.path.join('Thuman2.0', splits[-2])
39
+ elif 'CustomHumans' in x:
40
+ splits = x.split('/')
41
+ x = os.path.join('CustomHumans', splits[-2])
42
+ # print(splits[-2])
43
+ elif '1M' in x:
44
+ splits = x.split('/')
45
+ x = os.path.join('2K2K', splits[-2])
46
+ elif 'realistic_8k_model' in x:
47
+ splits = x.split('/')
48
+ x = os.path.join('realistic_8k_model', splits[-1].split('.')[0])
49
+ return f'{save_dir}/{x}'
50
+
51
+ with open(json_path, 'r') as f:
52
+ glb_list = json.load(f)
53
+
54
+ # glb_list = ['Thuman2.0/0011/0011.obj']
55
+ # glb_list = ['THuman3.0/00024_1/00024_0006/mesh.obj']
56
+ # glb_list = ['CustomHumans/mesh/0383_00070_02_00061/mesh-f00061.obj']
57
+ # glb_list = ['realistic_8k_model/1d41f2a72f994306b80e632f1cc8233f.glb']
58
+
59
+ total_num_glbs = len(glb_list)
60
+
61
+ num_glbs_local = int(math.ceil(total_num_glbs / WORLD_SIZE))
62
+ start_idx = GLOBAL_RANK * num_glbs_local
63
+ end_idx = start_idx + num_glbs_local
64
+ # print(start_idx, end_idx)
65
+ local_glbs = glb_list[start_idx:end_idx]
66
+ if IS_MAIN:
67
+ pbar = tqdm(total=len(local_glbs))
68
+ lock = threading.Lock()
69
+
70
+ def process_human(glb_path):
71
+ src_path = os.path.join(data_dir, glb_path)
72
+ save_path = parse_obj_list(glb_path)
73
+ # print(save_path)
74
+ command = ('blender -b -P blender_render_human_script.py'
75
+ f' -- --object_path {src_path}'
76
+ f' --output_dir {save_path} ')
77
+ # 1>/dev/null
78
+ # print(command)
79
+ os.system(command)
80
+
81
+ if IS_MAIN:
82
+ with lock:
83
+ pbar.update(1)
84
+
85
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
86
+ executor.map(process_human, local_glbs)
87
+
88
+
blender/render_single.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # debug single sample
2
+ blender -b -P blender_render_human_ortho.py \
3
+ -- --object_path /data/lipeng/human_scan/THuman2.1/mesh/0011/0011.obj \
4
+ --smpl_path /data/lipeng/human_scan/THuman2.1/smplx/0011/0011.obj \
5
+ --output_dir debug --engine CYCLES \
6
+ --resolution 768 \
7
+ --random_images 0
blender/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import pytz
3
+ import traceback
4
+ from torchvision.utils import make_grid
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ import os
10
+ from tqdm import tqdm
11
+ import cv2
12
+ import imageio
13
+ def get_time_for_log():
14
+ return datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime(
15
+ "%Y%m%d %H:%M:%S")
16
+
17
+
18
+ def get_trace_for_log():
19
+ return str(traceback.format_exc())
20
+
21
+ def make_grid_(imgs, save_file, nrow=10, pad_value=1):
22
+ if isinstance(imgs, list):
23
+ if isinstance(imgs[0], Image.Image):
24
+ imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs]
25
+ elif isinstance(imgs[0], np.ndarray):
26
+ imgs = [torch.from_numpy(img/255.) for img in imgs]
27
+ imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2)
28
+ if isinstance(imgs, np.ndarray):
29
+ imgs = torch.from_numpy(imgs)
30
+
31
+ img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value)
32
+ img_grid = img_grid.permute(1, 2, 0).numpy()
33
+ img_grid = (img_grid * 255).astype(np.uint8)
34
+ img_grid = Image.fromarray(img_grid)
35
+ img_grid.save(save_file)
36
+
37
+ def draw_caption(img, text, pos, size=100, color=(128, 128, 128)):
38
+ draw = ImageDraw.Draw(img)
39
+ # font = ImageFont.truetype(size= size)
40
+ font = ImageFont.load_default()
41
+ font = font.font_variant(size=size)
42
+ draw.text(pos, text, color, font=font)
43
+ return img
44
+
45
+
46
+ def txt2json(txt_file, json_file):
47
+ with open(txt_file, 'r') as f:
48
+ items = f.readlines()
49
+ items = [x.strip() for x in items]
50
+
51
+ with open(json_file, 'w') as f:
52
+ json.dump(items.tolist(), f)
53
+
54
+ def process_thuman_texture():
55
+ path = '/aifs4su/mmcode/lipeng/Thuman2.0'
56
+ cases = os.listdir(path)
57
+ for case in tqdm(cases):
58
+ mtl = os.path.join(path, case, 'material0.mtl')
59
+ with open(mtl, 'r') as f:
60
+ lines = f.read()
61
+ lines = lines.replace('png', 'jpeg')
62
+ with open(mtl, 'w') as f:
63
+ f.write(lines)
64
+
65
+
66
+ #### for debug
67
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
68
+
69
+
70
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
71
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
72
+ intrinsic = np.identity(3, dtype=np.float32)
73
+ intrinsic[0, 0] = focal_length
74
+ intrinsic[1, 1] = focal_length
75
+ intrinsic[0, 2] = W / 2.0
76
+ intrinsic[1, 2] = H / 2.0
77
+
78
+ if bs > 0:
79
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
80
+
81
+ return torch.from_numpy(intrinsic)
82
+
83
+ def read_data(data_dir, i):
84
+ """
85
+ Return:
86
+ rgb: (H, W, 3) torch.float32
87
+ depth: (H, W, 1) torch.float32
88
+ mask: (H, W, 1) torch.float32
89
+ c2w: (4, 4) torch.float32
90
+ intrinsic: (3, 3) torch.float32
91
+ """
92
+ background_color = torch.tensor([0.0, 0.0, 0.0])
93
+
94
+ rgb_name = os.path.join(data_dir, f'render_%04d.webp' % i)
95
+ depth_name = os.path.join(data_dir, f'depth_%04d.exr' % i)
96
+
97
+ img = torch.from_numpy(
98
+ np.asarray(
99
+ Image.fromarray(imageio.v2.imread(rgb_name))
100
+ .convert("RGBA")
101
+ )
102
+ / 255.0
103
+ ).float()
104
+ mask = img[:, :, -1:]
105
+ rgb = img[:, :, :3] * mask + background_color[
106
+ None, None, :
107
+ ] * (1 - mask)
108
+
109
+ depth = torch.from_numpy(
110
+ cv2.imread(depth_name, cv2.IMREAD_UNCHANGED)[..., 0, None]
111
+ )
112
+ mask[depth > 100.0] = 0.0
113
+ depth[~(mask > 0.5)] = 0.0 # set invalid depth to 0
114
+
115
+ meta_path = os.path.join(data_dir, 'meta.json')
116
+ with open(meta_path, 'r') as f:
117
+ meta = json.load(f)
118
+
119
+ c2w = torch.as_tensor(
120
+ meta['locations'][i]["transform_matrix"],
121
+ dtype=torch.float32,
122
+ )
123
+
124
+ H, W = rgb.shape[:2]
125
+ fovy = meta["camera_angle_x"]
126
+ intrinsic = get_intrinsic_from_fov(fovy, H=H, W=W)
127
+
128
+ return rgb, depth, mask, c2w, intrinsic
configs/inference-768-6view.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1-unclip'
2
+ revision: null
3
+
4
+ num_views: 7
5
+ with_smpl: false
6
+ validation_dataset:
7
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
8
+ root_dir: 'examples/shhq'
9
+ num_views: ${num_views}
10
+ bg_color: 'white'
11
+ img_wh: [768, 768]
12
+ num_validation_samples: 1000
13
+ crop_size: 740
14
+ margin_size: 50
15
+ smpl_folder: 'smpl_image_pymaf'
16
+
17
+
18
+ save_dir: 'mv_results'
19
+ save_mode: 'rgba' # 'concat', 'rgba', 'rgb'
20
+ seed: 42
21
+ validation_batch_size: 1
22
+ dataloader_num_workers: 1
23
+ local_rank: -1
24
+
25
+ pipe_kwargs:
26
+ num_views: ${num_views}
27
+
28
+ validation_guidance_scales: 3.0
29
+ pipe_validation_kwargs:
30
+ num_inference_steps: 40
31
+ eta: 1.0
32
+
33
+ validation_grid_nrow: ${num_views}
34
+
35
+ unet_from_pretrained_kwargs:
36
+ unclip: true
37
+ sdxl: false
38
+ num_views: ${num_views}
39
+ sample_size: 96
40
+ zero_init_conv_in: false # modify
41
+
42
+ projection_camera_embeddings_input_dim: 2 # 2 for elevation and 6 for focal_length
43
+ zero_init_camera_projection: false
44
+ num_regress_blocks: 3
45
+
46
+ cd_attention_last: false
47
+ cd_attention_mid: false
48
+ multiview_attention: true
49
+ sparse_mv_attention: true
50
+ selfattn_block: self_rowwise
51
+ mvcd_attention: true
52
+
53
+ recon_opt:
54
+ res_path: out
55
+ save_glb: False
56
+ # camera setting
57
+ num_view: 6
58
+ scale: 4
59
+ mode: ortho
60
+ resolution: 1024
61
+ cam_path: 'mvdiffusion/data/six_human_pose'
62
+ # optimization
63
+ iters: 700
64
+ clr_iters: 200
65
+ debug: false
66
+ snapshot_step: 50
67
+ lr_clr: 2e-3
68
+ gpu_id: 0
69
+
70
+ replace_hand: false
71
+
72
+ enable_xformers_memory_efficient_attention: true
configs/remesh.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ res_path: out
2
+ save_glb: False
3
+ imgs_path: examples/debug
4
+ mv_path: ./
5
+ # camera setting
6
+ num_view: 6
7
+ scale: 4
8
+ mode: ortho
9
+ resolution: 1024
10
+ cam_path: 'mvdiffusion/data/six_human_pose'
11
+ # optimization
12
+ iters: 700
13
+ clr_iters: 200
14
+ debug: false
15
+ snapshot_step: 50
16
+ lr_clr: 2e-3
17
+ gpu_id: 0
18
+ replace_hand: false
configs/train-768-6view-onlyscan_face.yaml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
2
+ pretrained_unet_path: null
3
+ revision: null
4
+ with_smpl: false
5
+ data_common:
6
+ root_dir: /aifs4su/mmcode/lipeng/human_8view_new/
7
+ predict_relative_views: [0, 1, 2, 4, 6, 7]
8
+ num_validation_samples: 8
9
+ img_wh: [768, 768]
10
+ read_normal: true
11
+ read_color: true
12
+ read_depth: false
13
+ exten: .png
14
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
15
+ object_list:
16
+ - data_lists/human_only_scan.json
17
+ invalid_list:
18
+ -
19
+ train_dataset:
20
+ root_dir: ${data_common.root_dir}
21
+ azi_interval: 45.0
22
+ random_views: 3
23
+ predict_relative_views: ${data_common.predict_relative_views}
24
+ bg_color: three_choices
25
+ object_list: ${data_common.object_list}
26
+ invalid_list: ${data_common.invalid_list}
27
+ img_wh: ${data_common.img_wh}
28
+ validation: false
29
+ num_validation_samples: ${data_common.num_validation_samples}
30
+ read_normal: ${data_common.read_normal}
31
+ read_color: ${data_common.read_color}
32
+ read_depth: ${data_common.read_depth}
33
+ load_cache: false
34
+ exten: ${data_common.exten}
35
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
36
+ side_views_rate: 0.3
37
+ elevation_list: null
38
+ validation_dataset:
39
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
40
+ root_dir: examples/debug
41
+ num_views: ${num_views}
42
+ bg_color: white
43
+ img_wh: ${data_common.img_wh}
44
+ num_validation_samples: 1000
45
+ crop_size: 740
46
+ validation_train_dataset:
47
+ root_dir: ${data_common.root_dir}
48
+ azi_interval: 45.0
49
+ random_views: 3
50
+ predict_relative_views: ${data_common.predict_relative_views}
51
+ bg_color: white
52
+ object_list: ${data_common.object_list}
53
+ invalid_list: ${data_common.invalid_list}
54
+ img_wh: ${data_common.img_wh}
55
+ validation: false
56
+ num_validation_samples: ${data_common.num_validation_samples}
57
+ read_normal: ${data_common.read_normal}
58
+ read_color: ${data_common.read_color}
59
+ read_depth: ${data_common.read_depth}
60
+ num_samples: ${data_common.num_validation_samples}
61
+ load_cache: false
62
+ exten: ${data_common.exten}
63
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
64
+ elevation_list: null
65
+ output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5
66
+ checkpoint_prefix: ../human_checkpoint_backup/
67
+ seed: 42
68
+ train_batch_size: 2
69
+ validation_batch_size: 1
70
+ validation_train_batch_size: 1
71
+ max_train_steps: 30000
72
+ gradient_accumulation_steps: 2
73
+ gradient_checkpointing: true
74
+ learning_rate: 0.0001
75
+ scale_lr: false
76
+ lr_scheduler: piecewise_constant
77
+ step_rules: 1:2000,0.5
78
+ lr_warmup_steps: 10
79
+ snr_gamma: 5.0
80
+ use_8bit_adam: false
81
+ allow_tf32: true
82
+ use_ema: true
83
+ dataloader_num_workers: 32
84
+ adam_beta1: 0.9
85
+ adam_beta2: 0.999
86
+ adam_weight_decay: 0.01
87
+ adam_epsilon: 1.0e-08
88
+ max_grad_norm: 1.0
89
+ prediction_type: null
90
+ logging_dir: logs
91
+ vis_dir: vis
92
+ mixed_precision: fp16
93
+ report_to: wandb
94
+ local_rank: 0
95
+ checkpointing_steps: 2500
96
+ checkpoints_total_limit: 2
97
+ resume_from_checkpoint: latest
98
+ enable_xformers_memory_efficient_attention: true
99
+ validation_steps: 2500 #
100
+ validation_sanity_check: true
101
+ tracker_project_name: PSHuman
102
+ trainable_modules: null
103
+
104
+
105
+ use_classifier_free_guidance: true
106
+ condition_drop_rate: 0.05
107
+ scale_input_latents: true
108
+ regress_elevation: false
109
+ regress_focal_length: false
110
+ elevation_loss_weight: 1.0
111
+ focal_loss_weight: 0.0
112
+ pipe_kwargs:
113
+ num_views: ${num_views}
114
+ pipe_validation_kwargs:
115
+ eta: 1.0
116
+
117
+ unet_from_pretrained_kwargs:
118
+ unclip: true
119
+ num_views: ${num_views}
120
+ sample_size: 96
121
+ zero_init_conv_in: true
122
+ regress_elevation: ${regress_elevation}
123
+ regress_focal_length: ${regress_focal_length}
124
+ num_regress_blocks: 2
125
+ camera_embedding_type: e_de_da_sincos
126
+ projection_camera_embeddings_input_dim: 2
127
+ zero_init_camera_projection: true # modified
128
+ init_mvattn_with_selfattn: false
129
+ cd_attention_last: false
130
+ cd_attention_mid: false
131
+ multiview_attention: true
132
+ sparse_mv_attention: true
133
+ selfattn_block: self_rowwise
134
+ mvcd_attention: true
135
+ addition_downsample: false
136
+ use_face_adapter: false
137
+
138
+ validation_guidance_scales:
139
+ - 3.0
140
+ validation_grid_nrow: ${num_views}
141
+ camera_embedding_lr_mult: 1.0
142
+ plot_pose_acc: false
143
+ num_views: 7
144
+ pred_type: joint
145
+ drop_type: drop_as_a_whole
configs/train-768-6view-onlyscan_face_smplx.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-2-1-unclip
2
+ pretrained_unet_path: null
3
+ revision: null
4
+ with_smpl: true
5
+ data_common:
6
+ root_dir: /aifs4su/mmcode/lipeng/human_8view_with_smplx/
7
+ predict_relative_views: [0, 1, 2, 4, 6, 7]
8
+ num_validation_samples: 8
9
+ img_wh: [768, 768]
10
+ read_normal: true
11
+ read_color: true
12
+ read_depth: false
13
+ exten: .png
14
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_7view
15
+ object_list:
16
+ - data_lists/human_only_scan_with_smplx.json # modified
17
+ invalid_list:
18
+ -
19
+ with_smpl: ${with_smpl}
20
+
21
+ train_dataset:
22
+ root_dir: ${data_common.root_dir}
23
+ azi_interval: 45.0
24
+ random_views: 0
25
+ predict_relative_views: ${data_common.predict_relative_views}
26
+ bg_color: three_choices
27
+ object_list: ${data_common.object_list}
28
+ invalid_list: ${data_common.invalid_list}
29
+ img_wh: ${data_common.img_wh}
30
+ validation: false
31
+ num_validation_samples: ${data_common.num_validation_samples}
32
+ read_normal: ${data_common.read_normal}
33
+ read_color: ${data_common.read_color}
34
+ read_depth: ${data_common.read_depth}
35
+ load_cache: false
36
+ exten: ${data_common.exten}
37
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
38
+ side_views_rate: 0.3
39
+ elevation_list: null
40
+ with_smpl: ${with_smpl}
41
+
42
+ validation_dataset:
43
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
44
+ root_dir: examples/debug
45
+ num_views: ${num_views}
46
+ bg_color: white
47
+ img_wh: ${data_common.img_wh}
48
+ num_validation_samples: 1000
49
+ margin_size: 10
50
+ # crop_size: 720
51
+
52
+ validation_train_dataset:
53
+ root_dir: ${data_common.root_dir}
54
+ azi_interval: 45.0
55
+ random_views: 0
56
+ predict_relative_views: ${data_common.predict_relative_views}
57
+ bg_color: white
58
+ object_list: ${data_common.object_list}
59
+ invalid_list: ${data_common.invalid_list}
60
+ img_wh: ${data_common.img_wh}
61
+ validation: false
62
+ num_validation_samples: ${data_common.num_validation_samples}
63
+ read_normal: ${data_common.read_normal}
64
+ read_color: ${data_common.read_color}
65
+ read_depth: ${data_common.read_depth}
66
+ num_samples: ${data_common.num_validation_samples}
67
+ load_cache: false
68
+ exten: ${data_common.exten}
69
+ prompt_embeds_path: ${data_common.prompt_embeds_path}
70
+ elevation_list: null
71
+ with_smpl: ${with_smpl}
72
+
73
+ output_dir: output/unit-unclip-768-6view-onlyscan-onlyortho-faceinself-scale0.5-smplx
74
+ checkpoint_prefix: ../human_checkpoint_backup/
75
+ seed: 42
76
+ train_batch_size: 2
77
+ validation_batch_size: 1
78
+ validation_train_batch_size: 1
79
+ max_train_steps: 30000
80
+ gradient_accumulation_steps: 2
81
+ gradient_checkpointing: true
82
+ learning_rate: 0.0001
83
+ scale_lr: false
84
+ lr_scheduler: piecewise_constant
85
+ step_rules: 1:2000,0.5
86
+ lr_warmup_steps: 10
87
+ snr_gamma: 5.0
88
+ use_8bit_adam: false
89
+ allow_tf32: true
90
+ use_ema: true
91
+ dataloader_num_workers: 32
92
+ adam_beta1: 0.9
93
+ adam_beta2: 0.999
94
+ adam_weight_decay: 0.01
95
+ adam_epsilon: 1.0e-08
96
+ max_grad_norm: 1.0
97
+ prediction_type: null
98
+ logging_dir: logs
99
+ vis_dir: vis
100
+ mixed_precision: fp16
101
+ report_to: wandb
102
+ local_rank: 0
103
+ checkpointing_steps: 5000
104
+ checkpoints_total_limit: 2
105
+ resume_from_checkpoint: latest
106
+ enable_xformers_memory_efficient_attention: true
107
+ validation_steps: 2500 #
108
+ validation_sanity_check: true
109
+ tracker_project_name: PSHuman
110
+ trainable_modules: null
111
+
112
+ use_classifier_free_guidance: true
113
+ condition_drop_rate: 0.05
114
+ scale_input_latents: true
115
+ regress_elevation: false
116
+ regress_focal_length: false
117
+ elevation_loss_weight: 1.0
118
+ focal_loss_weight: 0.0
119
+ pipe_kwargs:
120
+ num_views: ${num_views}
121
+ pipe_validation_kwargs:
122
+ eta: 1.0
123
+
124
+ unet_from_pretrained_kwargs:
125
+ unclip: true
126
+ num_views: ${num_views}
127
+ sample_size: 96
128
+ zero_init_conv_in: true
129
+ regress_elevation: ${regress_elevation}
130
+ regress_focal_length: ${regress_focal_length}
131
+ num_regress_blocks: 2
132
+ camera_embedding_type: e_de_da_sincos
133
+ projection_camera_embeddings_input_dim: 2
134
+ zero_init_camera_projection: true # modified
135
+ init_mvattn_with_selfattn: false
136
+ cd_attention_last: false
137
+ cd_attention_mid: false
138
+ multiview_attention: true
139
+ sparse_mv_attention: true
140
+ selfattn_block: self_rowwise
141
+ mvcd_attention: true
142
+ addition_downsample: false
143
+ use_face_adapter: false
144
+ in_channels: 12
145
+
146
+
147
+ validation_guidance_scales:
148
+ - 3.0
149
+ validation_grid_nrow: ${num_views}
150
+ camera_embedding_lr_mult: 1.0
151
+ plot_pose_acc: false
152
+ num_views: 7
153
+ pred_type: joint
154
+ drop_type: drop_as_a_whole
core/opt.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import time
3
+ import torch
4
+ import torch_scatter
5
+ from core.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
6
+
7
+ @torch.no_grad()
8
+ def remesh(
9
+ vertices_etc:torch.Tensor, #V,D
10
+ faces:torch.Tensor, #F,3 long
11
+ min_edgelen:torch.Tensor, #V
12
+ max_edgelen:torch.Tensor, #V
13
+ flip:bool,
14
+ max_vertices=1e6
15
+ ):
16
+
17
+ # dummies
18
+ vertices_etc,faces = prepend_dummies(vertices_etc,faces)
19
+ vertices = vertices_etc[:,:3] #V,3
20
+ nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
21
+ min_edgelen = torch.concat((nan_tensor,min_edgelen))
22
+ max_edgelen = torch.concat((nan_tensor,max_edgelen))
23
+
24
+ # collapse
25
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
26
+ edge_length = calc_edge_length(vertices,edges) #E
27
+ face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
28
+ vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
29
+ face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
30
+ shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
31
+ priority = face_collapse.float() + shortness
32
+ vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
33
+
34
+ # split
35
+ if vertices.shape[0]<max_vertices:
36
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
37
+ vertices = vertices_etc[:,:3] #V,3
38
+ edge_length = calc_edge_length(vertices,edges) #E
39
+ splits = edge_length > max_edgelen[edges].mean(dim=-1)
40
+ vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
41
+
42
+ vertices_etc,faces = pack(vertices_etc,faces)
43
+ vertices = vertices_etc[:,:3]
44
+
45
+ if flip:
46
+ edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
47
+ flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
48
+
49
+ return remove_dummies(vertices_etc,faces)
50
+
51
+ def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
52
+ """lerp with adam's bias correction"""
53
+ c_prev = 1-weight**(step-1)
54
+ c = 1-weight**step
55
+ a_weight = weight*c_prev/c
56
+ b_weight = (1-weight)/c
57
+ a.mul_(a_weight).add_(b, alpha=b_weight)
58
+
59
+
60
+ class MeshOptimizer:
61
+ """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
62
+
63
+ def __init__(self,
64
+ vertices:torch.Tensor, #V,3
65
+ faces:torch.Tensor, #F,3
66
+ lr=0.3, #learning rate
67
+ betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
68
+ gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
69
+ nu_ref=0.3, #reference velocity for edge length controller
70
+ edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
71
+ edge_len_tol=.5, #edge length tolerance for split and collapse
72
+ gain=.2, #gain value for edge length controller
73
+ laplacian_weight=.02, #for laplacian smoothing/regularization
74
+ ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
75
+ grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
76
+ remesh_interval=1, #larger intervals are faster but with worse mesh quality
77
+ local_edgelen=True, #set to False to use a global scalar reference edge length instead
78
+ remesh_milestones= [500], #list of steps at which to remesh
79
+ # total_steps=1000, #total number of steps
80
+ ):
81
+ self._vertices = vertices
82
+ self._faces = faces
83
+ self._lr = lr
84
+ self._betas = betas
85
+ self._gammas = gammas
86
+ self._nu_ref = nu_ref
87
+ self._edge_len_lims = edge_len_lims
88
+ self._edge_len_tol = edge_len_tol
89
+ self._gain = gain
90
+ self._laplacian_weight = laplacian_weight
91
+ self._ramp = ramp
92
+ self._grad_lim = grad_lim
93
+ # self._remesh_interval = remesh_interval
94
+ # self._remseh_milestones = [ for remesh_milestones]
95
+ self._local_edgelen = local_edgelen
96
+ self._step = 0
97
+ self._start = time.time()
98
+
99
+ V = self._vertices.shape[0]
100
+ # prepare continuous tensor for all vertex-based data
101
+ self._vertices_etc = torch.zeros([V,9],device=vertices.device)
102
+ self._split_vertices_etc()
103
+ self.vertices.copy_(vertices) #initialize vertices
104
+ self._vertices.requires_grad_()
105
+ self._ref_len.fill_(edge_len_lims[1])
106
+
107
+ @property
108
+ def vertices(self):
109
+ return self._vertices
110
+
111
+ @property
112
+ def faces(self):
113
+ return self._faces
114
+
115
+ def _split_vertices_etc(self):
116
+ self._vertices = self._vertices_etc[:,:3]
117
+ self._m2 = self._vertices_etc[:,3]
118
+ self._nu = self._vertices_etc[:,4]
119
+ self._m1 = self._vertices_etc[:,5:8]
120
+ self._ref_len = self._vertices_etc[:,8]
121
+
122
+ with_gammas = any(g!=0 for g in self._gammas)
123
+ self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
124
+
125
+ def zero_grad(self):
126
+ self._vertices.grad = None
127
+
128
+ @torch.no_grad()
129
+ def step(self):
130
+
131
+ eps = 1e-8
132
+
133
+ self._step += 1
134
+ # spatial smoothing
135
+ edges,_ = calc_edges(self._faces) #E,2
136
+ E = edges.shape[0]
137
+ edge_smooth = self._smooth[edges] #E,2,S
138
+ neighbor_smooth = torch.zeros_like(self._smooth) #V,S
139
+ torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
140
+ #apply optional smoothing of m1,m2,nu
141
+ if self._gammas[0]:
142
+ self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
143
+ if self._gammas[1]:
144
+ self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
145
+ if self._gammas[2]:
146
+ self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
147
+
148
+ #add laplace smoothing to gradients
149
+ laplace = self._vertices - neighbor_smooth[:,:3]
150
+ grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
151
+
152
+ #gradient clipping
153
+ if self._step>1:
154
+ grad_lim = self._m1.abs().mul_(self._grad_lim)
155
+ grad.clamp_(min=-grad_lim,max=grad_lim)
156
+
157
+ # moment updates
158
+ lerp_unbiased(self._m1, grad, self._betas[0], self._step)
159
+ lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
160
+
161
+ velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
162
+ speed = velocity.norm(dim=-1) #V
163
+
164
+ if self._betas[2]:
165
+ lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
166
+ else:
167
+ self._nu.copy_(speed) #V
168
+ # update vertices
169
+ ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
170
+ self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
171
+
172
+ # update target edge length
173
+ if self._step < 500:
174
+ self._remesh_interval = 4
175
+ elif self._step < 800:
176
+ self._remesh_interval = 2
177
+ else:
178
+ self._remesh_interval = 1
179
+
180
+ if self._step % self._remesh_interval == 0:
181
+ if self._local_edgelen:
182
+ len_change = (1 + (self._nu - self._nu_ref) * self._gain)
183
+ else:
184
+ len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
185
+ self._ref_len *= len_change
186
+ self._ref_len.clamp_(*self._edge_len_lims)
187
+
188
+ def remesh(self, flip:bool=True)->tuple[torch.Tensor,torch.Tensor]:
189
+ min_edge_len = self._ref_len * (1 - self._edge_len_tol)
190
+ max_edge_len = self._ref_len * (1 + self._edge_len_tol)
191
+
192
+ self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip)
193
+
194
+ self._split_vertices_etc()
195
+ self._vertices.requires_grad_()
196
+
197
+ return self._vertices, self._faces
core/remesh.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as tfunc
3
+ import torch_scatter
4
+
5
+ def prepend_dummies(
6
+ vertices:torch.Tensor, #V,D
7
+ faces:torch.Tensor, #F,3 long
8
+ )->tuple[torch.Tensor,torch.Tensor]:
9
+ """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
10
+ V,D = vertices.shape
11
+ vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
12
+ faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
13
+ return vertices,faces
14
+
15
+ def remove_dummies(
16
+ vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
17
+ faces:torch.Tensor, #F,3 long - first face all zeros
18
+ )->tuple[torch.Tensor,torch.Tensor]:
19
+ """remove dummy elements added with prepend_dummies()"""
20
+ return vertices[1:],faces[1:]-1
21
+
22
+
23
+ def calc_edges(
24
+ faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
25
+ with_edge_to_face: bool = False
26
+ ) -> tuple[torch.Tensor, ...]:
27
+ """
28
+ returns tuple of
29
+ - edges E,2 long, 0 for unused, lower vertex index first
30
+ - face_to_edge F,3 long
31
+ - (optional) edge_to_face shape=E,[left,right],[face,side]
32
+
33
+ o-<-----e1 e0,e1...edge, e0<e1
34
+ | /A L,R....left and right face
35
+ | L / | both triangles ordered counter clockwise
36
+ | / R | normals pointing out of screen
37
+ V/ |
38
+ e0---->-o
39
+ """
40
+
41
+ F = faces.shape[0]
42
+
43
+ # make full edges, lower vertex index first
44
+ face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
45
+ full_edges = face_edges.reshape(F*3,2)
46
+ sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 TODO min/max faster?
47
+
48
+ # make unique edges
49
+ edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
50
+ E = edges.shape[0]
51
+ face_to_edge = full_to_unique.reshape(F,3) #F,3
52
+
53
+ if not with_edge_to_face:
54
+ return edges, face_to_edge
55
+
56
+ is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
57
+ edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
58
+ scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
59
+ edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
60
+ edge_to_face[0] = 0
61
+ return edges, face_to_edge, edge_to_face
62
+
63
+ def calc_edge_length(
64
+ vertices:torch.Tensor, #V,3 first may be dummy
65
+ edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
66
+ )->torch.Tensor: #E
67
+
68
+ full_vertices = vertices[edges] #E,2,3
69
+ a,b = full_vertices.unbind(dim=1) #E,3
70
+ return torch.norm(a-b,p=2,dim=-1)
71
+
72
+ def calc_face_normals(
73
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
74
+ faces:torch.Tensor, #F,3 long, first face may be all zero
75
+ normalize:bool=False,
76
+ )->torch.Tensor: #F,3
77
+ """
78
+ n
79
+ |
80
+ c0 corners ordered counterclockwise when
81
+ / \ looking onto surface (in neg normal direction)
82
+ c1---c2
83
+ """
84
+ full_vertices = vertices[faces] #F,C=3,3
85
+ v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
86
+ face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
87
+ if normalize:
88
+ face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) #TODO inplace?
89
+ return face_normals #F,3
90
+
91
+ def calc_vertex_normals(
92
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
93
+ faces:torch.Tensor, #F,3 long, first face may be all zero
94
+ face_normals:torch.Tensor=None, #F,3, not normalized
95
+ )->torch.Tensor: #F,3
96
+
97
+ F = faces.shape[0]
98
+
99
+ if face_normals is None:
100
+ face_normals = calc_face_normals(vertices,faces)
101
+
102
+ vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
103
+ vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
104
+ vertex_normals = vertex_normals.sum(dim=1) #V,3
105
+ return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
106
+
107
+ def calc_face_ref_normals(
108
+ faces:torch.Tensor, #F,3 long, 0 for unused
109
+ vertex_normals:torch.Tensor, #V,3 first unused
110
+ normalize:bool=False,
111
+ )->torch.Tensor: #F,3
112
+ """calculate reference normals for face flip detection"""
113
+ full_normals = vertex_normals[faces] #F,C=3,3
114
+ ref_normals = full_normals.sum(dim=1) #F,3
115
+ if normalize:
116
+ ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
117
+ return ref_normals
118
+
119
+ def pack(
120
+ vertices:torch.Tensor, #V,3 first unused and nan
121
+ faces:torch.Tensor, #F,3 long, 0 for unused
122
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
123
+ """removes unused elements in vertices and faces"""
124
+ V = vertices.shape[0]
125
+
126
+ # remove unused faces
127
+ used_faces = faces[:,0]!=0
128
+ used_faces[0] = True
129
+ faces = faces[used_faces] #sync
130
+
131
+ # remove unused vertices
132
+ used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
133
+ used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') #TODO int faster?
134
+ used_vertices = used_vertices.any(dim=1)
135
+ used_vertices[0] = True
136
+ vertices = vertices[used_vertices] #sync
137
+
138
+ # update used faces
139
+ ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
140
+ V1 = used_vertices.sum()
141
+ ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
142
+ faces = ind[faces]
143
+
144
+ return vertices,faces
145
+
146
+ def split_edges(
147
+ vertices:torch.Tensor, #V,3 first unused
148
+ faces:torch.Tensor, #F,3 long, 0 for unused
149
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
150
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
151
+ splits, #E bool
152
+ pack_faces:bool=True,
153
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
154
+
155
+ # c2 c2 c...corners = faces
156
+ # . . . . s...side_vert, 0 means no split
157
+ # . . .N2 . S...shrunk_face
158
+ # . . . . Ni...new_faces
159
+ # s2 s1 s2|c2...s1|c1
160
+ # . . . . .
161
+ # . . . S . .
162
+ # . . . . N1 .
163
+ # c0...(s0=0)....c1 s0|c0...........c1
164
+ #
165
+ # pseudo-code:
166
+ # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
167
+ # split = side_vert!=0 example:[False,True,True]
168
+ # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
169
+ # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
170
+ # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
171
+
172
+ V = vertices.shape[0]
173
+ F = faces.shape[0]
174
+ S = splits.sum().item() #sync
175
+
176
+ if S==0:
177
+ return vertices,faces
178
+
179
+ edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
180
+ edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
181
+ side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
182
+ split_edges = edges[splits] #S sync
183
+
184
+ #vertices
185
+ split_vertices = vertices[split_edges].mean(dim=1) #S,3
186
+ vertices = torch.concat((vertices,split_vertices),dim=0)
187
+
188
+ #faces
189
+ side_split = side_vert!=0 #F,3
190
+ shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
191
+ new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
192
+ faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
193
+ if pack_faces:
194
+ mask = faces[:,0]!=0
195
+ mask[0] = True
196
+ faces = faces[mask] #F',3 sync
197
+
198
+ return vertices,faces
199
+
200
+ def collapse_edges(
201
+ vertices:torch.Tensor, #V,3 first unused
202
+ faces:torch.Tensor, #F,3 long 0 for unused
203
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
204
+ priorities:torch.Tensor, #E float
205
+ stable:bool=False, #only for unit testing
206
+ )->tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
207
+
208
+ V = vertices.shape[0]
209
+
210
+ # check spacing
211
+ _,order = priorities.sort(stable=stable) #E
212
+ rank = torch.zeros_like(order)
213
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
214
+ vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
215
+ edge_rank = rank #E
216
+ for i in range(3):
217
+ torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
218
+ edge_rank,_ = vert_rank[edges].max(dim=-1) #E
219
+ candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
220
+
221
+ # check connectivity
222
+ vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
223
+ vert_connections[candidates[:,0]] = 1 #start
224
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
225
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
226
+ vert_connections[candidates] = 0 #clear start and end
227
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
228
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
229
+ collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
230
+
231
+ # mean vertices
232
+ vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) #TODO dim?
233
+
234
+ # update faces
235
+ dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
236
+ dest[collapses[:,1]] = dest[collapses[:,0]]
237
+ faces = dest[faces] #F,3 TODO optimize?
238
+ c0,c1,c2 = faces.unbind(dim=-1)
239
+ collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
240
+ faces[collapsed] = 0
241
+
242
+ return vertices,faces
243
+
244
+ def calc_face_collapses(
245
+ vertices:torch.Tensor, #V,3 first unused
246
+ faces:torch.Tensor, #F,3 long, 0 for unused
247
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
248
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
249
+ edge_length:torch.Tensor, #E
250
+ face_normals:torch.Tensor, #F,3
251
+ vertex_normals:torch.Tensor, #V,3 first unused
252
+ min_edge_length:torch.Tensor=None, #V
253
+ area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
254
+ shortest_probability = 0.8
255
+ )->torch.Tensor: #E edges to collapse
256
+
257
+ E = edges.shape[0]
258
+ F = faces.shape[0]
259
+
260
+ # face flips
261
+ ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
262
+ face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
263
+
264
+ # small faces
265
+ if min_edge_length is not None:
266
+ min_face_length = min_edge_length[faces].mean(dim=-1) #F
267
+ min_area = min_face_length**2 * area_ratio #F
268
+ face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
269
+ face_collapses[0] = False
270
+
271
+ # faces to edges
272
+ face_length = edge_length[face_to_edge] #F,3
273
+
274
+ if shortest_probability<1:
275
+ #select shortest edge with shortest_probability chance
276
+ randlim = round(2/(1-shortest_probability))
277
+ rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
278
+ sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
279
+ local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
280
+ else:
281
+ local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
282
+
283
+ edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
284
+ edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
285
+ edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) #TODO legal for bool?
286
+
287
+ return edge_collapses.bool()
288
+
289
+ def flip_edges(
290
+ vertices:torch.Tensor, #V,3 first unused
291
+ faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
292
+ edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
293
+ edge_to_face:torch.Tensor, #E,[left,right],[face,side]
294
+ with_border:bool=True, #handle border edges (D=4 instead of D=6)
295
+ with_normal_check:bool=True, #check face normal flips
296
+ stable:bool=False, #only for unit testing
297
+ ):
298
+ V = vertices.shape[0]
299
+ E = edges.shape[0]
300
+ device=vertices.device
301
+ vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
302
+ vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
303
+ neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
304
+ neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
305
+ edge_is_inside = neighbors.all(dim=-1) #E
306
+
307
+ if with_border:
308
+ # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
309
+ # need to use float for masks in order to use scatter(reduce='multiply')
310
+ vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
311
+ src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
312
+ vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
313
+ vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
314
+ vertex_degree -= 2 * vertex_is_inside #V long
315
+
316
+ neighbor_degrees = vertex_degree[neighbors] #E,LR=2
317
+ edge_degrees = vertex_degree[edges] #E,2
318
+ #
319
+ # loss = Sum_over_affected_vertices((new_degree-6)**2)
320
+ # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
321
+ # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
322
+ # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
323
+ #
324
+ loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
325
+ candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
326
+ loss_change = loss_change[candidates] #E'
327
+ if loss_change.shape[0]==0:
328
+ return
329
+
330
+ edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
331
+ _,order = loss_change.sort(descending=True, stable=stable) #E'
332
+ rank = torch.zeros_like(order)
333
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
334
+ vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
335
+ torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
336
+ vertex_rank,_ = vertex_rank.max(dim=-1) #V
337
+ neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
338
+ flip = rank==neighborhood_rank #E'
339
+
340
+ if with_normal_check:
341
+ # cl-<-----e1 e0,e1...edge, e0<e1
342
+ # | /A L,R....left and right face
343
+ # | L / | both triangles ordered counter clockwise
344
+ # | / R | normals pointing out of screen
345
+ # V/ |
346
+ # e0---->-cr
347
+ v = vertices[edges_neighbors] #E",4,3
348
+ v = v - v[:,0:1] #make relative to e0
349
+ e1 = v[:,1]
350
+ cl = v[:,2]
351
+ cr = v[:,3]
352
+ n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
353
+ flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
354
+ flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
355
+
356
+ flip_edges_neighbors = edges_neighbors[flip] #E",4
357
+ flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
358
+ flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
359
+ faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
econdataset.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam
18
+ from lib.pixielib.utils.config import cfg as pixie_cfg
19
+ from lib.pixielib.pixie import PIXIE
20
+ import lib.smplx as smplx
21
+ # from lib.pare.pare.core.tester import PARETester
22
+ from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis
23
+ from lib.pymaf.utils.imutils import process_image
24
+ from lib.common.imutils import econ_process_image
25
+ from lib.pymaf.core import path_config
26
+ from lib.pymaf.models import pymaf_net
27
+ from lib.common.config import cfg
28
+ from lib.common.render import Render
29
+ from lib.dataset.body_model import TetraSMPLModel
30
+ from lib.dataset.mesh_util import get_visibility
31
+ from utils.smpl_util import SMPLX
32
+ import os.path as osp
33
+ import os
34
+ import torch
35
+ import numpy as np
36
+ import random
37
+ from termcolor import colored
38
+ from PIL import ImageFile
39
+ from torchvision.models import detection
40
+
41
+
42
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
43
+
44
+
45
+ class SMPLDataset():
46
+
47
+ def __init__(self, cfg, device):
48
+
49
+ random.seed(1993)
50
+
51
+ self.image_dir = cfg['image_dir']
52
+ self.seg_dir = cfg['seg_dir']
53
+ self.hps_type = cfg['hps_type']
54
+ self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
55
+ self.smpl_gender = 'neutral'
56
+ self.colab = cfg['colab']
57
+
58
+ self.device = device
59
+
60
+ keep_lst = [f"{self.image_dir}/{i}" for i in sorted(os.listdir(self.image_dir))]
61
+ img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp']
62
+ keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
63
+
64
+ self.subject_list = [item for item in keep_lst if item.split(".")[-1] in img_fmts]
65
+
66
+ if self.colab:
67
+ self.subject_list = [self.subject_list[0]]
68
+
69
+ # smpl related
70
+ self.smpl_data = SMPLX()
71
+
72
+ # smpl-smplx correspondence
73
+ self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
74
+ self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68]
75
+ self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data.
76
+ model_dir,
77
+ gender=smpl_gender,
78
+ model_type=smpl_type,
79
+ ext='npz')
80
+
81
+ # Load SMPL model
82
+ self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device)
83
+ self.faces = self.smpl_model.faces
84
+
85
+ if self.hps_type == 'pymaf':
86
+ self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device)
87
+ self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True)
88
+ self.hps.eval()
89
+
90
+ elif self.hps_type == 'pare':
91
+ self.hps = PARETester(path_config.CFG, path_config.CKPT).model
92
+ elif self.hps_type == 'pixie':
93
+ self.hps = PIXIE(config=pixie_cfg, device=self.device)
94
+ self.smpl_model = self.hps.smplx
95
+ elif self.hps_type == 'hybrik':
96
+ smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
97
+ self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG,
98
+ smpl_path=smpl_path,
99
+ data_path=path_config.hybrik_data_dir)
100
+ self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'),
101
+ strict=False)
102
+ self.hps.to(self.device)
103
+ elif self.hps_type == 'bev':
104
+ try:
105
+ import bev
106
+ except:
107
+ print('Could not find bev, installing via pip install --upgrade simple-romp')
108
+ os.system('pip install simple-romp==1.0.3')
109
+ import bev
110
+ settings = bev.main.default_settings
111
+ # change the argparse settings of bev here if you prefer other settings.
112
+ settings.mode = 'image'
113
+ settings.GPU = int(str(self.device).split(':')[1])
114
+ settings.show_largest = True
115
+ # settings.show = True # uncommit this to show the original BEV predictions
116
+ self.hps = bev.BEV(settings)
117
+
118
+ self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True)
119
+ self.detector.eval()
120
+ print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
121
+
122
+ self.render = Render(size=512, device=device)
123
+
124
+ def __len__(self):
125
+ return len(self.subject_list)
126
+
127
+ def compute_vis_cmap(self, smpl_verts, smpl_faces):
128
+
129
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
130
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
131
+ smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type)
132
+
133
+ return {
134
+ 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
135
+ 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
136
+ 'smpl_verts': smpl_verts.unsqueeze(0)
137
+ }
138
+
139
+ def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale):
140
+
141
+ smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
142
+ tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz')
143
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
144
+
145
+ pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
146
+ smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
147
+
148
+ verts = np.concatenate([smpl_model.verts, smpl_model.verts_added],
149
+ axis=0) * scale.item() + trans.detach().cpu().numpy()
150
+ faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'),
151
+ dtype=np.int32) - 1
152
+
153
+ pad_v_num = int(8000 - verts.shape[0])
154
+ pad_f_num = int(25100 - faces.shape[0])
155
+
156
+ verts = np.pad(verts,
157
+ ((0, pad_v_num),
158
+ (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5
159
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant',
160
+ constant_values=0.0).astype(np.int32)
161
+
162
+ verts[:, 2] *= -1.0
163
+
164
+ voxel_dict = {
165
+ 'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
166
+ 'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
167
+ 'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
168
+ 'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
169
+ }
170
+
171
+ return voxel_dict
172
+
173
+ def __getitem__(self, index):
174
+
175
+ img_path = self.subject_list[index]
176
+ img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
177
+ print(img_name)
178
+ # smplx_param_path=f'./data/thuman2/smplx/{img_name[:-2]}.pkl'
179
+ # smplx_param = np.load(smplx_param_path, allow_pickle=True)
180
+
181
+ if self.seg_dir is None:
182
+ img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
183
+ img_path, self.hps_type, 512, self.device)
184
+
185
+ data_dict = {
186
+ 'name': img_name,
187
+ 'image': img_icon.to(self.device).unsqueeze(0),
188
+ 'ori_image': img_ori,
189
+ 'mask': img_mask,
190
+ 'uncrop_param': uncrop_param
191
+ }
192
+
193
+ else:
194
+ img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
195
+ img_path,
196
+ self.hps_type,
197
+ 512,
198
+ self.device,
199
+ seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
200
+ data_dict = {
201
+ 'name': img_name,
202
+ 'image': img_icon.to(self.device).unsqueeze(0),
203
+ 'ori_image': img_ori,
204
+ 'mask': img_mask,
205
+ 'uncrop_param': uncrop_param,
206
+ 'segmentations': segmentations
207
+ }
208
+
209
+ arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector)
210
+ data_dict['hands_visibility']=arr_dict['hands_visibility']
211
+
212
+ with torch.no_grad():
213
+ # import ipdb; ipdb.set_trace()
214
+ preds_dict = self.hps.forward(img_hps)
215
+
216
+ data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to(
217
+ self.device)
218
+
219
+ if self.hps_type == 'pymaf':
220
+ output = preds_dict['smpl_out'][-1]
221
+ scale, tranX, tranY = output['theta'][0, :3]
222
+ data_dict['betas'] = output['pred_shape']
223
+ data_dict['body_pose'] = output['rotmat'][:, 1:]
224
+ data_dict['global_orient'] = output['rotmat'][:, 0:1]
225
+ data_dict['smpl_verts'] = output['verts'] # 不确定尺度是否一样
226
+ data_dict["type"] = "smpl"
227
+
228
+ elif self.hps_type == 'pare':
229
+ data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
230
+ data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
231
+ data_dict['betas'] = preds_dict['pred_shape']
232
+ data_dict['smpl_verts'] = preds_dict['smpl_vertices']
233
+ scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
234
+ data_dict["type"] = "smpl"
235
+
236
+ elif self.hps_type == 'pixie':
237
+ data_dict.update(preds_dict)
238
+ data_dict['body_pose'] = preds_dict['body_pose']
239
+ data_dict['global_orient'] = preds_dict['global_pose']
240
+ data_dict['betas'] = preds_dict['shape']
241
+ data_dict['smpl_verts'] = preds_dict['vertices']
242
+ scale, tranX, tranY = preds_dict['cam'][0, :3]
243
+ data_dict["type"] = "smplx"
244
+
245
+ elif self.hps_type == 'hybrik':
246
+ data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
247
+ data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
248
+ data_dict['betas'] = preds_dict['pred_shape']
249
+ data_dict['smpl_verts'] = preds_dict['pred_vertices']
250
+ scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
251
+ scale = scale * 2
252
+ data_dict["type"] = "smpl"
253
+
254
+ elif self.hps_type == 'bev':
255
+ data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to(
256
+ self.device).float()
257
+ pred_thetas = batch_rodrigues(
258
+ torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
259
+ data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
260
+ data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
261
+ data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to(
262
+ self.device).float()
263
+ tranX = preds_dict['cam_trans'][0, 0]
264
+ tranY = preds_dict['cam'][0, 1] + 0.28
265
+ scale = preds_dict['cam'][0, 0] * 1.1
266
+ data_dict["type"] = "smpl"
267
+
268
+ data_dict['scale'] = scale
269
+ data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float()
270
+
271
+ # data_dict info (key-shape):
272
+ # scale, tranX, tranY - tensor.float
273
+ # betas - [1,10] / [1, 200]
274
+ # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
275
+ # global_orient - [1, 1, 3, 3]
276
+ # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
277
+
278
+ # from rot_mat to rot_6d for better optimization
279
+ N_body = data_dict["body_pose"].shape[1]
280
+ data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1)
281
+ data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1)
282
+
283
+ return data_dict
284
+
285
+ def render_normal(self, verts, faces):
286
+
287
+ # render optimized mesh (normal, T_normal, image [-1,1])
288
+ self.render.load_meshes(verts, faces)
289
+ return self.render.get_rgb_image()
290
+
291
+ def render_depth(self, verts, faces):
292
+
293
+ # render optimized mesh (normal, T_normal, image [-1,1])
294
+ self.render.load_meshes(verts, faces)
295
+ return self.render.get_depth_map(cam_ids=[0, 2])
296
+
297
+ def visualize_alignment(self, data):
298
+
299
+ import vedo
300
+ import trimesh
301
+
302
+ if self.hps_type != 'pixie':
303
+ smpl_out = self.smpl_model(betas=data['betas'],
304
+ body_pose=data['body_pose'],
305
+ global_orient=data['global_orient'],
306
+ pose2rot=False)
307
+ smpl_verts = ((smpl_out.vertices + data['trans']) *
308
+ data['scale']).detach().cpu().numpy()[0]
309
+ else:
310
+ smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
311
+ expression_params=data['exp'],
312
+ body_pose=data['body_pose'],
313
+ global_pose=data['global_orient'],
314
+ jaw_pose=data['jaw_pose'],
315
+ left_hand_pose=data['left_hand_pose'],
316
+ right_hand_pose=data['right_hand_pose'])
317
+
318
+ smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0]
319
+
320
+ smpl_verts *= np.array([1.0, -1.0, -1.0])
321
+ faces = data['smpl_faces'][0].detach().cpu().numpy()
322
+
323
+ image_P = data['image']
324
+ image_F, image_B = self.render_normal(smpl_verts, faces)
325
+
326
+ # create plot
327
+ vp = vedo.Plotter(title="", size=(1500, 1500))
328
+ vis_list = []
329
+
330
+ image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
331
+ image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
332
+ image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
333
+
334
+ vis_list.append(
335
+ vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos(
336
+ -1.0, -1.0, 1.0))
337
+ vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5))
338
+ vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0))
339
+
340
+ # create a mesh
341
+ mesh = trimesh.Trimesh(smpl_verts, faces, process=False)
342
+ mesh.visual.vertex_colors = [200, 200, 0]
343
+ vis_list.append(mesh)
344
+
345
+ vp.show(*vis_list, bg="white", axes=1, interactive=True)
346
+
347
+
348
+ if __name__ == '__main__':
349
+
350
+ cfg.merge_from_file("./configs/icon-filter.yaml")
351
+ cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml')
352
+
353
+ cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False]
354
+
355
+ cfg.merge_from_list(cfg_show_list)
356
+ cfg.freeze()
357
+
358
+
359
+ device = torch.device('cuda:0')
360
+
361
+ dataset = SMPLDataset(
362
+ {
363
+ 'image_dir': "./examples",
364
+ 'has_det': True, # w/ or w/o detection
365
+ 'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev
366
+ },
367
+ device)
368
+
369
+ for i in range(len(dataset)):
370
+ dataset.visualize_alignment(dataset[i])
examples/02986d0998ce01aa0aa67a99fbd1e09a.png ADDED
examples/16171.png ADDED
examples/26d2e846349647ff04c536816e0e8ca1.png ADDED
examples/30755.png ADDED
examples/3930.png ADDED
examples/4656716-3016170581.png ADDED
examples/663dcd6db19490de0b790da430bd5681.png ADDED

Git LFS Details

  • SHA256: b499922b6df6d6874fea68c571ff3271f68aa6bc40420396f4898e5c58d74dc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
examples/7332.png ADDED
examples/85891251f52a2399e660a63c2a7fdf40.png ADDED
examples/a689a48d23d6b8d58d67ff5146c6e088.png ADDED
examples/b0d178743c7e3e09700aaee8d2b1ec47.png ADDED
examples/case5.png ADDED
examples/d40776a1e1582179d97907d36f84d776.png ADDED
examples/durant.png ADDED
examples/eedb9018-e0eb-45be-33bd-5a0108ca0d8b.png ADDED
examples/f14f7d40b72062928461b21c6cc877407e69ee0c_high.png ADDED
examples/f6317ac1b0498f4e6ef9d12bd991a9bd1ff4ae04f898-IQTEBw_fw1200.png ADDED
examples/pexels-barbara-olsen-7869640.png ADDED
examples/pexels-julia-m-cameron-4145040.png ADDED
examples/pexels-marta-wave-6437749.png ADDED
examples/pexels-photo-6311555-removebg.png ADDED
examples/pexels-zdmit-6780091.png ADDED
inference.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Dict, Optional, Tuple, List
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from dataclasses import dataclass
7
+ from collections import defaultdict
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torchvision.utils import make_grid, save_image
11
+ from accelerate.utils import set_seed
12
+ from tqdm.auto import tqdm
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+ from rembg import remove, new_session
16
+ import pdb
17
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
18
+ from econdataset import SMPLDataset
19
+ from reconstruct import ReMesh
20
+ providers = [
21
+ ('CUDAExecutionProvider', {
22
+ 'device_id': 0,
23
+ 'arena_extend_strategy': 'kSameAsRequested',
24
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
25
+ 'cudnn_conv_algo_search': 'HEURISTIC',
26
+ })
27
+ ]
28
+ session = new_session(providers=providers)
29
+
30
+ weight_dtype = torch.float16
31
+ def tensor_to_numpy(tensor):
32
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
33
+
34
+
35
+ @dataclass
36
+ class TestConfig:
37
+ pretrained_model_name_or_path: str
38
+ revision: Optional[str]
39
+ validation_dataset: Dict
40
+ save_dir: str
41
+ seed: Optional[int]
42
+ validation_batch_size: int
43
+ dataloader_num_workers: int
44
+ # save_single_views: bool
45
+ save_mode: str
46
+ local_rank: int
47
+
48
+ pipe_kwargs: Dict
49
+ pipe_validation_kwargs: Dict
50
+ unet_from_pretrained_kwargs: Dict
51
+ validation_guidance_scales: float
52
+ validation_grid_nrow: int
53
+
54
+ num_views: int
55
+ enable_xformers_memory_efficient_attention: bool
56
+ with_smpl: Optional[bool]
57
+
58
+ recon_opt: Dict
59
+
60
+
61
+ def convert_to_numpy(tensor):
62
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
63
+
64
+ def convert_to_pil(tensor):
65
+ return Image.fromarray(convert_to_numpy(tensor))
66
+
67
+ def save_image(tensor, fp):
68
+ ndarr = convert_to_numpy(tensor)
69
+ # pdb.set_trace()
70
+ save_image_numpy(ndarr, fp)
71
+ return ndarr
72
+
73
+ def save_image_numpy(ndarr, fp):
74
+ im = Image.fromarray(ndarr)
75
+ im.save(fp)
76
+
77
+ def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir):
78
+ pipeline.set_progress_bar_config(disable=True)
79
+
80
+ if cfg.seed is None:
81
+ generator = None
82
+ else:
83
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
84
+
85
+ images_cond, pred_cat = [], defaultdict(list)
86
+ for case_id, batch in tqdm(enumerate(dataloader)):
87
+ images_cond.append(batch['imgs_in'][:, 0])
88
+
89
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
90
+ num_views = imgs_in.shape[1]
91
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
92
+ if cfg.with_smpl:
93
+ smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0)
94
+ smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W")
95
+ else:
96
+ smpl_in = None
97
+
98
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
99
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
100
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
101
+
102
+ with torch.autocast("cuda"):
103
+ # B*Nv images
104
+ guidance_scale = cfg.validation_guidance_scales
105
+ unet_out = pipeline(
106
+ imgs_in, None, prompt_embeds=prompt_embeddings,
107
+ dino_feature=None, smpl_in=smpl_in,
108
+ generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1,
109
+ **cfg.pipe_validation_kwargs
110
+ )
111
+
112
+ out = unet_out.images
113
+ bsz = out.shape[0] // 2
114
+
115
+ normals_pred = out[:bsz]
116
+ images_pred = out[bsz:]
117
+ if cfg.save_mode == 'concat': ## save concatenated color and normal---------------------
118
+ pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) # b, 3, h, w
119
+ cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}")
120
+ os.makedirs(cur_dir, exist_ok=True)
121
+ for i in range(bsz//num_views):
122
+ scene = batch['filename'][i].split('.')[0]
123
+
124
+ img_in_ = images_cond[-1][i].to(out.device)
125
+ vis_ = [img_in_]
126
+ for j in range(num_views):
127
+ idx = i*num_views + j
128
+ normal = normals_pred[idx]
129
+ color = images_pred[idx]
130
+
131
+ vis_.append(color)
132
+ vis_.append(normal)
133
+
134
+ out_filename = f"{cur_dir}/{scene}.png"
135
+ vis_ = torch.stack(vis_, dim=0)
136
+ vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1))
137
+ save_image(vis_, out_filename)
138
+ elif cfg.save_mode == 'rgb':
139
+ for i in range(bsz//num_views):
140
+ scene = batch['filename'][i].split('.')[0]
141
+
142
+ img_in_ = images_cond[-1][i].to(out.device)
143
+ normals, colors = [], []
144
+ for j in range(num_views):
145
+ idx = i*num_views + j
146
+ normal = normals_pred[idx]
147
+ if j == 0:
148
+ color = imgs_in[0].to(out.device)
149
+ else:
150
+ color = images_pred[idx]
151
+ if j in [3, 4]:
152
+ normal = torch.flip(normal, dims=[2])
153
+ color = torch.flip(color, dims=[2])
154
+
155
+ colors.append(color)
156
+ if j == 6:
157
+ normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
158
+ normals.append(normal)
159
+
160
+ ## save color and normal---------------------
161
+ # normal_filename = f"normals_{view}_masked.png"
162
+ # rgb_filename = f"color_{view}_masked.png"
163
+ # save_image(normal, os.path.join(scene_dir, normal_filename))
164
+ # save_image(color, os.path.join(scene_dir, rgb_filename))
165
+ normals[0][:, :256, 256:512] = normals[-1]
166
+
167
+ colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]]
168
+ normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]]
169
+ pose = econdata.__getitem__(case_id)
170
+ carving.optimize_case(scene, pose, colors, normals)
171
+ torch.cuda.empty_cache()
172
+
173
+
174
+
175
+ def load_pshuman_pipeline(cfg):
176
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
177
+ pipeline.unet.enable_xformers_memory_efficient_attention()
178
+ if torch.cuda.is_available():
179
+ pipeline.to('cuda')
180
+ return pipeline
181
+
182
+ def main(
183
+ cfg: TestConfig
184
+ ):
185
+
186
+ # If passed along, set the training seed now.
187
+ if cfg.seed is not None:
188
+ set_seed(cfg.seed)
189
+ pipeline = load_pshuman_pipeline(cfg)
190
+
191
+
192
+ if cfg.with_smpl:
193
+ from mvdiffusion.data.testdata_with_smpl import SingleImageDataset
194
+ else:
195
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
196
+
197
+ # Get the dataset
198
+ validation_dataset = SingleImageDataset(
199
+ **cfg.validation_dataset
200
+ )
201
+ validation_dataloader = torch.utils.data.DataLoader(
202
+ validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers
203
+ )
204
+ dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'}
205
+ econdata = SMPLDataset(dataset_param, device='cuda')
206
+
207
+ carving = ReMesh(cfg.recon_opt, econ_dataset=econdata)
208
+ run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir)
209
+
210
+
211
+ if __name__ == '__main__':
212
+ parser = argparse.ArgumentParser()
213
+ parser.add_argument('--config', type=str, required=True)
214
+ args, extras = parser.parse_known_args()
215
+ from utils.misc import load_config
216
+
217
+ # parse YAML config to OmegaConf
218
+ cfg = load_config(args.config, cli_args=extras)
219
+ schema = OmegaConf.structured(TestConfig)
220
+ cfg = OmegaConf.merge(schema, cfg)
221
+ main(cfg)
lib/__init__.py ADDED
File without changes
lib/common/__init__.py ADDED
File without changes
lib/common/cloth_extraction.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import itertools
5
+ import trimesh
6
+ from matplotlib.path import Path
7
+ from collections import Counter
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+
10
+
11
+ def load_segmentation(path, shape):
12
+ """
13
+ Get a segmentation mask for a given image
14
+ Arguments:
15
+ path: path to the segmentation json file
16
+ shape: shape of the output mask
17
+ Returns:
18
+ Returns a segmentation mask
19
+ """
20
+ with open(path) as json_file:
21
+ dict = json.load(json_file)
22
+ segmentations = []
23
+ for key, val in dict.items():
24
+ if not key.startswith('item'):
25
+ continue
26
+
27
+ # Each item can have multiple polygons. Combine them to one
28
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
29
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
30
+
31
+ coordinates = []
32
+ for segmentation_coord in val['segmentation']:
33
+ # The format before is [x1,y1, x2, y2, ....]
34
+ x = segmentation_coord[::2]
35
+ y = segmentation_coord[1::2]
36
+ xy = np.vstack((x, y)).T
37
+ coordinates.append(xy)
38
+
39
+ segmentations.append({
40
+ 'type': val['category_name'],
41
+ 'type_id': val['category_id'],
42
+ 'coordinates': coordinates
43
+ })
44
+
45
+ return segmentations
46
+
47
+
48
+ def smpl_to_recon_labels(recon, smpl, k=1):
49
+ """
50
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
51
+ Arguments:
52
+ recon: trimesh object (fully clothed model)
53
+ shape: trimesh object (smpl model)
54
+ k: number of nearest neighbours to use
55
+ Returns:
56
+ Returns a dictionary containing the bodypart and the corresponding indices
57
+ """
58
+ smpl_vert_segmentation = json.load(
59
+ open(
60
+ os.path.join(os.path.dirname(__file__),
61
+ 'smpl_vert_segmentation.json')))
62
+ n = smpl.vertices.shape[0]
63
+ y = np.array([None] * n)
64
+ for key, val in smpl_vert_segmentation.items():
65
+ y[val] = key
66
+
67
+ classifier = KNeighborsClassifier(n_neighbors=1)
68
+ classifier.fit(smpl.vertices, y)
69
+
70
+ y_pred = classifier.predict(recon.vertices)
71
+
72
+ recon_labels = {}
73
+ for key in smpl_vert_segmentation.keys():
74
+ recon_labels[key] = list(
75
+ np.argwhere(y_pred == key).flatten().astype(int))
76
+
77
+ return recon_labels
78
+
79
+
80
+ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
81
+ """
82
+ Extract a portion of a mesh using 2d segmentation coordinates
83
+ Arguments:
84
+ recon: fully clothed mesh
85
+ seg_coord: segmentation coordinates in 2D (NDC)
86
+ K: intrinsic matrix of the projection
87
+ R: rotation matrix of the projection
88
+ t: translation vector of the projection
89
+ Returns:
90
+ Returns a submesh using the segmentation coordinates
91
+ """
92
+ seg_coord = segmentation['coord_normalized']
93
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
94
+ extrinsic = np.zeros((3, 4))
95
+ extrinsic[:3, :3] = R
96
+ extrinsic[:, 3] = t
97
+ P = K[:3, :3] @ extrinsic
98
+
99
+ P_inv = np.linalg.pinv(P)
100
+
101
+ # Each segmentation can contain multiple polygons
102
+ # We need to check them separately
103
+ points_so_far = []
104
+ faces = recon.faces
105
+ for polygon in seg_coord:
106
+ n = len(polygon)
107
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
108
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
109
+ XYZ = P_inv @ coords_h[:, :, None]
110
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
111
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
112
+
113
+ p = Path(XYZ[:, :2])
114
+
115
+ grid = p.contains_points(recon.vertices[:, :2])
116
+ indeces = np.argwhere(grid == True)
117
+ points_so_far += list(indeces.flatten())
118
+
119
+ if smpl is not None:
120
+ num_verts = recon.vertices.shape[0]
121
+ recon_labels = smpl_to_recon_labels(recon, smpl)
122
+ body_parts_to_remove = [
123
+ 'rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
124
+ 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand',
125
+ 'rightHand'
126
+ ]
127
+ type = segmentation['type_id']
128
+
129
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
130
+ # https://github.com/switchablenorms/DeepFashion2
131
+ # Short sleeve clothes
132
+ if type == 1 or type == 3 or type == 10:
133
+ body_parts_to_remove += ['leftForeArm', 'rightForeArm']
134
+ # No sleeves at all or lower body clothes
135
+ elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
136
+ body_parts_to_remove += [
137
+ 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'
138
+ ]
139
+ # Shorts
140
+ elif type == 7:
141
+ body_parts_to_remove += [
142
+ 'leftLeg', 'rightLeg', 'leftForeArm', 'rightForeArm',
143
+ 'leftArm', 'rightArm'
144
+ ]
145
+
146
+ verts_to_remove = list(
147
+ itertools.chain.from_iterable(
148
+ [recon_labels[part] for part in body_parts_to_remove]))
149
+
150
+ label_mask = np.zeros(num_verts, dtype=bool)
151
+ label_mask[verts_to_remove] = True
152
+
153
+ seg_mask = np.zeros(num_verts, dtype=bool)
154
+ seg_mask[points_so_far] = True
155
+
156
+ # Remove points that belong to other bodyparts
157
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
158
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
159
+
160
+ combine_mask = np.zeros(num_verts, dtype=bool)
161
+ combine_mask[points_so_far] = True
162
+ combine_mask[extra_verts_to_remove] = False
163
+
164
+ all_indices = np.argwhere(combine_mask == True).flatten()
165
+
166
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
167
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
168
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
169
+
170
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
171
+ mask = np.zeros(len(recon.faces), dtype=bool)
172
+ if len(faces_to_keep) > 0:
173
+ mask[faces_to_keep] = True
174
+
175
+ mesh.update_faces(mask)
176
+ mesh.remove_unreferenced_vertices()
177
+
178
+ # mesh.rezero()
179
+
180
+ return mesh
181
+
182
+ return None
lib/common/config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from yacs.config import CfgNode as CN
18
+ import os
19
+
20
+ _C = CN(new_allowed=True)
21
+
22
+ # needed by trainer
23
+ _C.name = 'default'
24
+ _C.gpus = [0]
25
+ _C.test_gpus = [1]
26
+ _C.root = "./data/"
27
+ _C.ckpt_dir = './data/ckpt/'
28
+ _C.resume_path = ''
29
+ _C.normal_path = ''
30
+ _C.corr_path = ''
31
+ _C.results_path = './data/results/'
32
+ _C.projection_mode = 'orthogonal'
33
+ _C.num_views = 1
34
+ _C.sdf = False
35
+ _C.sdf_clip = 5.0
36
+
37
+ _C.lr_G = 1e-3
38
+ _C.lr_C = 1e-3
39
+ _C.lr_N = 2e-4
40
+ _C.weight_decay = 0.0
41
+ _C.momentum = 0.0
42
+ _C.optim = 'Adam'
43
+ _C.schedule = [5, 10, 15]
44
+ _C.gamma = 0.1
45
+
46
+ _C.overfit = False
47
+ _C.resume = False
48
+ _C.test_mode = False
49
+ _C.test_uv = False
50
+ _C.draw_geo_thres = 0.60
51
+ _C.num_sanity_val_steps = 2
52
+ _C.fast_dev = 0
53
+ _C.get_fit = False
54
+ _C.agora = False
55
+ _C.optim_cloth = False
56
+ _C.optim_body = False
57
+ _C.mcube_res = 256
58
+ _C.clean_mesh = True
59
+ _C.remesh = False
60
+
61
+ _C.batch_size = 4
62
+ _C.num_threads = 8
63
+
64
+ _C.num_epoch = 10
65
+ _C.freq_plot = 0.01
66
+ _C.freq_show_train = 0.1
67
+ _C.freq_show_val = 0.2
68
+ _C.freq_eval = 0.5
69
+ _C.accu_grad_batch = 4
70
+
71
+ _C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
72
+
73
+ _C.net = CN()
74
+ _C.net.gtype = 'HGPIFuNet'
75
+ _C.net.ctype = 'resnet18'
76
+ _C.net.classifierIMF = 'MultiSegClassifier'
77
+ _C.net.netIMF = 'resnet18'
78
+ _C.net.norm = 'group'
79
+ _C.net.norm_mlp = 'group'
80
+ _C.net.norm_color = 'group'
81
+ _C.net.hg_down = 'conv128' #'ave_pool'
82
+ _C.net.num_views = 1
83
+
84
+ # kernel_size, stride, dilation, padding
85
+
86
+ _C.net.conv1 = [7, 2, 1, 3]
87
+ _C.net.conv3x3 = [3, 1, 1, 1]
88
+
89
+ _C.net.num_stack = 4
90
+ _C.net.num_hourglass = 2
91
+ _C.net.hourglass_dim = 256
92
+ _C.net.voxel_dim = 32
93
+ _C.net.resnet_dim = 120
94
+ _C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
95
+ _C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
96
+ _C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
97
+ _C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
98
+ _C.net.res_layers = [2, 3, 4]
99
+ _C.net.filter_dim = 256
100
+ _C.net.smpl_dim = 3
101
+
102
+ _C.net.cly_dim = 3
103
+ _C.net.soft_dim = 64
104
+ _C.net.z_size = 200.0
105
+ _C.net.N_freqs = 10
106
+ _C.net.geo_w = 0.1
107
+ _C.net.norm_w = 0.1
108
+ _C.net.dc_w = 0.1
109
+ _C.net.C_cat_to_G = False
110
+
111
+ _C.net.skip_hourglass = True
112
+ _C.net.use_tanh = False
113
+ _C.net.soft_onehot = True
114
+ _C.net.no_residual = False
115
+ _C.net.use_attention = False
116
+
117
+ _C.net.prior_type = "sdf"
118
+ _C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
119
+ _C.net.use_filter = True
120
+ _C.net.use_cc = False
121
+ _C.net.use_PE = False
122
+ _C.net.use_IGR = False
123
+ _C.net.in_geo = ()
124
+ _C.net.in_nml = ()
125
+
126
+ _C.dataset = CN()
127
+ _C.dataset.root = ''
128
+ _C.dataset.set_splits = [0.95, 0.04]
129
+ _C.dataset.types = [
130
+ "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
131
+ ]
132
+ _C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
133
+ _C.dataset.rp_type = "pifu900"
134
+ _C.dataset.th_type = 'train'
135
+ _C.dataset.input_size = 512
136
+ _C.dataset.rotation_num = 3
137
+ _C.dataset.num_sample_ray=128 # volume rendering
138
+ _C.dataset.num_precomp = 10 # Number of segmentation classifiers
139
+ _C.dataset.num_multiseg = 500 # Number of categories per classifier
140
+ _C.dataset.num_knn = 10 # for loss/error
141
+ _C.dataset.num_knn_dis = 20 # for accuracy
142
+ _C.dataset.num_verts_max = 20000
143
+ _C.dataset.zray_type = False
144
+ _C.dataset.online_smpl = False
145
+ _C.dataset.noise_type = ['z-trans', 'pose', 'beta']
146
+ _C.dataset.noise_scale = [0.0, 0.0, 0.0]
147
+ _C.dataset.num_sample_geo = 10000
148
+ _C.dataset.num_sample_color = 0
149
+ _C.dataset.num_sample_seg = 0
150
+ _C.dataset.num_sample_knn = 10000
151
+
152
+ _C.dataset.sigma_geo = 5.0
153
+ _C.dataset.sigma_color = 0.10
154
+ _C.dataset.sigma_seg = 0.10
155
+ _C.dataset.thickness_threshold = 20.0
156
+ _C.dataset.ray_sample_num = 2
157
+ _C.dataset.semantic_p = False
158
+ _C.dataset.remove_outlier = False
159
+
160
+ _C.dataset.train_bsize = 1.0
161
+ _C.dataset.val_bsize = 1.0
162
+ _C.dataset.test_bsize = 1.0
163
+
164
+
165
+ def get_cfg_defaults():
166
+ """Get a yacs CfgNode object with default values for my_project."""
167
+ # Return a clone so that the defaults will not be altered
168
+ # This is for the "local variable" use pattern
169
+ return _C.clone()
170
+
171
+
172
+ # Alternatively, provide a way to import the defaults as
173
+ # a global singleton:
174
+ cfg = _C # users can `from config import cfg`
175
+
176
+ # cfg = get_cfg_defaults()
177
+ # cfg.merge_from_file('./configs/example.yaml')
178
+
179
+ # # Now override from a list (opts could come from the command line)
180
+ # opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
181
+ # cfg.merge_from_list(opts)
182
+
183
+
184
+ def update_cfg(cfg_file):
185
+ # cfg = get_cfg_defaults()
186
+ _C.merge_from_file(cfg_file)
187
+ # return cfg.clone()
188
+ return _C
189
+
190
+
191
+ def parse_args(args):
192
+ cfg_file = args.cfg_file
193
+ if args.cfg_file is not None:
194
+ cfg = update_cfg(args.cfg_file)
195
+ else:
196
+ cfg = get_cfg_defaults()
197
+
198
+ # if args.misc is not None:
199
+ # cfg.merge_from_list(args.misc)
200
+
201
+ return cfg
202
+
203
+
204
+ def parse_args_extend(args):
205
+ if args.resume:
206
+ if not os.path.exists(args.log_dir):
207
+ raise ValueError(
208
+ 'Experiment are set to resume mode, but log directory does not exist.'
209
+ )
210
+
211
+ # load log's cfg
212
+ cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
213
+ cfg = update_cfg(cfg_file)
214
+
215
+ if args.misc is not None:
216
+ cfg.merge_from_list(args.misc)
217
+ else:
218
+ parse_args(args)
lib/common/imutils.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
3
+ import cv2
4
+ import mediapipe as mp
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ from lib.pymafx.core import constants
10
+ from rembg import remove
11
+ # from rembg.session_factory import new_session
12
+ from torchvision import transforms
13
+ from kornia.geometry.transform import get_affine_matrix2d, warp_affine
14
+
15
+
16
+ def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
17
+ all_ops = []
18
+ if res is not None:
19
+ all_ops.append(transforms.Resize(size=res))
20
+ if not is_tensor:
21
+ all_ops.append(transforms.ToTensor())
22
+ if mean is not None and std is not None:
23
+ all_ops.append(transforms.Normalize(mean=mean, std=std))
24
+ return transforms.Compose(all_ops)
25
+
26
+
27
+ def get_affine_matrix_wh(w1, h1, w2, h2):
28
+
29
+ transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
30
+ center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
31
+ scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
32
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
33
+
34
+ return M
35
+
36
+
37
+ def get_affine_matrix_box(boxes, w2, h2):
38
+
39
+ # boxes [left, top, right, bottom]
40
+ width = boxes[:, 2] - boxes[:, 0] #(N,)
41
+ height = boxes[:, 3] - boxes[:, 1] #(N,)
42
+ center = torch.tensor(
43
+ [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
44
+ ).T #(N,2)
45
+ scale = torch.min(torch.tensor([w2 / width, h2 / height]),
46
+ dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2)
47
+ transl = torch.cat([w2 / 2.0 - center[:, 0:1], h2 / 2.0 - center[:, 1:2]], dim=1) #(N,2)
48
+ M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.,]*transl.shape[0]))
49
+
50
+ return M
51
+
52
+
53
+ def load_img(img_file):
54
+
55
+ if img_file.endswith("exr"):
56
+ img = cv2.imread(img_file, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
57
+ else :
58
+ img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
59
+
60
+ # considering non 8-bit image
61
+ if img.dtype != np.uint8 :
62
+ img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
63
+
64
+ if len(img.shape) == 2:
65
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
66
+
67
+ if not img_file.endswith("png"):
68
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69
+ else:
70
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
71
+
72
+ return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
73
+
74
+
75
+ def get_keypoints(image):
76
+ def collect_xyv(x, body=True):
77
+ lmk = x.landmark
78
+ all_lmks = []
79
+ for i in range(len(lmk)):
80
+ visibility = lmk[i].visibility if body else 1.0
81
+ all_lmks.append(torch.Tensor([lmk[i].x, lmk[i].y, lmk[i].z, visibility]))
82
+ return torch.stack(all_lmks).view(-1, 4)
83
+
84
+ mp_holistic = mp.solutions.holistic
85
+
86
+ with mp_holistic.Holistic(
87
+ static_image_mode=True,
88
+ model_complexity=2,
89
+ ) as holistic:
90
+ results = holistic.process(image)
91
+
92
+ fake_kps = torch.zeros(33, 4)
93
+
94
+ result = {}
95
+ result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
96
+ result["lhand"] = collect_xyv(
97
+ results.left_hand_landmarks, False
98
+ ) if results.left_hand_landmarks else fake_kps
99
+ result["rhand"] = collect_xyv(
100
+ results.right_hand_landmarks, False
101
+ ) if results.right_hand_landmarks else fake_kps
102
+ result["face"] = collect_xyv(
103
+ results.face_landmarks, False
104
+ ) if results.face_landmarks else fake_kps
105
+
106
+ return result
107
+
108
+
109
+ def get_pymafx(image, landmarks):
110
+
111
+ # image [3,512,512]
112
+
113
+ item = {
114
+ 'img_body':
115
+ F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
116
+ }
117
+
118
+ for part in ['lhand', 'rhand', 'face']:
119
+ kp2d = landmarks[part]
120
+ kp2d_valid = kp2d[kp2d[:, 3] > 0.]
121
+ if len(kp2d_valid) > 0:
122
+ bbox = [
123
+ min(kp2d_valid[:, 0]),
124
+ min(kp2d_valid[:, 1]),
125
+ max(kp2d_valid[:, 0]),
126
+ max(kp2d_valid[:, 1])
127
+ ]
128
+ center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
129
+ scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
130
+
131
+ # handle invalid part keypoints
132
+ if len(kp2d_valid) < 1 or scale_part < 0.01:
133
+ center_part = [0, 0]
134
+ scale_part = 0.5
135
+ kp2d[:, 3] = 0
136
+
137
+ center_part = torch.tensor(center_part).float()
138
+
139
+ theta_part = torch.zeros(1, 2, 3)
140
+ theta_part[:, 0, 0] = scale_part
141
+ theta_part[:, 1, 1] = scale_part
142
+ theta_part[:, :, -1] = center_part
143
+
144
+ grid = F.affine_grid(theta_part, torch.Size([1, 3, 224, 224]), align_corners=False)
145
+ img_part = F.grid_sample(image.unsqueeze(0), grid, align_corners=False).squeeze(0).float()
146
+
147
+ item[f'img_{part}'] = img_part
148
+
149
+ theta_i_inv = torch.zeros_like(theta_part)
150
+ theta_i_inv[:, 0, 0] = 1. / theta_part[:, 0, 0]
151
+ theta_i_inv[:, 1, 1] = 1. / theta_part[:, 1, 1]
152
+ theta_i_inv[:, :, -1] = -theta_part[:, :, -1] / theta_part[:, 0, 0].unsqueeze(-1)
153
+ item[f'{part}_theta_inv'] = theta_i_inv[0]
154
+
155
+ return item
156
+
157
+
158
+ def remove_floats(mask):
159
+
160
+ # 1. find all the contours
161
+ # 2. fillPoly "True" for the largest one
162
+ # 3. fillPoly "False" for its childrens
163
+
164
+ new_mask = np.zeros(mask.shape)
165
+ cnts, hier = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
166
+ cnt_index = sorted(range(len(cnts)), key=lambda k: cv2.contourArea(cnts[k]), reverse=True)
167
+ body_cnt = cnts[cnt_index[0]]
168
+ childs_cnt_idx = np.where(np.array(hier)[0, :, -1] == cnt_index[0])[0]
169
+ childs_cnt = [cnts[idx] for idx in childs_cnt_idx]
170
+ cv2.fillPoly(new_mask, [body_cnt], 1)
171
+ cv2.fillPoly(new_mask, childs_cnt, 0)
172
+
173
+ return new_mask
174
+
175
+
176
+ def econ_process_image(img_file, hps_type, single, input_res, detector):
177
+
178
+ img_raw, (in_height, in_width) = load_img(img_file)
179
+ tgt_res = input_res * 2
180
+ M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
181
+ img_square = warp_affine(
182
+ img_raw,
183
+ M_square[:, :2], (tgt_res, ) * 2,
184
+ mode='bilinear',
185
+ padding_mode='zeros',
186
+ align_corners=True
187
+ )
188
+
189
+ # detection for bbox
190
+ predictions = detector(img_square / 255.)[0]
191
+
192
+ if single:
193
+ top_score = predictions["scores"][predictions["labels"] == 1].max()
194
+ human_ids = torch.where(predictions["scores"] == top_score)[0]
195
+ else:
196
+ human_ids = torch.logical_and(predictions["labels"] == 1,
197
+ predictions["scores"] > 0.9).nonzero().squeeze(1)
198
+
199
+ boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
200
+ masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
201
+
202
+ M_crop = get_affine_matrix_box(boxes, input_res, input_res)
203
+
204
+ img_icon_lst = []
205
+ img_crop_lst = []
206
+ img_hps_lst = []
207
+ img_mask_lst = []
208
+ landmark_lst = []
209
+ hands_visibility_lst = []
210
+ img_pymafx_lst = []
211
+
212
+ uncrop_param = {
213
+ "ori_shape": [in_height, in_width],
214
+ "box_shape": [input_res, input_res],
215
+ "square_shape": [tgt_res, tgt_res],
216
+ "M_square": M_square,
217
+ "M_crop": M_crop
218
+ }
219
+
220
+ for idx in range(len(boxes)):
221
+
222
+ # mask out the pixels of others
223
+ if len(masks) > 1:
224
+ mask_detection = (masks[np.arange(len(masks)) != idx]).max(axis=0)
225
+ else:
226
+ mask_detection = masks[0] * 0.
227
+
228
+ img_square_rgba = torch.cat(
229
+ [img_square.squeeze(0).permute(1, 2, 0),
230
+ torch.tensor(mask_detection < 0.4) * 255],
231
+ dim=2
232
+ )
233
+
234
+ img_crop = warp_affine(
235
+ img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
236
+ M_crop[idx:idx + 1, :2], (input_res, ) * 2,
237
+ mode='bilinear',
238
+ padding_mode='zeros',
239
+ align_corners=True
240
+ ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
241
+
242
+ # get accurate person segmentation mask
243
+ img_rembg = remove(img_crop) #post_process_mask=True)
244
+ img_mask = remove_floats(img_rembg[:, :, [3]])
245
+
246
+ mean_icon = std_icon = (0.5, 0.5, 0.5)
247
+ img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
248
+ img_icon = transform_to_tensor(512, mean_icon, std_icon)(
249
+ Image.fromarray(img_np)
250
+ ) * torch.tensor(img_mask).permute(2, 0, 1)
251
+ img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN,
252
+ constants.IMG_NORM_STD)(Image.fromarray(img_np))
253
+
254
+ landmarks = get_keypoints(img_np)
255
+
256
+ # get hands visibility
257
+ hands_visibility = [True, True]
258
+ if landmarks['lhand'][:, -1].mean() == 0.:
259
+ hands_visibility[0] = False
260
+ if landmarks['rhand'][:, -1].mean() == 0.:
261
+ hands_visibility[1] = False
262
+ hands_visibility_lst.append(hands_visibility)
263
+
264
+ if hps_type == 'pymafx':
265
+ img_pymafx_lst.append(
266
+ get_pymafx(
267
+ transform_to_tensor(512, constants.IMG_NORM_MEAN,
268
+ constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks
269
+ )
270
+ )
271
+
272
+ img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
273
+ img_icon_lst.append(img_icon)
274
+ img_hps_lst.append(img_hps)
275
+ img_mask_lst.append(torch.tensor(img_mask[..., 0]))
276
+ landmark_lst.append(landmarks['body'])
277
+
278
+ # required image tensors / arrays
279
+
280
+ # img_icon (tensor): (-1, 1), [3,512,512]
281
+ # img_hps (tensor): (-2.11, 2.44), [3,224,224]
282
+
283
+ # img_np (array): (0, 255), [512,512,3]
284
+ # img_rembg (array): (0, 255), [512,512,4]
285
+ # img_mask (array): (0, 1), [512,512,1]
286
+ # img_crop (array): (0, 255), [512,512,4]
287
+
288
+ return_dict = {
289
+ "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
290
+ "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
291
+ "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
292
+ "img_raw": img_raw, #[1, 3, H, W]
293
+ "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
294
+ "uncrop_param": uncrop_param,
295
+ "landmark": torch.stack(landmark_lst), #[N, 33, 4]
296
+ "hands_visibility": hands_visibility_lst,
297
+ }
298
+
299
+ img_pymafx = {}
300
+
301
+ if len(img_pymafx_lst) > 0:
302
+ for idx in range(len(img_pymafx_lst)):
303
+ for key in img_pymafx_lst[idx].keys():
304
+ if key not in img_pymafx.keys():
305
+ img_pymafx[key] = [img_pymafx_lst[idx][key]]
306
+ else:
307
+ img_pymafx[key] += [img_pymafx_lst[idx][key]]
308
+
309
+ for key in img_pymafx.keys():
310
+ img_pymafx[key] = torch.stack(img_pymafx[key]).float()
311
+
312
+ return_dict.update({"img_pymafx": img_pymafx})
313
+
314
+ return return_dict
315
+
316
+
317
+ def blend_rgb_norm(norms, data):
318
+
319
+ # norms [N, 3, res, res]
320
+ masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
321
+ norm_mask = F.interpolate(
322
+ torch.cat([norms, masks], dim=1).detach(),
323
+ size=data["uncrop_param"]["box_shape"],
324
+ mode="bilinear",
325
+ align_corners=False
326
+ )
327
+ final = data["img_raw"].type_as(norm_mask)
328
+
329
+ for idx in range(len(norms)):
330
+
331
+ norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
332
+ mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
333
+
334
+ norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
335
+ mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
336
+
337
+ final = final * (1.0 - mask_ori) + norm_ori * mask_ori
338
+
339
+ return final.detach().cpu()
340
+
341
+
342
+ def unwrap(image, uncrop_param, idx):
343
+
344
+ device = image.device
345
+
346
+ img_square = warp_affine(
347
+ image,
348
+ torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
349
+ uncrop_param["square_shape"],
350
+ mode='bilinear',
351
+ padding_mode='zeros',
352
+ align_corners=True
353
+ )
354
+
355
+ img_ori = warp_affine(
356
+ img_square,
357
+ torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
358
+ uncrop_param["ori_shape"],
359
+ mode='bilinear',
360
+ padding_mode='zeros',
361
+ align_corners=True
362
+ )
363
+
364
+ return img_ori
lib/common/render.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ from pytorch3d.renderer import (
18
+ BlendParams,
19
+ blending,
20
+ look_at_view_transform,
21
+ FoVOrthographicCameras,
22
+ PointLights,
23
+ RasterizationSettings,
24
+ PointsRasterizationSettings,
25
+ PointsRenderer,
26
+ AlphaCompositor,
27
+ PointsRasterizer,
28
+ MeshRenderer,
29
+ MeshRasterizer,
30
+ SoftPhongShader,
31
+ SoftSilhouetteShader,
32
+ TexturesVertex,
33
+ )
34
+ from pytorch3d.renderer.mesh import TexturesVertex
35
+ from pytorch3d.structures import Meshes
36
+ from lib.dataset.mesh_util import get_visibility, get_visibility_color
37
+
38
+ import lib.common.render_utils as util
39
+ import torch
40
+ import numpy as np
41
+ from PIL import Image
42
+ from tqdm import tqdm
43
+ import os
44
+ import cv2
45
+ import math
46
+ from termcolor import colored
47
+
48
+
49
+ def image2vid(images, vid_path):
50
+
51
+ w, h = images[0].size
52
+ videodims = (w, h)
53
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
54
+ video = cv2.VideoWriter(vid_path, fourcc, len(images) / 5.0, videodims)
55
+ for image in images:
56
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
57
+ video.release()
58
+
59
+
60
+ def query_color(verts, faces, image, device, predicted_color):
61
+ """query colors from points and image
62
+
63
+ Args:
64
+ verts ([B, 3]): [query verts]
65
+ faces ([M, 3]): [query faces]
66
+ image ([B, 3, H, W]): [full image]
67
+
68
+ Returns:
69
+ [np.float]: [return colors]
70
+ """
71
+
72
+ verts = verts.float().to(device)
73
+ faces = faces.long().to(device)
74
+ predicted_color=predicted_color.to(device)
75
+ (xy, z) = verts.split([2, 1], dim=1)
76
+ visibility = get_visibility_color(xy, z, faces[:, [0, 2, 1]]).flatten()
77
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
78
+ uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
79
+ colors = (torch.nn.functional.grid_sample(
80
+ image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
81
+ 1.0) * 0.5 * 255.0
82
+ colors[visibility == 0.0]=(predicted_color* 255.0)[visibility == 0.0]
83
+
84
+ return colors.detach().cpu()
85
+
86
+
87
+ class cleanShader(torch.nn.Module):
88
+
89
+ def __init__(self, device="cpu", cameras=None, blend_params=None):
90
+ super().__init__()
91
+ self.cameras = cameras
92
+ self.blend_params = blend_params if blend_params is not None else BlendParams(
93
+ )
94
+
95
+ def forward(self, fragments, meshes, **kwargs):
96
+ cameras = kwargs.get("cameras", self.cameras)
97
+ if cameras is None:
98
+ msg = "Cameras must be specified either at initialization \
99
+ or in the forward pass of TexturedSoftPhongShader"
100
+
101
+ raise ValueError(msg)
102
+
103
+ # get renderer output
104
+ blend_params = kwargs.get("blend_params", self.blend_params)
105
+ texels = meshes.sample_textures(fragments)
106
+ images = blending.softmax_rgb_blend(texels,
107
+ fragments,
108
+ blend_params,
109
+ znear=-256,
110
+ zfar=256)
111
+
112
+ return images
113
+
114
+
115
+ class Render:
116
+
117
+ def __init__(self, size=512, device=torch.device("cuda:0")):
118
+ self.device = device
119
+ self.size = size
120
+
121
+ # camera setting
122
+ self.dis = 100.0
123
+ self.scale = 100.0
124
+ self.mesh_y_center = 0.0
125
+
126
+ self.reload_cam()
127
+
128
+ self.type = "color"
129
+
130
+ self.mesh = None
131
+ self.deform_mesh = None
132
+ self.pcd = None
133
+ self.renderer = None
134
+ self.meshRas = None
135
+
136
+ self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
137
+
138
+ def reload_cam(self):
139
+
140
+ self.cam_pos = [
141
+ (0, self.mesh_y_center, self.dis),
142
+ (self.dis, self.mesh_y_center, 0),
143
+ (0, self.mesh_y_center, -self.dis),
144
+ (-self.dis, self.mesh_y_center, 0),
145
+ (0,self.mesh_y_center+self.dis,0),
146
+ (0,self.mesh_y_center-self.dis,0),
147
+ ]
148
+
149
+ def get_camera(self, cam_id):
150
+
151
+ if cam_id == 4:
152
+ R, T = look_at_view_transform(
153
+ eye=[self.cam_pos[cam_id]],
154
+ at=((0, self.mesh_y_center, 0), ),
155
+ up=((0, 0, 1), ),
156
+ )
157
+ elif cam_id == 5:
158
+ R, T = look_at_view_transform(
159
+ eye=[self.cam_pos[cam_id]],
160
+ at=((0, self.mesh_y_center, 0), ),
161
+ up=((0, 0, 1), ),
162
+ )
163
+
164
+ else:
165
+ R, T = look_at_view_transform(
166
+ eye=[self.cam_pos[cam_id]],
167
+ at=((0, self.mesh_y_center, 0), ),
168
+ up=((0, 1, 0), ),
169
+ )
170
+
171
+ camera = FoVOrthographicCameras(
172
+ device=self.device,
173
+ R=R,
174
+ T=T,
175
+ znear=100.0,
176
+ zfar=-100.0,
177
+ max_y=100.0,
178
+ min_y=-100.0,
179
+ max_x=100.0,
180
+ min_x=-100.0,
181
+ scale_xyz=(self.scale * np.ones(3), ),
182
+ )
183
+
184
+ return camera
185
+
186
+ def init_renderer(self, camera, type="clean_mesh", bg="gray"):
187
+
188
+ if "mesh" in type:
189
+
190
+ # rasterizer
191
+ self.raster_settings_mesh = RasterizationSettings(
192
+ image_size=self.size,
193
+ blur_radius=np.log(1.0 / 1e-4) * 1e-7,
194
+ faces_per_pixel=30,
195
+ )
196
+ self.meshRas = MeshRasterizer(
197
+ cameras=camera, raster_settings=self.raster_settings_mesh)
198
+
199
+ if bg == "black":
200
+ blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
201
+ elif bg == "white":
202
+ blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
203
+ elif bg == "gray":
204
+ blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
205
+
206
+ if type == "ori_mesh":
207
+
208
+ lights = PointLights(
209
+ device=self.device,
210
+ ambient_color=((0.8, 0.8, 0.8), ),
211
+ diffuse_color=((0.2, 0.2, 0.2), ),
212
+ specular_color=((0.0, 0.0, 0.0), ),
213
+ location=[[0.0, 200.0, 0.0]],
214
+ )
215
+
216
+ self.renderer = MeshRenderer(
217
+ rasterizer=self.meshRas,
218
+ shader=SoftPhongShader(
219
+ device=self.device,
220
+ cameras=camera,
221
+ lights=None,
222
+ blend_params=blendparam,
223
+ ),
224
+ )
225
+
226
+ if type == "silhouette":
227
+ self.raster_settings_silhouette = RasterizationSettings(
228
+ image_size=self.size,
229
+ blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
230
+ faces_per_pixel=50,
231
+ cull_backfaces=True,
232
+ )
233
+
234
+ self.silhouetteRas = MeshRasterizer(
235
+ cameras=camera,
236
+ raster_settings=self.raster_settings_silhouette)
237
+ self.renderer = MeshRenderer(rasterizer=self.silhouetteRas,
238
+ shader=SoftSilhouetteShader())
239
+
240
+ if type == "pointcloud":
241
+ self.raster_settings_pcd = PointsRasterizationSettings(
242
+ image_size=self.size, radius=0.006, points_per_pixel=10)
243
+
244
+ self.pcdRas = PointsRasterizer(
245
+ cameras=camera, raster_settings=self.raster_settings_pcd)
246
+ self.renderer = PointsRenderer(
247
+ rasterizer=self.pcdRas,
248
+ compositor=AlphaCompositor(background_color=(0, 0, 0)),
249
+ )
250
+
251
+ if type == "clean_mesh":
252
+
253
+ self.renderer = MeshRenderer(
254
+ rasterizer=self.meshRas,
255
+ shader=cleanShader(device=self.device,
256
+ cameras=camera,
257
+ blend_params=blendparam),
258
+ )
259
+
260
+ def VF2Mesh(self, verts, faces, vertex_texture = None):
261
+
262
+ if not torch.is_tensor(verts):
263
+ verts = torch.tensor(verts)
264
+ if not torch.is_tensor(faces):
265
+ faces = torch.tensor(faces)
266
+
267
+ if verts.ndimension() == 2:
268
+ verts = verts.unsqueeze(0).float()
269
+ if faces.ndimension() == 2:
270
+ faces = faces.unsqueeze(0).long()
271
+
272
+ verts = verts.to(self.device)
273
+ faces = faces.to(self.device)
274
+ if vertex_texture is not None:
275
+ vertex_texture = vertex_texture.to(self.device)
276
+
277
+ mesh = Meshes(verts, faces).to(self.device)
278
+
279
+ if vertex_texture is None:
280
+ mesh.textures = TexturesVertex(
281
+ verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5)#modify
282
+ else:
283
+ mesh.textures = TexturesVertex(
284
+ verts_features = vertex_texture.unsqueeze(0))#modify
285
+ return mesh
286
+
287
+ def load_meshes(self, verts, faces,offset=None, vertex_texture = None):
288
+ """load mesh into the pytorch3d renderer
289
+
290
+ Args:
291
+ verts ([N,3]): verts
292
+ faces ([N,3]): faces
293
+ offset ([N,3]): offset
294
+ """
295
+ if offset is not None:
296
+ verts = verts + offset
297
+
298
+ if isinstance(verts, list):
299
+ self.meshes = []
300
+ for V, F in zip(verts, faces):
301
+ if vertex_texture is None:
302
+ self.meshes.append(self.VF2Mesh(V, F))
303
+ else:
304
+ self.meshes.append(self.VF2Mesh(V, F, vertex_texture))
305
+ else:
306
+ if vertex_texture is None:
307
+ self.meshes = [self.VF2Mesh(verts, faces)]
308
+ else:
309
+ self.meshes = [self.VF2Mesh(verts, faces, vertex_texture)]
310
+
311
+ def get_depth_map(self, cam_ids=[0, 2]):
312
+
313
+ depth_maps = []
314
+ for cam_id in cam_ids:
315
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
316
+ fragments = self.meshRas(self.meshes[0])
317
+ depth_map = fragments.zbuf[..., 0].squeeze(0)
318
+ if cam_id == 2:
319
+ depth_map = torch.fliplr(depth_map)
320
+ depth_maps.append(depth_map)
321
+
322
+ return depth_maps
323
+
324
+ def get_rgb_image(self, cam_ids=[0, 2], bg='gray'):
325
+
326
+ images = []
327
+ for cam_id in range(len(self.cam_pos)):
328
+ if cam_id in cam_ids:
329
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", bg)
330
+ if len(cam_ids) == 4:
331
+ rendered_img = (self.renderer(
332
+ self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) -
333
+ 0.5) * 2.0
334
+ else:
335
+ rendered_img = (self.renderer(
336
+ self.meshes[0])[0:1, :, :, :3].permute(0, 3, 1, 2) -
337
+ 0.5) * 2.0
338
+ if cam_id == 2 and len(cam_ids) == 2:
339
+ rendered_img = torch.flip(rendered_img, dims=[3])
340
+ images.append(rendered_img)
341
+
342
+ return images
343
+
344
+ def get_rendered_video(self, images, save_path):
345
+
346
+ self.cam_pos = []
347
+ for angle in range(360):
348
+ self.cam_pos.append((
349
+ 100.0 * math.cos(np.pi / 180 * angle),
350
+ self.mesh_y_center,
351
+ 100.0 * math.sin(np.pi / 180 * angle),
352
+ ))
353
+
354
+ old_shape = np.array(images[0].shape[:2])
355
+ new_shape = np.around(
356
+ (self.size / old_shape[0]) * old_shape).astype(np.int)
357
+
358
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
359
+ video = cv2.VideoWriter(save_path, fourcc, 10,
360
+ (self.size * len(self.meshes) +
361
+ new_shape[1] * len(images), self.size))
362
+
363
+ pbar = tqdm(range(len(self.cam_pos)))
364
+ pbar.set_description(
365
+ colored(f"exporting video {os.path.basename(save_path)}...",
366
+ "blue"))
367
+ for cam_id in pbar:
368
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
369
+
370
+ img_lst = [
371
+ np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(
372
+ np.uint8)[:, :, [2, 1, 0]] for img in images
373
+ ]
374
+
375
+ for mesh in self.meshes:
376
+ rendered_img = ((self.renderer(mesh)[0, :, :, :3] *
377
+ 255.0).detach().cpu().numpy().astype(
378
+ np.uint8))
379
+
380
+ img_lst.append(rendered_img)
381
+ final_img = np.concatenate(img_lst, axis=1)
382
+ video.write(final_img)
383
+
384
+ video.release()
385
+ self.reload_cam()
386
+
387
+ def get_silhouette_image(self, cam_ids=[0, 2]):
388
+
389
+ images = []
390
+ for cam_id in range(len(self.cam_pos)):
391
+ if cam_id in cam_ids:
392
+ self.init_renderer(self.get_camera(cam_id), "silhouette")
393
+ rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
394
+ if cam_id == 2 and len(cam_ids) == 2:
395
+ rendered_img = torch.flip(rendered_img, dims=[2])
396
+ images.append(rendered_img)
397
+
398
+ return images