Huiwenshi commited on
Commit
68cd723
·
verified ·
1 Parent(s): 69d0567

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/*~
2
+ **/*.bk
3
+ **/*.xx
4
+ **/*.so
5
+ **/*.ipynb
6
+ **/*.log
7
+ **/*.swp
8
+ **/*.zip
9
+ **/*.look
10
+ **/*.lock
11
+ **/*.think
12
+ **/dosth.sh
13
+ **/nohup.out
14
+ **/*polaris*
15
+ **/*egg*/
16
+ **/cl5/
17
+ **/tmp/
18
+ **/look/
19
+ **/temp/
20
+ **/build/
21
+ **/model/
22
+ **/log/
23
+ **/backup/
24
+ **/outputs/
25
+ **/work_dir/
26
+ **/work_dirs/
27
+ **/__pycache__/
28
+ **/.ipynb_checkpoints/
29
+ *.jpg
30
+ *.png
31
+ *.gif
32
+ ### PreCI ###
33
+ .codecc
34
+
35
+ app_hg.py
36
+ outputs
37
+ weights
38
+ .vscode/
39
+ baking
40
+ inference.py
41
+ third_party/weights
42
+ third_party/dust3r
README.md CHANGED
@@ -1,14 +1,5 @@
1
- ---
2
- title: Hunyuan3D-1.0
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.5.0
8
- app_file: app_hg.py
9
- pinned: false
10
- short_description: Text-to-3D and Image-to-3D Generation
11
- ---
12
  <!-- ## **Hunyuan3D-1.0** -->
13
 
14
  <p align="center">
@@ -19,7 +10,7 @@ short_description: Text-to-3D and Image-to-3D Generation
19
 
20
  <div align="center">
21
  <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
22
- <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent Hunyuan3D&color=blue&logo=github-pages"></a> &ensp;
23
  <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a> &ensp;
24
  <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a> &ensp;
25
  <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a> &ensp;
@@ -101,6 +92,19 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
101
  # step 3. install other packages
102
  bash env_install.sh
103
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  <details>
105
  <summary>💡Other tips for envrionment installation</summary>
106
 
@@ -204,6 +208,33 @@ bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/tes
204
  bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
205
  ```
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  #### Using Gradio
208
 
209
  We have prepared two versions of multi-view generation, std and lite.
 
1
+ [English](README.md) | [简体中文](README_zh_cn.md)
2
+
 
 
 
 
 
 
 
 
 
3
  <!-- ## **Hunyuan3D-1.0** -->
4
 
5
  <p align="center">
 
10
 
11
  <div align="center">
12
  <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
13
+ <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a> &ensp;
14
  <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a> &ensp;
15
  <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a> &ensp;
16
  <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a> &ensp;
 
92
  # step 3. install other packages
93
  bash env_install.sh
94
  ```
95
+
96
+ because of dust3r, we offer a guide:
97
+
98
+ ```
99
+ cd third_party
100
+ git clone --recursive https://github.com/naver/dust3r.git
101
+
102
+ cd ../third_party/weights
103
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
104
+
105
+ ```
106
+
107
+
108
  <details>
109
  <summary>💡Other tips for envrionment installation</summary>
110
 
 
208
  bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
209
  ```
210
 
211
+ #### Baking related
212
+
213
+ ```bash
214
+ cd ./third_party
215
+ git clone --recursive https://github.com/naver/dust3r.git
216
+
217
+ mkdir -p weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
218
+ cd weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
219
+
220
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
221
+ cd ../../..
222
+ ```
223
+
224
+ If you download related code and weights, we list some additional arg:
225
+
226
+ | Argument | Default | Description |
227
+ |:------------------:|:---------:|:---------------------------------------------------:|
228
+ |`--do_bake` | False | baking multi-view into mesh |
229
+ |`--bake_align_times` | 3 | the times of align image with mesh |
230
+
231
+
232
+ Note: When running main.py, ensure that do_bake is set to True and do_texture_mapping is also set to True.
233
+
234
+ ```bash
235
+ python main.py ... --do_texture_mapping --do_bake (--do_render)
236
+ ```
237
+
238
  #### Using Gradio
239
 
240
  We have prepared two versions of multi-view generation, std and lite.
README_zh_cn.md ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [English](README.md) | [简体中文](README_zh_cn.md)
2
+
3
+ <!-- ## **Hunyuan3D-1.0** -->
4
+
5
+ <p align="center">
6
+ <img src="./assets/logo.png" height=200>
7
+ </p>
8
+
9
+ # Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation
10
+
11
+ <div align="center">
12
+ <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a> &ensp;
13
+ <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a> &ensp;
14
+ <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a> &ensp;
15
+ <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a> &ensp;
16
+ <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a> &ensp;
17
+ </div>
18
+
19
+
20
+ ## 🔥🔥🔥 更新!!
21
+
22
+ * Nov 5, 2024: 💬 已经支持图生3D。请在[script](#using-gradio)体验。
23
+ * Nov 5, 2024: 💬 已经支持文生3D,请在[script](#using-gradio)体验。
24
+
25
+
26
+ ## 📑 开源计划
27
+
28
+ - [x] Inference
29
+ - [x] Checkpoints
30
+ - [ ] Baking related
31
+ - [ ] Training
32
+ - [ ] ComfyUI
33
+ - [ ] Distillation Version
34
+ - [ ] TensorRT Version
35
+
36
+
37
+
38
+ ## **概要**
39
+ <p align="center">
40
+ <img src="./assets/teaser.png" height=450>
41
+ </p>
42
+
43
+ 为了解决现有的3D生成模型在生成速度和泛化能力上存在不足,我们开源了混元3D-1.0模型,可以帮助3D创作者和艺术家自动化生产3D资产。我们的模型采用两阶段生成方法,在保证质量和可控的基础上,仅需10秒即可生成3D资产。在第一阶段,我们采用了一种多视角扩散模型,轻量版模型能够在大约4秒内高效生成多视角图像,这些多视角图像从不同的视角捕捉了3D资产的丰富的纹理和几何先验,将任务从单视角重建松弛到多视角重建。在第二阶段,我们引入了一种前馈重建模型,利用上一阶段生成的多视角图像。该模型能够在大约3秒内快速而准确地重建3D资产。重建模型学习处理多视角扩散引入的噪声和不一致性,并利用条件图像中的可用信息高效恢复3D结构。最终,该模型可以实现输入任意单视角实现三维生成。
44
+
45
+
46
+ ## 🎉 **Hunyuan3D-1.0 模型架构**
47
+
48
+ <p align="center">
49
+ <img src="./assets/overview_3.png" height=400>
50
+ </p>
51
+
52
+
53
+ ## 📈 比较
54
+
55
+ 通过和其他开源模型比较, 混元3D-1.0在5项指标都得到了最高用户评分。细节请查看以下用户研究结果。
56
+
57
+ 在A100显卡上,轻量版模型仅需10s即可完成单图生成3D,标准版则大约需要25s。以下散点图表明腾讯混元3D-1.0实现了质量和速度的合理平衡。
58
+
59
+ <p align="center">
60
+ <img src="./assets/radar.png" height=300>
61
+ <img src="./assets/runtime.png" height=300>
62
+ </p>
63
+
64
+ ## 使用
65
+
66
+ #### 复制代码仓库
67
+
68
+ ```shell
69
+ git clone https://github.com/tencent/Hunyuan3D-1
70
+ cd Hunyuan3D-1
71
+ ```
72
+
73
+ #### Linux系统安装
74
+
75
+ env_install.sh 脚本提供了如何安装环境:
76
+
77
+ ```
78
+ # 第一步:创建环境
79
+ conda create -n hunyuan3d-1 python=3.9 or 3.10 or 3.11 or 3.12
80
+ conda activate hunyuan3d-1
81
+
82
+ # 第二部:安装torch和相关依赖包
83
+ which pip # check pip corresponds to python
84
+
85
+ # modify the cuda version according to your machine (recommended)
86
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
87
+
88
+ # 第三步:安装其他相关依赖包
89
+ bash env_install.sh
90
+ ```
91
+
92
+ 由于dust3r的许可证限制, 我们仅提供其安装途径:
93
+
94
+ ```
95
+ cd third_party
96
+ git clone --recursive https://github.com/naver/dust3r.git
97
+
98
+ cd ../third_party/weights
99
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
100
+
101
+ ```
102
+
103
+
104
+ <details>
105
+ <summary>💡一些环境安装建议</summary>
106
+
107
+ 可以选择安装 xformers 或 flash_attn 进行加速:
108
+
109
+ ```
110
+ pip install xformers --index-url https://download.pytorch.org/whl/cu121
111
+ ```
112
+ ```
113
+ pip install flash_attn
114
+ ```
115
+
116
+ Most environment errors are caused by a mismatch between machine and packages. You can try manually specifying the version, as shown in the following successful cases:
117
+ ```
118
+ # python3.9
119
+ pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
120
+ ```
121
+
122
+ when install pytorch3d, the gcc version is preferably greater than 9, and the gpu driver should not be too old.
123
+
124
+ </details>
125
+
126
+ #### 下载预训练模型
127
+
128
+ 模型下载链接 [https://huggingface.co/tencent/Hunyuan3D-1](https://huggingface.co/tencent/Hunyuan3D-1):
129
+
130
+ + `Hunyuan3D-1/lite`, lite model for multi-view generation.
131
+ + `Hunyuan3D-1/std`, standard model for multi-view generation.
132
+ + `Hunyuan3D-1/svrm`, sparse-view reconstruction model.
133
+
134
+
135
+ 为了通过Hugging Face下载模型,请先下载 huggingface-cli. (安装细节可见 [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
136
+
137
+ ```shell
138
+ python3 -m pip install "huggingface_hub[cli]"
139
+ ```
140
+
141
+ 请使用以下命令下载模型:
142
+
143
+ ```shell
144
+ mkdir weights
145
+ huggingface-cli download tencent/Hunyuan3D-1 --local-dir ./weights
146
+
147
+ mkdir weights/hunyuanDiT
148
+ huggingface-cli download Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled --local-dir ./weights/hunyuanDiT
149
+ ```
150
+
151
+ #### 推理
152
+ 对于文生3D,我们支持中/英双语生成,请使用以下命令进行本地推理:
153
+ ```python
154
+ python3 main.py \
155
+ --text_prompt "a lovely rabbit" \
156
+ --save_folder ./outputs/test/ \
157
+ --max_faces_num 90000 \
158
+ --do_texture_mapping \
159
+ --do_render
160
+ ```
161
+
162
+ 对于图生3D,请使用以下命令进行本地推理:
163
+ ```python
164
+ python3 main.py \
165
+ --image_prompt "/path/to/your/image" \
166
+ --save_folder ./outputs/test/ \
167
+ --max_faces_num 90000 \
168
+ --do_texture_mapping \
169
+ --do_render
170
+ ```
171
+ 更多参数详解:
172
+
173
+ | Argument | Default | Description |
174
+ |:------------------:|:---------:|:---------------------------------------------------:|
175
+ |`--text_prompt` | None |The text prompt for 3D generation |
176
+ |`--image_prompt` | None |The image prompt for 3D generation |
177
+ |`--t2i_seed` | 0 |The random seed for generating images |
178
+ |`--t2i_steps` | 25 |The number of steps for sampling of text to image |
179
+ |`--gen_seed` | 0 |The random seed for generating 3d generation |
180
+ |`--gen_steps` | 50 |The number of steps for sampling of 3d generation |
181
+ |`--max_faces_numm` | 90000 |The limit number of faces of 3d mesh |
182
+ |`--save_memory` | False |module will move to cpu automatically|
183
+ |`--do_texture_mapping` | False |Change vertex shadding to texture shading |
184
+ |`--do_render` | False |render gif |
185
+
186
+
187
+ 如果显卡内存有限,可以使用`--save_memory`命令,最低显卡内存要求如下:
188
+ - Inference Std-pipeline requires 30GB VRAM (24G VRAM with --save_memory).
189
+ - Inference Lite-pipeline requires 22GB VRAM (18G VRAM with --save_memory).
190
+ - Note: --save_memory will increase inference time
191
+
192
+ ```bash
193
+ bash scripts/text_to_3d_std.sh
194
+ bash scripts/text_to_3d_lite.sh
195
+ bash scripts/image_to_3d_std.sh
196
+ bash scripts/image_to_3d_lite.sh
197
+ ```
198
+
199
+ 如果你的显卡内存为16G,可以分别加载模型到显卡:
200
+ ```bash
201
+ bash scripts/text_to_3d_std_separately.sh 'a lovely rabbit' ./outputs/test # >= 16G
202
+ bash scripts/text_to_3d_lite_separately.sh 'a lovely rabbit' ./outputs/test # >= 14G
203
+ bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/test # >= 16G
204
+ bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
205
+ ```
206
+
207
+ #### Gradio界面部署
208
+
209
+ 我们分别提供轻量版和标准版界面:
210
+
211
+ ```shell
212
+ # std
213
+ python3 app.py
214
+ python3 app.py --save_memory
215
+
216
+ # lite
217
+ python3 app.py --use_lite
218
+ python3 app.py --use_lite --save_memory
219
+ ```
220
+
221
+ Gradio界面体验地址为 http://0.0.0.0:8080. 这里 0.0.0.0 应当填写运行模型的机器IP地址。
222
+
223
+ ## 相机参数
224
+
225
+ 生成多视图视角固定为
226
+
227
+ + Azimuth (relative to input view): `+0, +60, +120, +180, +240, +300`.
228
+
229
+
230
+ ## 引用
231
+
232
+ 如果我们的仓库对您有帮助,请引用我们的工作
233
+ ```bibtex
234
+ @misc{yang2024tencent,
235
+ title={Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation},
236
+ author={Xianghui Yang and Huiwen Shi and Bowen Zhang and Fan Yang and Jiacheng Wang and Hongxu Zhao and Xinhai Liu and Xinzhou Wang and Qingxiang Lin and Jiaao Yu and Lifu Wang and Zhuo Chen and Sicong Liu and Yuhong Liu and Yong Yang and Di Wang and Jie Jiang and Chunchao Guo},
237
+ year={2024},
238
+ eprint={2411.02293},
239
+ archivePrefix={arXiv},
240
+ primaryClass={cs.CV}
241
+ }
242
+ ```
app.py CHANGED
@@ -32,9 +32,21 @@ import torch
32
  import numpy as np
33
  from PIL import Image
34
  from einops import rearrange
 
35
 
36
  from infer import seed_everything, save_gif
37
  from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  warnings.simplefilter('ignore', category=UserWarning)
40
  warnings.simplefilter('ignore', category=FutureWarning)
@@ -58,33 +70,19 @@ CONST_MAX_QUEUE = 1
58
  CONST_SERVER = '0.0.0.0'
59
 
60
  CONST_HEADER = '''
61
- <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D
62
- Generationr</b></a></h2>
63
- Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/placeholder' target='_blank'>ArXiv</a>.
64
-
65
- ❗️❗️❗️**Important Notes:**
66
- - By default, our demo can export a .obj mesh with vertex colors or a .glb mesh.
67
- - If you select "texture mapping," it will export a .obj mesh with a texture map or a .glb mesh.
68
- - If you select "render GIF," it will export a GIF image rendering of the .glb file.
69
- - If the result is unsatisfactory, please try a different seed value (Default: 0).
70
  '''
71
 
72
- CONST_CITATION = r"""
73
- If HunYuan3D-1 is helpful, please help to ⭐ the <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/tencent/Hunyuan3D-1?style=social)](https://github.com/tencent/Hunyuan3D-1)
74
- ---
75
- 📝 **Citation**
76
- If you find our work useful for your research or applications, please cite using this bibtex:
77
- ```bibtex
78
- @misc{yang2024tencent,
79
- title={Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation},
80
- author={Xianghui Yang and Huiwen Shi and Bowen Zhang and Fan Yang and Jiacheng Wang and Hongxu Zhao and Xinhai Liu and Xinzhou Wang and Qingxiang Lin and Jiaao Yu and Lifu Wang and Zhuo Chen and Sicong Liu and Yuhong Liu and Yong Yang and Di Wang and Jie Jiang and Chunchao Guo},
81
- year={2024},
82
- eprint={2411.02293},
83
- archivePrefix={arXiv},
84
- primaryClass={cs.CV}
85
- }
86
- ```
87
- """
88
 
89
  ################################################################
90
  # prepare text examples and image examples
@@ -129,6 +127,13 @@ worker_v23 = Views2Mesh(
129
  )
130
  worker_gif = GifRenderer(args.device)
131
 
 
 
 
 
 
 
 
132
  def stage_0_t2i(text, image, seed, step):
133
  os.makedirs('./outputs/app_output', exist_ok=True)
134
  exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
@@ -153,11 +158,11 @@ def stage_0_t2i(text, image, seed, step):
153
  dst = worker_xbg(image, save_folder)
154
  return dst, save_folder
155
 
156
- def stage_1_xbg(image, save_folder):
157
  if isinstance(image, str):
158
  image = Image.open(image)
159
  dst = save_folder + '/img_nobg.png'
160
- rgba = worker_xbg(image)
161
  rgba.save(dst)
162
  return dst
163
 
@@ -181,12 +186,9 @@ def stage_3_v23(
181
  seed,
182
  save_folder,
183
  target_face_count = 30000,
184
- do_texture_mapping = True,
185
- do_render =True
186
  ):
187
- do_texture_mapping = do_texture_mapping or do_render
188
- obj_dst = save_folder + '/mesh_with_colors.obj'
189
- glb_dst = save_folder + '/mesh.glb'
190
  worker_v23(
191
  views_pil,
192
  cond_pil,
@@ -195,149 +197,268 @@ def stage_3_v23(
195
  target_face_count = target_face_count,
196
  do_texture_mapping = do_texture_mapping
197
  )
 
 
 
198
  return obj_dst, glb_dst
199
 
200
- def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
201
- if not do_render_gif: return None
202
- gif_dst = save_folder + '/output.gif'
203
- worker_gif(
204
- save_folder + '/mesh.obj',
205
- gif_dst_path = gif_dst
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return gif_dst
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  # ===============================================================
209
  # gradio display
210
  # ===============================================================
 
211
  with gr.Blocks() as demo:
212
  gr.Markdown(CONST_HEADER)
213
  with gr.Row(variant="panel"):
 
 
 
214
  with gr.Column(scale=2):
 
 
 
215
  with gr.Tab("Text to 3D"):
216
  with gr.Column():
217
- text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', lines=1, max_lines=10, label='Input text')
 
218
  with gr.Row():
219
- textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
220
- textgen_step = gr.Number(value=25, label="T2I step", precision=0)
221
- textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
222
- textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
223
- textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
224
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  with gr.Row():
226
- textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
227
- textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
228
  textgen_submit = gr.Button("Generate", variant="primary")
229
 
230
  with gr.Row():
231
- gr.Examples(examples=example_ts, inputs=[text], label="Txt examples", examples_per_page=10)
232
 
 
 
233
  with gr.Tab("Image to 3D"):
234
- with gr.Column():
235
- input_image = gr.Image(label="Input image",
236
- width=256, height=256, type="pil",
237
- image_mode="RGBA", sources="upload",
238
- interactive=True)
239
- with gr.Row():
240
- imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
241
- imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
242
- imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- with gr.Row():
245
- imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
246
- imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
247
- imggen_submit = gr.Button("Generate", variant="primary")
248
- with gr.Row():
249
- gr.Examples(
250
- examples=example_is,
251
- inputs=[input_image],
252
- label="Img examples",
253
- examples_per_page=10
254
- )
255
-
256
  with gr.Column(scale=3):
257
  with gr.Row():
258
  with gr.Column(scale=2):
259
- rem_bg_image = gr.Image(label="No backgraound image", type="pil",
260
- image_mode="RGBA", interactive=False)
 
 
 
 
261
  with gr.Column(scale=3):
262
- result_image = gr.Image(label="Multi views", type="pil", interactive=False)
263
-
264
- with gr.Row():
 
 
 
 
265
  result_3dobj = gr.Model3D(
266
  clear_color=[0.0, 0.0, 0.0, 0.0],
267
- label="Output Obj",
268
  show_label=True,
269
  visible=True,
270
  camera_position=[90, 90, None],
271
  interactive=False
272
  )
 
 
 
 
 
 
 
 
 
 
273
 
274
- result_3dglb = gr.Model3D(
275
  clear_color=[0.0, 0.0, 0.0, 0.0],
276
- label="Output Glb",
277
  show_label=True,
278
  visible=True,
279
  camera_position=[90, 90, None],
280
- interactive=False
281
- )
282
- result_gif = gr.Image(label="Rendered GIF", interactive=False)
283
 
284
- with gr.Row():
285
- gr.Markdown("""
286
- We recommend downloading and opening Glb with 3D software, such as Blender, MeshLab, etc.
287
-
288
- Limited by gradio, Obj file here only be shown as vertex shading, but Glb can be texture shading.
289
- """)
290
-
291
- #===============================================================
292
- # gradio running code
293
- #===============================================================
294
 
 
 
 
 
295
  none = gr.State(None)
296
  save_folder = gr.State()
297
  cond_image = gr.State()
298
  views_image = gr.State()
299
  text_image = gr.State()
300
 
 
301
  textgen_submit.click(
302
- fn=stage_0_t2i, inputs=[text, none, textgen_seed, textgen_step],
 
303
  outputs=[rem_bg_image, save_folder],
304
  ).success(
305
- fn=stage_2_i2v, inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
 
306
  outputs=[views_image, cond_image, result_image],
307
  ).success(
308
- fn=stage_3_v23, inputs=[views_image, cond_image, textgen_SEED, save_folder,
309
- textgen_max_faces, textgen_do_texture_mapping,
310
- textgen_do_render_gif],
311
- outputs=[result_3dobj, result_3dglb],
312
  ).success(
313
- fn=stage_4_gif, inputs=[result_3dglb, save_folder, textgen_do_render_gif],
 
 
 
 
 
314
  outputs=[result_gif],
315
  ).success(lambda: print('Text_to_3D Done ...'))
316
 
 
317
  imggen_submit.click(
318
- fn=stage_0_t2i, inputs=[none, input_image, textgen_seed, textgen_step],
 
319
  outputs=[text_image, save_folder],
320
  ).success(
321
- fn=stage_1_xbg, inputs=[text_image, save_folder],
 
322
  outputs=[rem_bg_image],
323
  ).success(
324
- fn=stage_2_i2v, inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
 
325
  outputs=[views_image, cond_image, result_image],
326
  ).success(
327
- fn=stage_3_v23, inputs=[views_image, cond_image, imggen_SEED, save_folder,
328
- imggen_max_faces, imggen_do_texture_mapping,
329
- imggen_do_render_gif],
330
- outputs=[result_3dobj, result_3dglb],
 
 
 
331
  ).success(
332
- fn=stage_4_gif, inputs=[result_3dglb, save_folder, imggen_do_render_gif],
 
333
  outputs=[result_gif],
334
  ).success(lambda: print('Image_to_3D Done ...'))
335
 
336
- #===============================================================
337
- # start gradio server
338
- #===============================================================
339
 
340
- gr.Markdown(CONST_CITATION)
341
  demo.queue(max_size=CONST_MAX_QUEUE)
342
  demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
343
 
 
32
  import numpy as np
33
  from PIL import Image
34
  from einops import rearrange
35
+ import pandas as pd
36
 
37
  from infer import seed_everything, save_gif
38
  from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
39
+ from third_party.check import check_bake_available
40
+
41
+ try:
42
+ from third_party.mesh_baker import MeshBaker
43
+ BAKE_AVAILEBLE = True
44
+ except Exception as err:
45
+ print(err)
46
+ print("import baking related fail, run without baking")
47
+ check_bake_available()
48
+ BAKE_AVAILEBLE = False
49
+
50
 
51
  warnings.simplefilter('ignore', category=UserWarning)
52
  warnings.simplefilter('ignore', category=FutureWarning)
 
70
  CONST_SERVER = '0.0.0.0'
71
 
72
  CONST_HEADER = '''
73
+ <h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
74
+ ⭐️Technical report: <a href='https://arxiv.org/pdf/2411.02293' target='_blank'>ArXiv</a>. ⭐️Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>.
 
 
 
 
 
 
 
75
  '''
76
 
77
+ CONST_NOTE = '''
78
+ ❗️❗️❗️Usage❗️❗️❗️<br>
79
+
80
+ Limited by format, the model can only export *.obj mesh with vertex colors. The "texture" mod can only work on *.glb.<br>
81
+ Please click "Do Rendering" to export a GIF.<br>
82
+ You can click "Do Baking" to bake multi-view imgaes onto the shape.<br>
83
+
84
+ If the results aren't satisfactory, please try a different radnom seed (default is 0).
85
+ '''
 
 
 
 
 
 
 
86
 
87
  ################################################################
88
  # prepare text examples and image examples
 
127
  )
128
  worker_gif = GifRenderer(args.device)
129
 
130
+
131
+ if BAKE_AVAILEBLE:
132
+ worker_baker = MeshBaker()
133
+
134
+
135
+ ### functional modules
136
+
137
  def stage_0_t2i(text, image, seed, step):
138
  os.makedirs('./outputs/app_output', exist_ok=True)
139
  exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
 
158
  dst = worker_xbg(image, save_folder)
159
  return dst, save_folder
160
 
161
+ def stage_1_xbg(image, save_folder, force_remove):
162
  if isinstance(image, str):
163
  image = Image.open(image)
164
  dst = save_folder + '/img_nobg.png'
165
+ rgba = worker_xbg(image, force=force_remove)
166
  rgba.save(dst)
167
  return dst
168
 
 
186
  seed,
187
  save_folder,
188
  target_face_count = 30000,
189
+ texture_color = 'texture'
 
190
  ):
191
+ do_texture_mapping = texture_color == 'texture'
 
 
192
  worker_v23(
193
  views_pil,
194
  cond_pil,
 
197
  target_face_count = target_face_count,
198
  do_texture_mapping = do_texture_mapping
199
  )
200
+ glb_dst = save_folder + '/mesh.glb' if do_texture_mapping else None
201
+ obj_dst = save_folder + '/mesh.obj'
202
+ obj_dst = save_folder + '/mesh_vertex_colors.obj' # gradio just only can show vertex shading
203
  return obj_dst, glb_dst
204
 
205
+ def stage_3p_baking(save_folder, color, bake):
206
+ if color == "texture" and bake:
207
+ obj_dst = worker_baker(save_folder)
208
+ glb_dst = obj_dst.replace(".obj", ".glb")
209
+ return glb_dst
210
+ else:
211
+ return None
212
+
213
+ def stage_4_gif(save_folder, color, bake, render):
214
+ if not render: return None
215
+ if os.path.exists(save_folder + '/view_1/bake/mesh.obj'):
216
+ obj_dst = save_folder + '/view_1/bake/mesh.obj'
217
+ elif os.path.exists(save_folder + '/view_0/bake/mesh.obj'):
218
+ obj_dst = save_folder + '/view_0/bake/mesh.obj'
219
+ elif os.path.exists(save_folder + '/mesh.obj'):
220
+ obj_dst = save_folder + '/mesh.obj'
221
+ else:
222
+ print(save_folder)
223
+ raise FileNotFoundError("mesh obj file not found")
224
+ gif_dst = obj_dst.replace(".obj", ".gif")
225
+ worker_gif(obj_dst, gif_dst_path=gif_dst)
226
  return gif_dst
227
+
228
+
229
+ def check_image_available(image):
230
+ if image.mode == "RGBA":
231
+ data = np.array(image)
232
+ alpha_channel = data[:, :, 3]
233
+ unique_alpha_values = np.unique(alpha_channel)
234
+ if len(unique_alpha_values) == 1:
235
+ msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
236
+ return msg, gr.update(value=True, interactive=False)
237
+ else:
238
+ msg = "The image has four channels, and you can choose to remove the background or not."
239
+ return msg, gr.update(value=False, interactive=True)
240
+ elif image.mode == "RGB":
241
+ msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
242
+ return msg, gr.update(value=True, interactive=False)
243
+ else:
244
+ raise Exception("Image Error")
245
+
246
+ def update_bake_render(color):
247
+ if color == "vertex":
248
+ return gr.update(value=False, interactive=False), gr.update(value=False, interactive=False)
249
+ else:
250
+ return gr.update(interactive=True), gr.update(interactive=True)
251
+
252
  # ===============================================================
253
  # gradio display
254
  # ===============================================================
255
+
256
  with gr.Blocks() as demo:
257
  gr.Markdown(CONST_HEADER)
258
  with gr.Row(variant="panel"):
259
+
260
+ ###### Input region
261
+
262
  with gr.Column(scale=2):
263
+
264
+ ### Text iutput region
265
+
266
  with gr.Tab("Text to 3D"):
267
  with gr.Column():
268
+ text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
269
+ lines=3, max_lines=20, label='Input text')
270
  with gr.Row():
271
+ textgen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
272
+ with gr.Row():
273
+ textgen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
274
+ if BAKE_AVAILEBLE:
275
+ textgen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
276
+ else:
277
+ textgen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
278
+
279
+ textgen_color.change(
280
+ fn=update_bake_render,
281
+ inputs=textgen_color,
282
+ outputs=[textgen_bake, textgen_render]
283
+ )
284
+
285
+ with gr.Row():
286
+ textgen_seed = gr.Number(value=0, label="T2I seed", precision=0, interactive=True)
287
+ textgen_step = gr.Number(value=25, label="T2I steps", precision=0,
288
+ minimum=10, maximum=50, interactive=True)
289
+ textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
290
+ textgen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
291
+ minimum=40, maximum=100, interactive=True)
292
+ textgen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
293
+ minimum=5000, maximum=1000000, interactive=True)
294
  with gr.Row():
 
 
295
  textgen_submit = gr.Button("Generate", variant="primary")
296
 
297
  with gr.Row():
298
+ gr.Examples(examples=example_ts, inputs=[text], label="Text examples", examples_per_page=10)
299
 
300
+ ### Image iutput region
301
+
302
  with gr.Tab("Image to 3D"):
303
+ with gr.Row():
304
+ input_image = gr.Image(label="Input image", width=256, height=256, type="pil",
305
+ image_mode="RGBA", sources="upload", interactive=True)
306
+ with gr.Row():
307
+ alert_message = gr.Markdown("") # for warning
308
+ with gr.Row():
309
+ imggen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
310
+ with gr.Row():
311
+ imggen_removebg = gr.Checkbox(label="Remove Background", value=True, interactive=True)
312
+ imggen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
313
+ if BAKE_AVAILEBLE:
314
+ imggen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
315
+ else:
316
+ imggen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
317
+
318
+ input_image.change(
319
+ fn=check_image_available,
320
+ inputs=input_image,
321
+ outputs=[alert_message, imggen_removebg]
322
+ )
323
+ imggen_color.change(
324
+ fn=update_bake_render,
325
+ inputs=imggen_color,
326
+ outputs=[imggen_bake, imggen_render]
327
+ )
328
+
329
+ with gr.Row():
330
+ imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
331
+ imggen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
332
+ minimum=40, maximum=100, interactive=True)
333
+ imggen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
334
+ minimum=5000, maximum=1000000, interactive=True)
335
+ with gr.Row():
336
+ imggen_submit = gr.Button("Generate", variant="primary")
337
+
338
+ with gr.Row():
339
+ gr.Examples(examples=example_is, inputs=[input_image],
340
+ label="Img examples", examples_per_page=10)
341
+
342
+ gr.Markdown(CONST_NOTE)
343
+
344
+ ###### Output region
345
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  with gr.Column(scale=3):
347
  with gr.Row():
348
  with gr.Column(scale=2):
349
+ rem_bg_image = gr.Image(
350
+ label="Image without background",
351
+ type="pil",
352
+ image_mode="RGBA",
353
+ interactive=False
354
+ )
355
  with gr.Column(scale=3):
356
+ result_image = gr.Image(
357
+ label="Multi-view images",
358
+ type="pil",
359
+ interactive=False
360
+ )
361
+
362
+ with gr.Row():
363
  result_3dobj = gr.Model3D(
364
  clear_color=[0.0, 0.0, 0.0, 0.0],
365
+ label="OBJ vertex color",
366
  show_label=True,
367
  visible=True,
368
  camera_position=[90, 90, None],
369
  interactive=False
370
  )
371
+ result_gif = gr.Image(label="GIF", interactive=False)
372
+
373
+ with gr.Row():
374
+ result_3dglb_texture = gr.Model3D(
375
+ clear_color=[0.0, 0.0, 0.0, 0.0],
376
+ label="GLB texture color",
377
+ show_label=True,
378
+ visible=True,
379
+ camera_position=[90, 90, None],
380
+ interactive=False)
381
 
382
+ result_3dglb_baked = gr.Model3D(
383
  clear_color=[0.0, 0.0, 0.0, 0.0],
384
+ label="GLB baked color",
385
  show_label=True,
386
  visible=True,
387
  camera_position=[90, 90, None],
388
+ interactive=False)
 
 
389
 
390
+ with gr.Row():
391
+ gr.Markdown(
392
+ "Due to Gradio limitations, OBJ files are displayed with vertex shading only, "
393
+ "while GLB files can be viewed with texture shading. <br>For the best experience, "
394
+ "we recommend downloading the GLB files and opening them with 3D software "
395
+ "like Blender or MeshLab."
396
+ )
 
 
 
397
 
398
+ #===============================================================
399
+ # gradio running code
400
+ #===============================================================
401
+
402
  none = gr.State(None)
403
  save_folder = gr.State()
404
  cond_image = gr.State()
405
  views_image = gr.State()
406
  text_image = gr.State()
407
 
408
+
409
  textgen_submit.click(
410
+ fn=stage_0_t2i,
411
+ inputs=[text, none, textgen_seed, textgen_step],
412
  outputs=[rem_bg_image, save_folder],
413
  ).success(
414
+ fn=stage_2_i2v,
415
+ inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
416
  outputs=[views_image, cond_image, result_image],
417
  ).success(
418
+ fn=stage_3_v23,
419
+ inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_color],
420
+ outputs=[result_3dobj, result_3dglb_texture],
 
421
  ).success(
422
+ fn=stage_3p_baking,
423
+ inputs=[save_folder, textgen_color, textgen_bake],
424
+ outputs=[result_3dglb_baked],
425
+ ).success(
426
+ fn=stage_4_gif,
427
+ inputs=[save_folder, textgen_color, textgen_bake, textgen_render],
428
  outputs=[result_gif],
429
  ).success(lambda: print('Text_to_3D Done ...'))
430
 
431
+
432
  imggen_submit.click(
433
+ fn=stage_0_t2i,
434
+ inputs=[none, input_image, textgen_seed, textgen_step],
435
  outputs=[text_image, save_folder],
436
  ).success(
437
+ fn=stage_1_xbg,
438
+ inputs=[text_image, save_folder, imggen_removebg],
439
  outputs=[rem_bg_image],
440
  ).success(
441
+ fn=stage_2_i2v,
442
+ inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
443
  outputs=[views_image, cond_image, result_image],
444
  ).success(
445
+ fn=stage_3_v23,
446
+ inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_color],
447
+ outputs=[result_3dobj, result_3dglb_texture],
448
+ ).success(
449
+ fn=stage_3p_baking,
450
+ inputs=[save_folder, imggen_color, imggen_bake],
451
+ outputs=[result_3dglb_baked],
452
  ).success(
453
+ fn=stage_4_gif,
454
+ inputs=[save_folder, imggen_color, imggen_bake, imggen_render],
455
  outputs=[result_gif],
456
  ).success(lambda: print('Image_to_3D Done ...'))
457
 
458
+ #===============================================================
459
+ # start gradio server
460
+ #===============================================================
461
 
 
462
  demo.queue(max_size=CONST_MAX_QUEUE)
463
  demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
464
 
env_install.sh CHANGED
@@ -1,6 +1,6 @@
1
  pip3 install diffusers transformers
2
  pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops
3
- pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio
4
  pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
5
  pip3 install git+https://github.com/NVlabs/nvdiffrast
6
  pip3 install open3d
 
1
  pip3 install diffusers transformers
2
  pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops
3
+ pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio roma
4
  pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
5
  pip3 install git+https://github.com/NVlabs/nvdiffrast
6
  pip3 install open3d
infer/gif_render.py CHANGED
@@ -25,7 +25,7 @@
25
  import os, sys
26
  sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
27
 
28
- from svrm.ldm.vis_util import render
29
  from infer.utils import seed_everything, timing_decorator
30
 
31
  class GifRenderer():
@@ -40,14 +40,14 @@ class GifRenderer():
40
  self,
41
  obj_filename,
42
  elev=0,
43
- azim=0,
44
  resolution=512,
45
  gif_dst_path='',
46
  n_views=120,
47
  fps=30,
48
  rgb=True
49
  ):
50
- render(
51
  obj_filename,
52
  elev=elev,
53
  azim=azim,
 
25
  import os, sys
26
  sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
27
 
28
+ from svrm.ldm.vis_util import render_func
29
  from infer.utils import seed_everything, timing_decorator
30
 
31
  class GifRenderer():
 
40
  self,
41
  obj_filename,
42
  elev=0,
43
+ azim=None,
44
  resolution=512,
45
  gif_dst_path='',
46
  n_views=120,
47
  fps=30,
48
  rgb=True
49
  ):
50
+ render_func(
51
  obj_filename,
52
  elev=elev,
53
  azim=azim,
infer/image_to_views.py CHANGED
@@ -48,21 +48,26 @@ def save_gif(pils, save_path, df=False):
48
 
49
 
50
  class Image2Views():
51
- def __init__(self, device="cuda:0", use_lite=False, save_memory=False):
 
 
 
52
  self.device = device
53
  if use_lite:
 
54
  self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
55
- "./weights/mvd_lite",
56
  torch_dtype = torch.float16,
57
  use_safetensors = True,
58
  )
59
  else:
 
60
  self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
61
- "./weights/mvd_std",
62
  torch_dtype = torch.float16,
63
  use_safetensors = True,
64
  )
65
- self.pipe = self.pipe.to(device)
66
  self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
67
  self.save_memory = save_memory
68
  set_parameter_grad_false(self.pipe.unet)
 
48
 
49
 
50
  class Image2Views():
51
+ def __init__(self,
52
+ device="cuda:0", use_lite=False, save_memory=False,
53
+ std_pretrain='./weights/mvd_std', lite_pretrain='./weights/mvd_lite'
54
+ ):
55
  self.device = device
56
  if use_lite:
57
+ print("loading", lite_pretrain)
58
  self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
59
+ lite_pretrain,
60
  torch_dtype = torch.float16,
61
  use_safetensors = True,
62
  )
63
  else:
64
+ print("loadding", std_pretrain)
65
  self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
66
+ std_pretrain,
67
  torch_dtype = torch.float16,
68
  use_safetensors = True,
69
  )
70
+ self.pipe = self.pipe if save_memory else self.pipe.to(device)
71
  self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
72
  self.save_memory = save_memory
73
  set_parameter_grad_false(self.pipe.unet)
infer/text_to_image.py CHANGED
@@ -46,8 +46,7 @@ class Text2Image():
46
  )
47
  set_parameter_grad_false(self.pipe.transformer)
48
  print('text2image transformer model', get_parameter_number(self.pipe.transformer))
49
- if not save_memory:
50
- self.pipe = self.pipe.to(device)
51
  self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
52
  "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
53
  "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
 
46
  )
47
  set_parameter_grad_false(self.pipe.transformer)
48
  print('text2image transformer model', get_parameter_number(self.pipe.transformer))
49
+ self.pipe = self.pipe if save_memory else self.pipe.to(device)
 
50
  self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
51
  "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
52
  "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
infer/utils.py CHANGED
@@ -21,7 +21,8 @@
21
  # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
 
25
  import os
26
  import time
27
  import random
@@ -30,6 +31,7 @@ import torch
30
  from torch.cuda.amp import autocast, GradScaler
31
  from functools import wraps
32
 
 
33
  def seed_everything(seed):
34
  '''
35
  seed everthing
@@ -39,6 +41,7 @@ def seed_everything(seed):
39
  torch.manual_seed(seed)
40
  os.environ["PL_GLOBAL_SEED"] = str(seed)
41
 
 
42
  def timing_decorator(category: str):
43
  '''
44
  timing_decorator: record time
@@ -57,6 +60,7 @@ def timing_decorator(category: str):
57
  return wrapper
58
  return decorator
59
 
 
60
  def auto_amp_inference(func):
61
  '''
62
  with torch.cuda.amp.autocast()"
@@ -69,11 +73,13 @@ def auto_amp_inference(func):
69
  return output
70
  return wrapper
71
 
 
72
  def get_parameter_number(model):
73
  total_num = sum(p.numel() for p in model.parameters())
74
  trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
75
  return {'Total': total_num, 'Trainable': trainable_num}
76
 
 
77
  def set_parameter_grad_false(model):
78
  for p in model.parameters():
79
  p.requires_grad = False
 
21
  # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+ import sys
25
+ import io
26
  import os
27
  import time
28
  import random
 
31
  from torch.cuda.amp import autocast, GradScaler
32
  from functools import wraps
33
 
34
+
35
  def seed_everything(seed):
36
  '''
37
  seed everthing
 
41
  torch.manual_seed(seed)
42
  os.environ["PL_GLOBAL_SEED"] = str(seed)
43
 
44
+
45
  def timing_decorator(category: str):
46
  '''
47
  timing_decorator: record time
 
60
  return wrapper
61
  return decorator
62
 
63
+
64
  def auto_amp_inference(func):
65
  '''
66
  with torch.cuda.amp.autocast()"
 
73
  return output
74
  return wrapper
75
 
76
+
77
  def get_parameter_number(model):
78
  total_num = sum(p.numel() for p in model.parameters())
79
  trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
80
  return {'Total': total_num, 'Trainable': trainable_num}
81
 
82
+
83
  def set_parameter_grad_false(model):
84
  for p in model.parameters():
85
  p.requires_grad = False
infer/views_to_mesh.py CHANGED
@@ -47,11 +47,15 @@ class Views2Mesh():
47
  use_lite: lite version
48
  save_memory: cpu auto
49
  '''
50
- self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
51
- self.mv23d_predictor.model.eval()
52
- self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
53
  self.device = device
54
  self.save_memory = save_memory
 
 
 
 
 
 
 
55
  set_parameter_grad_false(self.mv23d_predictor.model)
56
  print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
57
 
@@ -109,7 +113,6 @@ class Views2Mesh():
109
  do_texture_mapping = do_texture_mapping
110
  )
111
  torch.cuda.empty_cache()
112
- return save_dir
113
 
114
 
115
  if __name__ == "__main__":
 
47
  use_lite: lite version
48
  save_memory: cpu auto
49
  '''
 
 
 
50
  self.device = device
51
  self.save_memory = save_memory
52
+ self.mv23d_predictor = MV23DPredictor(
53
+ mv23d_ckt_path,
54
+ mv23d_cfg_path,
55
+ device = "cpu" if save_memory else device
56
+ )
57
+ self.mv23d_predictor.model.eval()
58
+ self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
59
  set_parameter_grad_false(self.mv23d_predictor.model)
60
  print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
61
 
 
113
  do_texture_mapping = do_texture_mapping
114
  )
115
  torch.cuda.empty_cache()
 
116
 
117
 
118
  if __name__ == "__main__":
main.py CHANGED
@@ -24,16 +24,28 @@
24
 
25
  import os
26
  import warnings
27
- import torch
28
- from PIL import Image
29
  import argparse
30
-
31
- from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
 
32
 
33
  warnings.simplefilter('ignore', category=UserWarning)
34
  warnings.simplefilter('ignore', category=FutureWarning)
35
  warnings.simplefilter('ignore', category=DeprecationWarning)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def get_args():
38
  parser = argparse.ArgumentParser()
39
  parser.add_argument(
@@ -73,8 +85,8 @@ def get_args():
73
  "--gen_steps", default=50, type=int
74
  )
75
  parser.add_argument(
76
- "--max_faces_num", default=80000, type=int,
77
- help="max num of face, suggest 80000 for effect, 10000 for speed"
78
  )
79
  parser.add_argument(
80
  "--save_memory", default=False, action="store_true"
@@ -85,6 +97,13 @@ def get_args():
85
  parser.add_argument(
86
  "--do_render", default=False, action="store_true"
87
  )
 
 
 
 
 
 
 
88
  return parser.parse_args()
89
 
90
 
@@ -95,6 +114,7 @@ if __name__ == "__main__":
95
  assert args.text_prompt or args.image_prompt, "Text and image can only be given to one"
96
 
97
  # init model
 
98
  rembg_model = Removebg()
99
  image_to_views_model = Image2Views(
100
  device=args.device,
@@ -116,9 +136,18 @@ if __name__ == "__main__":
116
  device = args.device,
117
  save_memory = args.save_memory
118
  )
119
- if args.do_render:
 
 
 
 
 
 
 
120
  gif_renderer = GifRenderer(device=args.device)
121
-
 
 
122
  # ---- ----- ---- ---- ---- ----
123
 
124
  os.makedirs(args.save_folder, exist_ok=True)
@@ -136,7 +165,7 @@ if __name__ == "__main__":
136
 
137
  # stage 2, remove back ground
138
  res_rgba_pil = rembg_model(res_rgb_pil)
139
- res_rgb_pil.save(os.path.join(args.save_folder, "img_nobg.png"))
140
 
141
  # stage 3, image to views
142
  (views_grid_pil, cond_img), view_pil_list = image_to_views_model(
@@ -155,10 +184,29 @@ if __name__ == "__main__":
155
  save_folder = args.save_folder,
156
  do_texture_mapping = args.do_texture_mapping
157
  )
158
-
159
- # stage 5, render gif
 
 
 
 
 
 
160
  if args.do_render:
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  gif_renderer(
162
- os.path.join(args.save_folder, 'mesh.obj'),
163
  gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
164
  )
 
24
 
25
  import os
26
  import warnings
 
 
27
  import argparse
28
+ import time
29
+ from PIL import Image
30
+ import torch
31
 
32
  warnings.simplefilter('ignore', category=UserWarning)
33
  warnings.simplefilter('ignore', category=FutureWarning)
34
  warnings.simplefilter('ignore', category=DeprecationWarning)
35
 
36
+ from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
37
+ from third_party.mesh_baker import MeshBaker
38
+ from third_party.check import check_bake_available
39
+
40
+ try:
41
+ from third_party.mesh_baker import MeshBaker
42
+ assert check_bake_available()
43
+ BAKE_AVAILEBLE = True
44
+ except Exception as err:
45
+ print(err)
46
+ print("import baking related fail, run without baking")
47
+ BAKE_AVAILEBLE = False
48
+
49
  def get_args():
50
  parser = argparse.ArgumentParser()
51
  parser.add_argument(
 
85
  "--gen_steps", default=50, type=int
86
  )
87
  parser.add_argument(
88
+ "--max_faces_num", default=90000, type=int,
89
+ help="max num of face, suggest 90000 for effect, 10000 for speed"
90
  )
91
  parser.add_argument(
92
  "--save_memory", default=False, action="store_true"
 
97
  parser.add_argument(
98
  "--do_render", default=False, action="store_true"
99
  )
100
+ parser.add_argument(
101
+ "--do_bake", default=False, action="store_true"
102
+ )
103
+ parser.add_argument(
104
+ "--bake_align_times", default=3, type=int,
105
+ help="align times between view image and mesh, suggest 1~6"
106
+ )
107
  return parser.parse_args()
108
 
109
 
 
114
  assert args.text_prompt or args.image_prompt, "Text and image can only be given to one"
115
 
116
  # init model
117
+ st = time.time()
118
  rembg_model = Removebg()
119
  image_to_views_model = Image2Views(
120
  device=args.device,
 
136
  device = args.device,
137
  save_memory = args.save_memory
138
  )
139
+
140
+ if args.do_bake and BAKE_AVAILEBLE:
141
+ mesh_baker = MeshBaker(
142
+ device = args.device,
143
+ align_times = args.bake_align_times
144
+ )
145
+
146
+ if check_bake_available():
147
  gif_renderer = GifRenderer(device=args.device)
148
+
149
+ print(f"Init Models cost {time.time()-st}s")
150
+
151
  # ---- ----- ---- ---- ---- ----
152
 
153
  os.makedirs(args.save_folder, exist_ok=True)
 
165
 
166
  # stage 2, remove back ground
167
  res_rgba_pil = rembg_model(res_rgb_pil)
168
+ res_rgba_pil.save(os.path.join(args.save_folder, "img_nobg.png"))
169
 
170
  # stage 3, image to views
171
  (views_grid_pil, cond_img), view_pil_list = image_to_views_model(
 
184
  save_folder = args.save_folder,
185
  do_texture_mapping = args.do_texture_mapping
186
  )
187
+
188
+ # stage 5, baking
189
+ mesh_file_for_render = None
190
+ if args.do_bake and BAKE_AVAILEBLE:
191
+ mesh_file_for_render = mesh_baker(args.save_folder)
192
+
193
+ # stage 6, render gif
194
+ # todo fix: if init folder unclear, it maybe mistake rendering
195
  if args.do_render:
196
+ if mesh_file_for_render and os.path.exists(mesh_file_for_render):
197
+ mesh_file_for_render = mesh_file_for_render
198
+ elif os.path.exists(os.path.join(args.save_folder, 'view_1/bake/mesh.obj')):
199
+ mesh_file_for_render = os.path.join(args.save_folder, 'view_1/bake/mesh.obj')
200
+ elif os.path.exists(os.path.join(args.save_folder, 'view_0/bake/mesh.obj')):
201
+ mesh_file_for_render = os.path.join(args.save_folder, 'view_0/bake/mesh.obj')
202
+ elif os.path.exists(os.path.join(args.save_folder, 'mesh.obj')):
203
+ mesh_file_for_render = os.path.join(args.save_folder, 'mesh.obj')
204
+ else:
205
+ raise FileNotFoundError("mesh_file_for_render not found")
206
+
207
+ print("Rendering 3d file:", mesh_file_for_render)
208
+
209
  gif_renderer(
210
+ mesh_file_for_render,
211
  gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
212
  )
requirements.txt CHANGED
@@ -22,3 +22,4 @@ git+https://github.com/facebookresearch/pytorch3d@stable
22
  git+https://github.com/NVlabs/nvdiffrast
23
  open3d
24
  ninja
 
 
22
  git+https://github.com/NVlabs/nvdiffrast
23
  open3d
24
  ninja
25
+ roma
svrm/ldm/models/svrm.py CHANGED
@@ -46,7 +46,7 @@ from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
46
 
47
  from ..utils.ops import scale_tensor
48
  from ..util import count_params, instantiate_from_config
49
- from ..vis_util import render
50
 
51
 
52
  def unwrap_uv(v_pos, t_pos_idx):
@@ -58,7 +58,6 @@ def unwrap_uv(v_pos, t_pos_idx):
58
  indices = indices.astype(np.int64, casting="same_kind")
59
  return uvs, indices
60
 
61
-
62
  def uv_padding(image, hole_mask, uv_padding_size = 2):
63
  return cv2.inpaint(
64
  (image.detach().cpu().numpy() * 255).astype(np.uint8),
@@ -120,14 +119,16 @@ class SVRMModel(torch.nn.Module):
120
  out_dir = 'outputs/test'
121
  ):
122
  """
123
- color_type: 0 for ray texture, 1 for vertices texture
124
  """
125
 
126
- obj_vertext_path = os.path.join(out_dir, 'mesh_with_colors.obj')
127
- obj_path = os.path.join(out_dir, 'mesh.obj')
128
- obj_texture_path = os.path.join(out_dir, 'texture.png')
129
- obj_mtl_path = os.path.join(out_dir, 'texture.mtl')
130
- glb_path = os.path.join(out_dir, 'mesh.glb')
 
 
131
 
132
  st = time.time()
133
 
@@ -204,15 +205,13 @@ class SVRMModel(torch.nn.Module):
204
  mesh = trimesh.load_mesh(obj_vertext_path)
205
  print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
206
  st = time.time()
207
-
208
  if not do_texture_mapping:
209
- shutil.copy(obj_vertext_path, obj_path)
210
- mesh.export(glb_path, file_type='glb')
211
- return None
212
 
213
-
214
- ########## export texture ########
215
-
216
 
217
  st = time.time()
218
 
@@ -238,12 +237,9 @@ class SVRMModel(torch.nn.Module):
238
 
239
  # Interpolate world space position
240
  gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
241
-
242
  with torch.no_grad():
243
  gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
244
-
245
  tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
246
-
247
  tex_map = tex_map.float().squeeze(0) # (0, 1)
248
  tex_map = tex_map.view((texture_res, texture_res, 3))
249
  img = uv_padding(tex_map, hole_mask)
@@ -257,7 +253,7 @@ class SVRMModel(torch.nn.Module):
257
  fid.write('newmtl material_0\n')
258
  fid.write("Ka 1.000 1.000 1.000\n")
259
  fid.write("Kd 1.000 1.000 1.000\n")
260
- fid.write("Ks 0.000 0.000 0.000\n")
261
  fid.write("d 1.0\n")
262
  fid.write("illum 2\n")
263
  fid.write(f'map_Kd texture.png\n')
@@ -278,4 +274,5 @@ class SVRMModel(torch.nn.Module):
278
  mesh = trimesh.load_mesh(obj_path)
279
  mesh.export(glb_path, file_type='glb')
280
  print(f"=====> generate mesh with texture shading time: {time.time() - st}")
 
281
 
 
46
 
47
  from ..utils.ops import scale_tensor
48
  from ..util import count_params, instantiate_from_config
49
+ from ..vis_util import render_func
50
 
51
 
52
  def unwrap_uv(v_pos, t_pos_idx):
 
58
  indices = indices.astype(np.int64, casting="same_kind")
59
  return uvs, indices
60
 
 
61
  def uv_padding(image, hole_mask, uv_padding_size = 2):
62
  return cv2.inpaint(
63
  (image.detach().cpu().numpy() * 255).astype(np.uint8),
 
119
  out_dir = 'outputs/test'
120
  ):
121
  """
122
+ do_texture_mapping: True for ray texture, False for vertices texture
123
  """
124
 
125
+ obj_vertext_path = os.path.join(out_dir, 'mesh_vertex_colors.obj')
126
+
127
+ if do_texture_mapping:
128
+ obj_path = os.path.join(out_dir, 'mesh.obj')
129
+ obj_texture_path = os.path.join(out_dir, 'texture.png')
130
+ obj_mtl_path = os.path.join(out_dir, 'texture.mtl')
131
+ glb_path = os.path.join(out_dir, 'mesh.glb')
132
 
133
  st = time.time()
134
 
 
205
  mesh = trimesh.load_mesh(obj_vertext_path)
206
  print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
207
  st = time.time()
208
+
209
  if not do_texture_mapping:
210
+ return obj_vertext_path, None
 
 
211
 
212
+ ###########################################################
213
+ #------------- export texture -----------------------
214
+ ###########################################################
215
 
216
  st = time.time()
217
 
 
237
 
238
  # Interpolate world space position
239
  gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
 
240
  with torch.no_grad():
241
  gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
 
242
  tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
 
243
  tex_map = tex_map.float().squeeze(0) # (0, 1)
244
  tex_map = tex_map.view((texture_res, texture_res, 3))
245
  img = uv_padding(tex_map, hole_mask)
 
253
  fid.write('newmtl material_0\n')
254
  fid.write("Ka 1.000 1.000 1.000\n")
255
  fid.write("Kd 1.000 1.000 1.000\n")
256
+ fid.write("Ks 0.500 0.500 0.500\n")
257
  fid.write("d 1.0\n")
258
  fid.write("illum 2\n")
259
  fid.write(f'map_Kd texture.png\n')
 
274
  mesh = trimesh.load_mesh(obj_path)
275
  mesh.export(glb_path, file_type='glb')
276
  print(f"=====> generate mesh with texture shading time: {time.time() - st}")
277
+ return obj_path, glb_path
278
 
svrm/ldm/modules/attention.py CHANGED
@@ -246,8 +246,11 @@ class CrossAttention(nn.Module):
246
  class FlashAttention(nn.Module):
247
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
248
  super().__init__()
249
- print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
250
- f"{heads} heads.")
 
 
 
251
  inner_dim = dim_head * heads
252
  context_dim = default(context_dim, query_dim)
253
  self.scale = dim_head ** -0.5
@@ -269,7 +272,12 @@ class FlashAttention(nn.Module):
269
  k = self.to_k(context).to(dtype)
270
  v = self.to_v(context).to(dtype)
271
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
272
- out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q
 
 
 
 
 
273
  out = rearrange(out, 'b n h d -> b n (h d)', h=h)
274
  return self.to_out(out.float())
275
 
@@ -277,8 +285,11 @@ class MemoryEfficientCrossAttention(nn.Module):
277
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
278
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
279
  super().__init__()
280
- print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
281
- f"{heads} heads.")
 
 
 
282
  inner_dim = dim_head * heads
283
  context_dim = default(context_dim, query_dim)
284
 
@@ -327,10 +338,12 @@ class BasicTransformerBlock(nn.Module):
327
  super().__init__()
328
  self.disable_self_attn = disable_self_attn
329
  self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
330
- context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
 
331
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
332
  self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
333
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
 
334
  self.norm1 = Fp32LayerNorm(dim)
335
  self.norm2 = Fp32LayerNorm(dim)
336
  self.norm3 = Fp32LayerNorm(dim)
@@ -451,7 +464,3 @@ class ImgToTriplaneTransformer(nn.Module):
451
  x = self.norm(x)
452
  return x
453
 
454
-
455
-
456
-
457
-
 
246
  class FlashAttention(nn.Module):
247
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
248
  super().__init__()
249
+ # print(
250
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
251
+ # "context_dim is {context_dim} and using "
252
+ # f"{heads} heads."
253
+ # )
254
  inner_dim = dim_head * heads
255
  context_dim = default(context_dim, query_dim)
256
  self.scale = dim_head ** -0.5
 
272
  k = self.to_k(context).to(dtype)
273
  v = self.to_v(context).to(dtype)
274
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
275
+ out = flash_attn_func(q, k, v,
276
+ dropout_p=self.dropout,
277
+ softmax_scale=None,
278
+ causal=False,
279
+ window_size=(-1, -1)
280
+ ) # out is same shape to q
281
  out = rearrange(out, 'b n h d -> b n (h d)', h=h)
282
  return self.to_out(out.float())
283
 
 
285
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
286
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
287
  super().__init__()
288
+ # print(
289
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
290
+ # "context_dim is {context_dim} and using "
291
+ # f"{heads} heads."
292
+ # )
293
  inner_dim = dim_head * heads
294
  context_dim = default(context_dim, query_dim)
295
 
 
338
  super().__init__()
339
  self.disable_self_attn = disable_self_attn
340
  self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
341
+ context_dim=context_dim if self.disable_self_attn else None)
342
+ # is a self-attention if not self.disable_self_attn
343
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
344
  self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
345
+ heads=n_heads, dim_head=d_head, dropout=dropout)
346
+ # is self-attn if context is none
347
  self.norm1 = Fp32LayerNorm(dim)
348
  self.norm2 = Fp32LayerNorm(dim)
349
  self.norm3 = Fp32LayerNorm(dim)
 
464
  x = self.norm(x)
465
  return x
466
 
 
 
 
 
svrm/ldm/vis_util.py CHANGED
@@ -27,10 +27,10 @@ from pytorch3d.renderer import (
27
  )
28
 
29
 
30
- def render(
31
  obj_filename,
32
  elev=0,
33
- azim=0,
34
  resolution=512,
35
  gif_dst_path='',
36
  n_views=120,
@@ -49,7 +49,7 @@ def render(
49
  mesh = load_objs_as_meshes([obj_filename], device=device)
50
  meshes = mesh.extend(n_views)
51
 
52
- if gif_dst_path != '':
53
  elev = torch.linspace(elev, elev, n_views+1)[:-1]
54
  azim = torch.linspace(0, 360, n_views+1)[:-1]
55
 
@@ -76,16 +76,15 @@ def render(
76
  )
77
  images = renderer(meshes)
78
 
79
- # single frame rendering
80
- if gif_dst_path == '':
81
- frame = images[0, ..., :3] if rgb else images[0, ...]
82
- frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
83
- return frame
 
 
 
 
 
 
84
 
85
- # orbit frames rendering
86
- with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
87
- for i in range(n_views):
88
- frame = images[i, ..., :3] if rgb else images[i, ...]
89
- frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
90
- writer.append_data(frame)
91
- return gif_dst_path
 
27
  )
28
 
29
 
30
+ def render_func(
31
  obj_filename,
32
  elev=0,
33
+ azim=None,
34
  resolution=512,
35
  gif_dst_path='',
36
  n_views=120,
 
49
  mesh = load_objs_as_meshes([obj_filename], device=device)
50
  meshes = mesh.extend(n_views)
51
 
52
+ if azim is None:
53
  elev = torch.linspace(elev, elev, n_views+1)[:-1]
54
  azim = torch.linspace(0, 360, n_views+1)[:-1]
55
 
 
76
  )
77
  images = renderer(meshes)
78
 
79
+ if gif_dst_path != '':
80
+ with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
81
+ for i in range(n_views):
82
+ frame = images[i, ..., :3] if rgb else images[i, ...]
83
+ frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
84
+ writer.append_data(frame)
85
+
86
+ frame = images[..., :3] if rgb else images
87
+ frames = [Image.fromarray((fra.cpu().squeeze(0) * 255).numpy().astype("uint8")) for fra in frame]
88
+ return frames
89
+
90
 
 
 
 
 
 
 
 
svrm/predictor.py CHANGED
@@ -33,7 +33,7 @@ from omegaconf import OmegaConf
33
  from torchvision import transforms
34
  from safetensors.torch import save_file, load_file
35
  from .ldm.util import instantiate_from_config
36
- from .ldm.vis_util import render
37
 
38
  class MV23DPredictor(object):
39
  def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
@@ -46,9 +46,7 @@ class MV23DPredictor(object):
46
  self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
47
  self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
48
 
49
- st = time.time()
50
  self.model = self.init_model(ckpt_path, cfg_path)
51
- print(f"=====> mv23d model init time: {time.time() - st}")
52
 
53
  self.input_view_transform = transforms.Compose([
54
  transforms.Resize(504, interpolation=Image.BICUBIC),
 
33
  from torchvision import transforms
34
  from safetensors.torch import save_file, load_file
35
  from .ldm.util import instantiate_from_config
36
+ from .ldm.vis_util import render_func
37
 
38
  class MV23DPredictor(object):
39
  def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
 
46
  self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
47
  self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
48
 
 
49
  self.model = self.init_model(ckpt_path, cfg_path)
 
50
 
51
  self.input_view_transform = transforms.Compose([
52
  transforms.Resize(504, interpolation=Image.BICUBIC),
third_party/check.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+
5
+ def check_bake_available():
6
+ is_ok = os.path.exists("./third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt/model.safetensors")
7
+ is_ok = is_ok and os.path.exists("./third_party/dust3r")
8
+ is_ok = is_ok and os.path.exists("./third_party/dust3r/dust3r")
9
+ is_ok = is_ok and os.path.exists("./third_party/dust3r/croco/models")
10
+ if is_ok:
11
+ print("Baking is avaliable")
12
+ print("Baking is avaliable")
13
+ print("Baking is avaliable")
14
+ else:
15
+ print("Baking is unavailable, please download related files in README")
16
+ print("Baking is unavailable, please download related files in README")
17
+ print("Baking is unavailable, please download related files in README")
18
+ return is_ok
19
+
20
+
21
+
22
+ if __name__ == "__main__":
23
+
24
+ check_bake_available()
25
+
third_party/dust3r_utils.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import io
3
+ import os
4
+ import cv2
5
+ import math
6
+ import numpy as np
7
+ from scipy.signal import medfilt
8
+ from scipy.spatial import KDTree
9
+ from matplotlib import pyplot as plt
10
+ from PIL import Image
11
+
12
+ from dust3r.inference import inference
13
+
14
+ from dust3r.utils.image import load_images# , resize_images
15
+ from dust3r.image_pairs import make_pairs
16
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
17
+ from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
18
+
19
+ from third_party.utils.camera_utils import remap_points
20
+ from third_party.utils.img_utils import rgba_to_rgb, resize_with_aspect_ratio
21
+ from third_party.utils.img_utils import compute_img_diff
22
+
23
+ from PIL.ImageOps import exif_transpose
24
+ import torchvision.transforms as tvf
25
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
26
+
27
+
28
+ def suppress_output(func):
29
+ def wrapper(*args, **kwargs):
30
+ original_stdout = sys.stdout
31
+ original_stderr = sys.stderr
32
+ sys.stdout = io.StringIO()
33
+ sys.stderr = io.StringIO()
34
+ try:
35
+ return func(*args, **kwargs)
36
+ finally:
37
+ sys.stdout = original_stdout
38
+ sys.stderr = original_stderr
39
+ return wrapper
40
+
41
+ def _resize_pil_image(img, long_edge_size):
42
+ S = max(img.size)
43
+ if S > long_edge_size:
44
+ interp = Image.LANCZOS
45
+ elif S <= long_edge_size:
46
+ interp = Image.BICUBIC
47
+ new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
48
+ return img.resize(new_size, interp)
49
+
50
+ def resize_images(imgs_list, size, square_ok=False):
51
+ """ open and convert all images in a list or folder to proper input format for DUSt3R
52
+ """
53
+ imgs = []
54
+ for img in imgs_list:
55
+ img = exif_transpose(Image.fromarray(img)).convert('RGB')
56
+ W1, H1 = img.size
57
+ if size == 224:
58
+ # resize short side to 224 (then crop)
59
+ img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
60
+ else:
61
+ # resize long side to 512
62
+ img = _resize_pil_image(img, size)
63
+ W, H = img.size
64
+ cx, cy = W//2, H//2
65
+ if size == 224:
66
+ half = min(cx, cy)
67
+ img = img.crop((cx-half, cy-half, cx+half, cy+half))
68
+ else:
69
+ halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
70
+ if not (square_ok) and W == H:
71
+ halfh = 3*halfw/4
72
+ img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
73
+
74
+ W2, H2 = img.size
75
+ imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
76
+ [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
77
+
78
+ return imgs
79
+
80
+ @suppress_output
81
+ def infer_match(images, model, vis=False, niter=300, lr=0.01, schedule='cosine', device="cuda:0"):
82
+ batch_size = 1
83
+ schedule = 'cosine'
84
+ lr = 0.01
85
+ niter = 300
86
+
87
+ images_packed = resize_images(images, size=512, square_ok=True)
88
+ # images_packed = images
89
+
90
+ pairs = make_pairs(images_packed, scene_graph='complete', prefilter=None, symmetrize=True)
91
+ output = inference(pairs, model, device, batch_size=batch_size, verbose=False)
92
+
93
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
94
+ loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
95
+
96
+ # retrieve useful values from scene:
97
+ imgs = scene.imgs
98
+ # focals = scene.get_focals()
99
+ # poses = scene.get_im_poses()
100
+ pts3d = scene.get_pts3d()
101
+ confidence_masks = scene.get_masks()
102
+
103
+ # visualize reconstruction
104
+ # scene.show()
105
+
106
+ # find 2D-2D matches between the two images
107
+ pts2d_list, pts3d_list = [], []
108
+ for i in range(2):
109
+ conf_i = confidence_masks[i].cpu().numpy()
110
+ pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
111
+ pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
112
+ if pts3d_list[-1].shape[0] == 0:
113
+ return np.zeros((0, 2)), np.zeros((0, 2))
114
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
115
+
116
+ matches_im1 = pts2d_list[1][reciprocal_in_P2]
117
+ matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
118
+
119
+ # visualize a few matches
120
+ if vis == True:
121
+ print(f'found {num_matches} matches')
122
+ n_viz = 20
123
+ match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
124
+ viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
125
+
126
+ H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
127
+ img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
128
+ img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
129
+ img = np.concatenate((img0, img1), axis=1)
130
+ plt.figure()
131
+ plt.imshow(img)
132
+ cmap = plt.get_cmap('jet')
133
+ for i in range(n_viz):
134
+ (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
135
+ plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
136
+ plt.show(block=True)
137
+
138
+ matches_im0 = remap_points(images[0].shape, matches_im0)
139
+ matches_im1 = remap_points(images[1].shape, matches_im1)
140
+ return matches_im0, matches_im1
141
+
142
+
143
+ def point_transform(H, pt):
144
+ """
145
+ @param: H is homography matrix of dimension (3x3)
146
+ @param: pt is the (x, y) point to be transformed
147
+
148
+ Return:
149
+ returns a transformed point ptrans = H*pt.
150
+ """
151
+ a = H[0, 0] * pt[0] + H[0, 1] * pt[1] + H[0, 2]
152
+ b = H[1, 0] * pt[0] + H[1, 1] * pt[1] + H[1, 2]
153
+ c = H[2, 0] * pt[0] + H[2, 1] * pt[1] + H[2, 2]
154
+ return [a / c, b / c]
155
+
156
+
157
+ def points_transform(H, pt_x, pt_y):
158
+ """
159
+ @param: H is homography matrix of dimension (3x3)
160
+ @param: pt is the (x, y) point to be transformed
161
+
162
+ Return:
163
+ returns a transformed point ptrans = H*pt.
164
+ """
165
+ a = H[0, 0] * pt_x + H[0, 1] * pt_y + H[0, 2]
166
+ b = H[1, 0] * pt_x + H[1, 1] * pt_y + H[1, 2]
167
+ c = H[2, 0] * pt_x + H[2, 1] * pt_y + H[2, 2]
168
+ return (a / c, b / c)
169
+
170
+
171
+ def motion_propagate(old_points, new_points, old_size, new_size, H_size=(21, 21)):
172
+ """
173
+ @param: old_points are points in old_frame that are
174
+ matched feature points with new_frame
175
+ @param: new_points are points in new_frame that are
176
+ matched feature points with old_frame
177
+ @param: old_frame is the frame to which
178
+ motion mesh needs to be obtained
179
+ @param: H is the homography between old and new points
180
+
181
+ Return:
182
+ returns a motion mesh in x-direction
183
+ and y-direction for old_frame
184
+ """
185
+ # spreads motion over the mesh for the old_frame
186
+ x_motion = np.zeros(H_size)
187
+ y_motion = np.zeros(H_size)
188
+ mesh_x_num, mesh_y_num = H_size[0], H_size[1]
189
+ pixels_x, pixels_y = (old_size[1]) / (mesh_x_num - 1), (old_size[0]) / (mesh_y_num - 1)
190
+ radius = max(pixels_x, pixels_y) * 5
191
+ sigma = radius / 3.0
192
+
193
+ H_global = None
194
+ if old_points.shape[0] > 3:
195
+ # pre-warping with global homography
196
+ H_global, _ = cv2.findHomography(old_points, new_points, cv2.RANSAC)
197
+ if H_global is None:
198
+ old_tmp = np.array([[0, 0], [0, old_size[0]], [old_size[1], 0], [old_size[1], old_size[0]]])
199
+ new_tmp = np.array([[0, 0], [0, new_size[0]], [new_size[1], 0], [new_size[1], new_size[0]]])
200
+ H_global, _ = cv2.findHomography(old_tmp, new_tmp, cv2.RANSAC)
201
+
202
+ for i in range(mesh_x_num):
203
+ for j in range(mesh_y_num):
204
+ pt = [pixels_x * i, pixels_y * j]
205
+ ptrans = point_transform(H_global, pt)
206
+ x_motion[i, j] = ptrans[0]
207
+ y_motion[i, j] = ptrans[1]
208
+
209
+ # disturbute feature motion vectors
210
+ weighted_move_x = np.zeros(H_size)
211
+ weighted_move_y = np.zeros(H_size)
212
+ # 构建 KDTree
213
+ tree = KDTree(old_points)
214
+ # 计算权重和移动值
215
+ for i in range(mesh_x_num):
216
+ for j in range(mesh_y_num):
217
+ vertex = [pixels_x * i, pixels_y * j]
218
+ neighbor_indices = tree.query_ball_point(vertex, radius, workers=-1)
219
+ if len(neighbor_indices) > 0:
220
+ pts = old_points[neighbor_indices]
221
+ sts = new_points[neighbor_indices]
222
+ ptrans_x, ptrans_y = points_transform(H_global, pts[:, 0], pts[:, 1])
223
+ moves_x = sts[:, 0] - ptrans_x
224
+ moves_y = sts[:, 1] - ptrans_y
225
+
226
+ dists = np.sqrt((vertex[0] - pts[:, 0]) ** 2 + (vertex[1] - pts[:, 1]) ** 2)
227
+ weights_x = np.exp(-(dists ** 2) / (2 * sigma ** 2))
228
+ weights_y = np.exp(-(dists ** 2) / (2 * sigma ** 2))
229
+
230
+ weighted_move_x[i, j] = np.sum(weights_x * moves_x) / (np.sum(weights_x) + 0.1)
231
+ weighted_move_y[i, j] = np.sum(weights_y * moves_y) / (np.sum(weights_y) + 0.1)
232
+
233
+ x_motion_mesh = x_motion + weighted_move_x
234
+ y_motion_mesh = y_motion + weighted_move_y
235
+ '''
236
+ # apply median filter (f-1) on obtained motion for each vertex
237
+ x_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
238
+ y_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
239
+
240
+ for key in x_motion.keys():
241
+ try:
242
+ temp_x_motion[key].sort()
243
+ x_motion_mesh[key] = x_motion[key]+temp_x_motion[key][len(temp_x_motion[key])//2]
244
+ except KeyError:
245
+ x_motion_mesh[key] = x_motion[key]
246
+ try:
247
+ temp_y_motion[key].sort()
248
+ y_motion_mesh[key] = y_motion[key]+temp_y_motion[key][len(temp_y_motion[key])//2]
249
+ except KeyError:
250
+ y_motion_mesh[key] = y_motion[key]
251
+
252
+ # apply second median filter (f-2) over the motion mesh for outliers
253
+ #x_motion_mesh = medfilt(x_motion_mesh, kernel_size=[3, 3])
254
+ #y_motion_mesh = medfilt(y_motion_mesh, kernel_size=[3, 3])
255
+ '''
256
+ return x_motion_mesh, y_motion_mesh
257
+
258
+
259
+ def mesh_warp_points(points, x_motion_mesh, y_motion_mesh, img_size):
260
+ ptrans = []
261
+ mesh_x_num, mesh_y_num = x_motion_mesh.shape
262
+ pixels_x, pixels_y = (img_size[1]) / (mesh_x_num - 1), (img_size[0]) / (mesh_y_num - 1)
263
+ for pt in points:
264
+ i = int(pt[0] // pixels_x)
265
+ j = int(pt[1] // pixels_y)
266
+ src = [[i * pixels_x, j * pixels_y],
267
+ [(i + 1) * pixels_x, j * pixels_y],
268
+ [i * pixels_x, (j + 1) * pixels_y],
269
+ [(i + 1) * pixels_x, (j + 1) * pixels_y]]
270
+ src = np.asarray(src)
271
+ dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
272
+ [x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
273
+ [x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
274
+ [x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
275
+ dst = np.asarray(dst)
276
+ H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
277
+ x, y = points_transform(H, pt[0], pt[1])
278
+ ptrans.append([x, y])
279
+
280
+ return np.array(ptrans)
281
+
282
+
283
+ def mesh_warp_frame(frame, x_motion_mesh, y_motion_mesh, resize):
284
+ """
285
+ @param: frame is the current frame
286
+ @param: x_motion_mesh is the motion_mesh to
287
+ be warped on frame along x-direction
288
+ @param: y_motion_mesh is the motion mesh to
289
+ be warped on frame along y-direction
290
+ @param: resize is the desired output size (tuple of width, height)
291
+
292
+ Returns:
293
+ returns a mesh warped frame according
294
+ to given motion meshes x_motion_mesh,
295
+ y_motion_mesh, resized to the specified size
296
+ """
297
+
298
+ map_x = np.zeros(resize, np.float32)
299
+ map_y = np.zeros(resize, np.float32)
300
+
301
+ mesh_x_num, mesh_y_num = x_motion_mesh.shape
302
+ pixels_x, pixels_y = (resize[1]) / (mesh_x_num - 1), (resize[0]) / (mesh_y_num - 1)
303
+
304
+ for i in range(mesh_x_num - 1):
305
+ for j in range(mesh_y_num - 1):
306
+ src = [[i * pixels_x, j * pixels_y],
307
+ [(i + 1) * pixels_x, j * pixels_y],
308
+ [i * pixels_x, (j + 1) * pixels_y],
309
+ [(i + 1) * pixels_x, (j + 1) * pixels_y]]
310
+ src = np.asarray(src)
311
+
312
+ dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
313
+ [x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
314
+ [x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
315
+ [x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
316
+ dst = np.asarray(dst)
317
+ H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
318
+
319
+ start_x = math.ceil(pixels_x * i)
320
+ end_x = math.ceil(pixels_x * (i + 1))
321
+ start_y = math.ceil(pixels_y * j)
322
+ end_y = math.ceil(pixels_y * (j + 1))
323
+
324
+ x, y = np.meshgrid(range(start_x, end_x), range(start_y, end_y), indexing='ij')
325
+
326
+ map_x[y, x], map_y[y, x] = points_transform(H, x, y)
327
+
328
+ # deforms mesh and directly outputs the resized frame
329
+ resized_frame = cv2.remap(frame, map_x, map_y,
330
+ interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
331
+ borderValue=(255, 255, 255))
332
+ return resized_frame
333
+
334
+
335
+ def infer_warp_mesh_img(src, dst, model, vis=False):
336
+ if isinstance(src, str):
337
+ image1 = cv2.imread(src, cv2.IMREAD_UNCHANGED)
338
+ image2 = cv2.imread(dst, cv2.IMREAD_UNCHANGED)
339
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
340
+ image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
341
+ elif isinstance(src, Image.Image):
342
+ image1 = np.array(src)
343
+ image2 = np.array(dst)
344
+ else:
345
+ assert isinstance(src, np.ndarray)
346
+
347
+ image1 = rgba_to_rgb(image1)
348
+ image2 = rgba_to_rgb(image2)
349
+
350
+ image1_padded = resize_with_aspect_ratio(image1, image2)
351
+ resized_image1 = cv2.resize(image1_padded, (image2.shape[1], image2.shape[0]), interpolation=cv2.INTER_AREA)
352
+
353
+ matches_im0, matches_im1 = infer_match([resized_image1, image2], model, vis=vis)
354
+ matches_im0 = matches_im0 * image1_padded.shape[0] / resized_image1.shape[0]
355
+
356
+ # print('Estimate Mesh Grid')
357
+ mesh_x, mesh_y = motion_propagate(matches_im1, matches_im0, image2.shape[:2], image1_padded.shape[:2])
358
+
359
+ aligned_image = mesh_warp_frame(image1_padded, mesh_x, mesh_y, (image2.shape[0], image2.shape[1]))
360
+
361
+ matches_im0_from_im1 = mesh_warp_points(matches_im1, mesh_x, mesh_y, (image2.shape[1], image2.shape[0]))
362
+
363
+ info = compute_img_diff(aligned_image, image2, matches_im0, matches_im0_from_im1, vis=vis)
364
+
365
+ return aligned_image, info
366
+
third_party/gen_baking.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, time
2
+ from typing import List, Optional
3
+ from iopath.common.file_io import PathManager
4
+
5
+ import cv2
6
+ import imageio
7
+ import numpy as np
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torchvision import transforms
14
+
15
+ import trimesh
16
+ from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
17
+ from pytorch3d.ops import interpolate_face_attributes
18
+ from pytorch3d.common.datatypes import Device
19
+ from pytorch3d.structures import Meshes
20
+ from pytorch3d.renderer import (
21
+ look_at_view_transform,
22
+ FoVPerspectiveCameras,
23
+ PointLights,
24
+ DirectionalLights,
25
+ AmbientLights,
26
+ Materials,
27
+ RasterizationSettings,
28
+ MeshRenderer,
29
+ MeshRasterizer,
30
+ SoftPhongShader,
31
+ TexturesUV,
32
+ TexturesVertex,
33
+ camera_position_from_spherical_angles,
34
+ BlendParams,
35
+ )
36
+
37
+
38
+ def erode_mask(src_mask, p=1 / 20.0):
39
+ monoMaskImage = cv2.split(src_mask)[0]
40
+ br = cv2.boundingRect(monoMaskImage)
41
+ k = int(min(br[2], br[3]) * p)
42
+ kernel = np.ones((k, k), dtype=np.uint8)
43
+ dst_mask = cv2.erode(src_mask, kernel, 1)
44
+ return dst_mask
45
+
46
+ def load_objs_as_meshes_fast(
47
+ verts,
48
+ faces,
49
+ aux,
50
+ device: Optional[Device] = None,
51
+ load_textures: bool = True,
52
+ create_texture_atlas: bool = False,
53
+ texture_atlas_size: int = 4,
54
+ texture_wrap: Optional[str] = "repeat",
55
+ path_manager: Optional[PathManager] = None,
56
+ ):
57
+ tex = None
58
+ if create_texture_atlas:
59
+ # TexturesAtlas type
60
+ tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
61
+ else:
62
+ # TexturesUV type
63
+ tex_maps = aux.texture_images
64
+ if tex_maps is not None and len(tex_maps) > 0:
65
+ verts_uvs = aux.verts_uvs.to(device) # (V, 2)
66
+ faces_uvs = faces.textures_idx.to(device) # (F, 3)
67
+ image = list(tex_maps.values())[0].to(device)[None]
68
+ tex = TexturesUV(verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image)
69
+ mesh = Meshes( verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex)
70
+ return mesh
71
+
72
+
73
+ def get_triangle_to_triangle(tri_1, tri_2, img_refined):
74
+ '''
75
+ args:
76
+ tri_1:
77
+ tri_2:
78
+ '''
79
+ r1 = cv2.boundingRect(tri_1)
80
+ r2 = cv2.boundingRect(tri_2)
81
+
82
+ tri_1_cropped = []
83
+ tri_2_cropped = []
84
+ for i in range(0, 3):
85
+ tri_1_cropped.append(((tri_1[i][1] - r1[1]), (tri_1[i][0] - r1[0])))
86
+ tri_2_cropped.append(((tri_2[i][1] - r2[1]), (tri_2[i][0] - r2[0])))
87
+
88
+ trans = cv2.getAffineTransform(np.float32(tri_1_cropped), np.float32(tri_2_cropped))
89
+
90
+ img_1_cropped = np.float32(img_refined[r1[0]:r1[0] + r1[2], r1[1]:r1[1] + r1[3]])
91
+
92
+ mask = np.zeros((r2[2], r2[3], 3), dtype=np.float32)
93
+
94
+ cv2.fillConvexPoly(mask, np.int32(tri_2_cropped), (1.0, 1.0, 1.0), 16, 0)
95
+
96
+ img_2_cropped = cv2.warpAffine(
97
+ img_1_cropped, trans, (r2[3], r2[2]), None,
98
+ flags = cv2.INTER_LINEAR,
99
+ borderMode = cv2.BORDER_REFLECT_101
100
+ )
101
+ return mask, img_2_cropped, r2
102
+
103
+
104
+ def back_projection(
105
+ obj_file,
106
+ init_texture_file,
107
+ front_view_file,
108
+ dst_dir,
109
+ render_resolution=512,
110
+ uv_resolution=600,
111
+ normalThreshold=0.3, # 0.3
112
+ rgb_thresh=820, # 520
113
+ views=None,
114
+ camera_dist=1.5,
115
+ erode_scale=1/100.0,
116
+ device="cuda:0"
117
+ ):
118
+ # obj_file: 带有uv的obj
119
+ # init_texture_file: 初始展开的uv贴图
120
+ # render_resolution 正面视角渲染分辨率
121
+ # uv_resolution 贴图分辨率
122
+ # thres:normal threshold
123
+
124
+ os.makedirs(dst_dir, exist_ok=True)
125
+
126
+ if isinstance(front_view_file, str):
127
+ src = np.array(Image.open(front_view_file).convert("RGB"))
128
+ elif isinstance(front_view_file, Image.Image):
129
+ src = np.array(front_view_file.convert("RGB"))
130
+ else:
131
+ raise "need file_path or pil"
132
+
133
+ image_size = (render_resolution, render_resolution)
134
+
135
+ init_texture = Image.open(init_texture_file)
136
+ init_texture = init_texture.convert("RGB")
137
+ # init_texture = init_texture.resize((uv_resolution, uv_resolution))
138
+ init_texture = np.array(init_texture).astype(np.float32)
139
+
140
+ print("load obj", obj_file)
141
+ verts, faces, aux = load_obj(obj_file, device=device)
142
+ mesh = load_objs_as_meshes_fast(verts, faces, aux, device=device)
143
+
144
+
145
+ t0 = time.time()
146
+ verts_uvs = aux.verts_uvs
147
+ triangle_uvs = verts_uvs[faces.textures_idx]
148
+ triangle_uvs = torch.cat([
149
+ ((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
150
+ (triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
151
+ ], dim=-1)
152
+ triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
153
+
154
+ # import ipdb;ipdb.set_trace()
155
+
156
+
157
+ R0, T0 = look_at_view_transform(camera_dist, views[0][0], views[0][1])
158
+
159
+ cameras = FoVPerspectiveCameras(device=device, R=R0, T=T0, fov=49.1)
160
+
161
+ camera_normal = camera_position_from_spherical_angles(1, views[0][0], views[0][1]).to(device)
162
+ screen_coords = cameras.transform_points_screen(verts, image_size=image_size)[:, :2]
163
+ screen_coords = torch.cat([screen_coords[..., 1, None], screen_coords[..., 0, None]], dim=-1)
164
+ triangle_screen_coords = np.round(np.float32(screen_coords[faces.verts_idx].cpu())) # numpy.ndarray (90000, 3, 2)
165
+ triangle_screen_coords = np.clip(triangle_screen_coords.astype(np.int64), 0, render_resolution-1)
166
+
167
+ renderer = MeshRenderer(
168
+ rasterizer=MeshRasterizer(
169
+ cameras=cameras,
170
+ raster_settings= RasterizationSettings(
171
+ image_size=image_size,
172
+ blur_radius=0.0,
173
+ faces_per_pixel=1,
174
+ ),
175
+ ),
176
+ shader=SoftPhongShader(
177
+ device=device,
178
+ cameras=cameras,
179
+ lights= AmbientLights(device=device),
180
+ blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
181
+ )
182
+ )
183
+
184
+ dst = renderer(mesh)
185
+ dst = (dst[..., :3] * 255).squeeze(0).cpu().numpy().astype(np.uint8)
186
+
187
+ src_mask = np.ones((src.shape[0], src.shape[1]), dst.dtype)
188
+ ids = np.where(dst.sum(-1) > 253 * 3)
189
+ ids2 = np.where(src.sum(-1) > 250 * 3)
190
+ src_mask[ids[0], ids[1]] = 0
191
+ src_mask[ids2[0], ids2[1]] = 0
192
+ src_mask = (src_mask > 0).astype(np.uint8) * 255
193
+
194
+ monoMaskImage = cv2.split(src_mask)[0] # reducing the mask to a monochrome
195
+ br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
196
+ center = (br[0] + br[2] // 2, br[1] + br[3] // 2)
197
+
198
+ # seamlessClone
199
+ try:
200
+ images = cv2.seamlessClone(src, dst, src_mask, center, cv2.NORMAL_CLONE) # more qingxi
201
+ # images = cv2.seamlessClone(src, dst, src_mask, center, cv2.MIXED_CLONE)
202
+ except Exception as err:
203
+ print(f"\n\n Warning seamlessClone error: {err} \n\n")
204
+ images = src
205
+
206
+ Image.fromarray(src_mask).save(os.path.join(dst_dir, 'mask.jpeg'))
207
+ Image.fromarray(src).save(os.path.join(dst_dir, 'src.jpeg'))
208
+ Image.fromarray(dst).save(os.path.join(dst_dir, 'dst.jpeg'))
209
+ Image.fromarray(images).save(os.path.join(dst_dir, 'blend.jpeg'))
210
+
211
+ fragments_scaled = renderer.rasterizer(mesh) # pytorch3d.renderer.mesh.rasterizer.Fragments
212
+ faces_covered = fragments_scaled.pix_to_face.unique()[1:] # torch.Tensor torch.Size([30025])
213
+ face_normals = mesh.faces_normals_packed().to(device) # torch.Tensor torch.Size([90000, 3]) cuda:0
214
+
215
+ # faces: pytorch3d.io.obj_io.Faces
216
+ # faces.textures_idx: torch.Tensor torch.Size([90000, 3])
217
+ # verts_uvs: torch.Tensor torch.Size([49554, 2])
218
+ triangle_uvs = verts_uvs[faces.textures_idx]
219
+ triangle_uvs = [
220
+ ((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
221
+ (triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
222
+ ]
223
+ triangle_uvs = torch.cat(triangle_uvs, dim=-1) # numpy.ndarray (90000, 3, 2)
224
+ triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
225
+
226
+ t0 = time.time()
227
+
228
+
229
+ SOFT_NORM = True # process big angle-diff face, true:flase? coeff:skip
230
+
231
+ for k in faces_covered:
232
+ # todo: accelerate this for-loop
233
+ # if cosine between face-camera is too low, skip current face baking
234
+ face_normal = face_normals[k]
235
+ cosine = torch.sum((face_normal * camera_normal) ** 2)
236
+ if not SOFT_NORM and cosine < normalThreshold: continue
237
+
238
+ # if coord in screen out of subject, skip current face baking
239
+ out_of_subject = src_mask[triangle_screen_coords[k][0][0], triangle_screen_coords[k][0][1]]==0
240
+ if out_of_subject: continue
241
+
242
+ coeff, img_2_cropped, r2 = get_triangle_to_triangle(triangle_screen_coords[k], triangle_uvs[k], images)
243
+
244
+ # if color difference between new-old, skip current face baking
245
+ err = np.abs(init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]]- img_2_cropped)
246
+ err = (err * coeff).sum(-1)
247
+
248
+ # print(err.shape, np.max(err))
249
+ if (np.max(err) > rgb_thresh): continue
250
+
251
+ color_for_debug = None
252
+ # if (np.max(err) > 400): color_for_debug = [255, 0, 0]
253
+ # if (np.max(err) > 450): color_for_debug = [0, 255, 0]
254
+ # if (np.max(err) > 500): color_for_debug = [0, 0, 255]
255
+
256
+ coeff = coeff.clip(0, 1)
257
+
258
+ if SOFT_NORM:
259
+ coeff *= ((cosine.detach().cpu().numpy() - normalThreshold) / normalThreshold).clip(0,1)
260
+
261
+ coeff *= (((rgb_thresh - err[...,None]) / rgb_thresh)**0.4).clip(0,1)
262
+
263
+ if color_for_debug is None:
264
+ init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = \
265
+ init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] * ((1.0,1.0,1.0)-coeff) + img_2_cropped * coeff
266
+ else:
267
+ init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = color_for_debug
268
+
269
+ print(f'View baking time: {time.time() - t0}')
270
+
271
+ bake_dir = os.path.join(dst_dir, 'bake')
272
+ os.makedirs(bake_dir, exist_ok=True)
273
+ os.system(f'cp {obj_file} {bake_dir}')
274
+
275
+ textute_img = Image.fromarray(init_texture.astype(np.uint8))
276
+ textute_img.save(os.path.join(bake_dir, init_texture_file.split("/")[-1]))
277
+
278
+ mtl_dir = obj_file.replace('.obj', '.mtl')
279
+ if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"material.mtl")
280
+ if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"texture.mtl")
281
+ if not os.path.exists(mtl_dir): import ipdb;ipdb.set_trace()
282
+ os.system(f'cp {mtl_dir} {bake_dir}')
283
+
284
+ # convert .obj to .glb file
285
+ new_obj_pth = os.path.join(bake_dir, obj_file.split('/')[-1])
286
+ new_glb_path = new_obj_pth.replace('.obj', '.glb')
287
+ mesh = trimesh.load_mesh(new_obj_pth)
288
+ mesh.export(new_glb_path, file_type='glb')
third_party/mesh_baker.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, time, traceback
2
+ print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
3
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image, ImageSequence
8
+ from einops import rearrange
9
+ import torch
10
+
11
+ from infer.utils import seed_everything, timing_decorator
12
+ from infer.utils import get_parameter_number, set_parameter_grad_false
13
+
14
+ from dust3r.inference import inference
15
+ from dust3r.model import AsymmetricCroCo3DStereo
16
+
17
+ from third_party.gen_baking import back_projection
18
+ from third_party.dust3r_utils import infer_warp_mesh_img
19
+ from svrm.ldm.vis_util import render_func
20
+
21
+
22
+ class MeshBaker:
23
+ def __init__(
24
+ self,
25
+ align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
26
+ device = "cuda:0",
27
+ align_times = 1,
28
+ iou_thresh = 0.8,
29
+ force_baking_ele_list = None,
30
+ save_memory = False
31
+ ):
32
+ self.device = device
33
+ self.save_memory = save_memory
34
+ self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
35
+ self.align_model = self.align_model if save_memory else self.align_model.to(device)
36
+ self.align_times = align_times
37
+ self.align_model.eval()
38
+ self.iou_thresh = iou_thresh
39
+ self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
40
+ self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
41
+ set_parameter_grad_false(self.align_model)
42
+ print('baking align model', get_parameter_number(self.align_model))
43
+
44
+ def align_and_check(self, src, dst, align_times=3):
45
+ try:
46
+ st = time.time()
47
+ best_baking_flag = False
48
+ best_aligned_image = aligned_image = src
49
+ best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
50
+ for i in range(align_times):
51
+ aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
52
+ aligned_image = Image.fromarray(aligned_image)
53
+ print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
54
+ if info['mask_iou'] > best_info['mask_iou']:
55
+ best_aligned_image, best_info = aligned_image, info
56
+ if info['mask_iou'] < self.iou_thresh:
57
+ break
58
+ print(f"Best Baking Info:{best_info['mask_iou']}")
59
+ best_baking_flag = best_info['mask_iou'] > self.iou_thresh
60
+ return best_aligned_image, best_info, best_baking_flag
61
+ except Exception as e:
62
+ print(f"Error processing image: {e}")
63
+ traceback.print_exc()
64
+ return None, None, None
65
+
66
+ @timing_decorator("baking mesh")
67
+ def __call__(self, *args, **kwargs):
68
+ if self.save_memory:
69
+ self.align_model = self.align_model.to(self.device)
70
+ torch.cuda.empty_cache()
71
+ res = self.call(*args, **kwargs)
72
+ self.align_model = self.align_model.to("cpu")
73
+ else:
74
+ res = self.call(*args, **kwargs)
75
+ torch.cuda.empty_cache()
76
+ return res
77
+
78
+ def call(self, save_folder):
79
+ obj_path = os.path.join(save_folder, "mesh.obj")
80
+ raw_texture_path = os.path.join(save_folder, "texture.png")
81
+ views_pil = os.path.join(save_folder, "views.jpg")
82
+ views_gif = os.path.join(save_folder, "views.gif")
83
+ cond_pil = os.path.join(save_folder, "img_nobg.png")
84
+
85
+ if os.path.exists(views_pil):
86
+ views_pil = Image.open(views_pil)
87
+ views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
88
+ views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]]
89
+ cond_pil = Image.open(cond_pil).resize((512,512))
90
+ elif os.path.exists(views_gif):
91
+ views_gif_pil = Image.open(views_gif)
92
+ views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
93
+ cond_pil, views = views[0], views[1:]
94
+ else:
95
+ raise FileNotFoundError("views file not found")
96
+
97
+ rendered_views = render_func(obj_path, elev=0, n_views=2)
98
+
99
+ for ele_idx, ele in enumerate([0, 180]):
100
+
101
+ if ele == 0:
102
+ aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
103
+ aligned_cond.save(save_folder + f'/aligned_cond.jpg')
104
+
105
+ aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
106
+ aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
107
+
108
+ if info['mask_iou'] < cond_info['mask_iou']:
109
+ print("Using Cond Image to bake front view")
110
+ aligned_img = aligned_cond
111
+ info = cond_info
112
+ need_baking = info['mask_iou'] > self.iou_thresh
113
+ else:
114
+ aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
115
+ aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
116
+
117
+ if need_baking or (ele in self.force_baking_ele_list):
118
+ st = time.time()
119
+ view1_res = back_projection(
120
+ obj_file = obj_path,
121
+ init_texture_file = raw_texture_path,
122
+ front_view_file = aligned_img,
123
+ dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
124
+ render_resolution = aligned_img.size[0],
125
+ uv_resolution = 1024,
126
+ views = [[0, ele]],
127
+ device = self.device
128
+ )
129
+ print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
130
+ obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
131
+ raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
132
+ else:
133
+ print(f"Skip view_{ele_idx} elevation_{ele} baking")
134
+
135
+ print("Baking Finished")
136
+ return obj_path
137
+
138
+
139
+ if __name__ == "__main__":
140
+ baker = MeshBaker()
141
+ obj_path = baker("./outputs/test")
142
+ print(obj_path)
third_party/utils/camera_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ def compute_extrinsic_matrix(elevation, azimuth, camera_distance):
5
+ # Convert angles to radians
6
+ elevation_rad = np.radians(elevation)
7
+ azimuth_rad = np.radians(azimuth)
8
+
9
+ R = np.array([
10
+ [np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)],
11
+ [0, 1, 0],
12
+ [np.sin(azimuth_rad), 0, np.cos(azimuth_rad)],
13
+ ], dtype=np.float32)
14
+
15
+ R = R @ np.array([
16
+ [1, 0, 0],
17
+ [0, np.cos(elevation_rad), -np.sin(elevation_rad)],
18
+ [0, np.sin(elevation_rad), np.cos(elevation_rad)]
19
+ ], dtype=np.float32)
20
+
21
+ # Construct translation matrix T (3x1)
22
+ T = np.array([[camera_distance], [0], [0]], dtype=np.float32)
23
+ T = R @ T
24
+
25
+ # Combined into a 4x4 transformation matrix
26
+ extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32)))
27
+
28
+ return extrinsic_matrix
29
+
30
+
31
+ def transform_camera_pose(im_pose, ori_pose, new_pose):
32
+ T = new_pose @ ori_pose.T
33
+ transformed_poses = []
34
+
35
+ for pose in im_pose:
36
+ transformed_pose = T @ pose
37
+ transformed_poses.append(transformed_pose)
38
+
39
+ return transformed_poses
40
+
41
+ def compute_fov(intrinsic_matrix):
42
+ # Get the focal length value in the internal parameter matrix
43
+ fx = intrinsic_matrix[0, 0]
44
+ fy = intrinsic_matrix[1, 1]
45
+
46
+ h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2
47
+
48
+ # Calculate horizontal and vertical FOV values
49
+ fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi
50
+ fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi
51
+
52
+ return fov_x, fov_y
53
+
54
+
55
+
56
+ def rotation_matrix_to_quaternion(rotation_matrix):
57
+ rot = Rotation.from_matrix(rotation_matrix)
58
+ quaternion = rot.as_quat()
59
+ return quaternion
60
+
61
+ def quaternion_to_rotation_matrix(quaternion):
62
+ rot = Rotation.from_quat(quaternion)
63
+ rotation_matrix = rot.as_matrix()
64
+ return rotation_matrix
65
+
66
+ def remap_points(img_size, match, size=512):
67
+ H, W, _ = img_size
68
+
69
+ S = max(W, H)
70
+ new_W = int(round(W * size / S))
71
+ new_H = int(round(H * size / S))
72
+ cx, cy = new_W // 2, new_H // 2
73
+
74
+ # Calculate the coordinates of the transformed image center point
75
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
76
+
77
+ dw, dh = cx - halfw, cy - halfh
78
+
79
+ # store point coordinates mapped back to the original image
80
+ new_match = np.zeros_like(match)
81
+
82
+ # Map the transformed point coordinates back to the original image
83
+ new_match[:, 0] = (match[:, 0] + dw) / new_W * W
84
+ new_match[:, 1] = (match[:, 1] + dh) / new_H * H
85
+
86
+ #print(dw,new_W,W,dh,new_H,H)
87
+
88
+ return new_match
89
+
90
+
third_party/utils/img_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from skimage.metrics import hausdorff_distance
5
+ from matplotlib import pyplot as plt
6
+
7
+
8
+ def get_input_imgs_path(input_data_dir):
9
+ path = {}
10
+ names = ['000', 'ori_000']
11
+ for name in names:
12
+ jpg_path = os.path.join(input_data_dir, f"{name}.jpg")
13
+ png_path = os.path.join(input_data_dir, f"{name}.png")
14
+ if os.path.exists(jpg_path):
15
+ path[name] = jpg_path
16
+ elif os.path.exists(png_path):
17
+ path[name] = png_path
18
+ return path
19
+
20
+
21
+ def rgba_to_rgb(image, bg_color=[255, 255, 255]):
22
+ if image.shape[-1] == 3: return image
23
+
24
+ rgba = image.astype(float)
25
+ rgb = rgba[:, :, :3].copy()
26
+ alpha = rgba[:, :, 3] / 255.0
27
+
28
+ bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32)
29
+ bg = bg * np.array(bg_color, dtype=np.float32)
30
+
31
+ rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis])
32
+ rgb = rgb.astype(np.uint8)
33
+
34
+ return rgb
35
+
36
+
37
+ def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]):
38
+ aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0])
39
+ aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0])
40
+
41
+ top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0
42
+
43
+ if aspect_ratio1 < aspect_ratio2:
44
+ new_width = (aspect_ratio2 * image1.shape[0])
45
+ right_pad = left_pad = int((new_width - image1.shape[1]) / 2)
46
+ else:
47
+ new_height = (image1.shape[1] / aspect_ratio2)
48
+ bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2)
49
+
50
+ image1_padded = cv2.copyMakeBorder(
51
+ image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value
52
+ )
53
+ return image1_padded
54
+
55
+
56
+ def estimate_img_mask(image):
57
+ # to gray
58
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
59
+
60
+ # segment
61
+ # _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
62
+ # mask_otsu = thresh.astype(bool)
63
+ # thresh_gray = 240
64
+
65
+ edges = cv2.Canny(gray, 20, 50)
66
+
67
+ kernel = np.ones((3, 3), np.uint8)
68
+ edges_dilated = cv2.dilate(edges, kernel, iterations=1)
69
+
70
+ contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
71
+
72
+ mask = np.zeros_like(gray, dtype=np.uint8)
73
+
74
+ cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
75
+ mask = mask.astype(bool)
76
+
77
+ return mask
78
+
79
+
80
+ def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False):
81
+ scale = 0.125
82
+ gray_trunc_thres = 25 / 255.0
83
+
84
+ # Match
85
+ if matches1.shape[0] > 0:
86
+ match_scale = np.max(np.ptp(matches1, axis=-1))
87
+ match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1))
88
+ dist_threshold = match_scale * 0.01
89
+ match_num = np.sum(match_dists <= dist_threshold)
90
+ match_rate = np.mean(match_dists <= dist_threshold)
91
+ else:
92
+ match_num = 0
93
+ match_rate = 0
94
+
95
+ # IOU
96
+ img1_mask = estimate_img_mask(img1)
97
+ img2_mask = estimate_img_mask(img2)
98
+ img_intersection = (img1_mask == 1) & (img2_mask == 1)
99
+ img_union = (img1_mask == 1) | (img2_mask == 1)
100
+ intersection = np.sum(img_intersection == 1)
101
+ union = np.sum(img_union == 1)
102
+ mask_iou = intersection / union if union != 0 else 0
103
+
104
+ # Gray
105
+ height, width = img1.shape[:2]
106
+ img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
107
+ img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
108
+ img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0)
109
+ img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0)
110
+
111
+ # Gray Diff
112
+ img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)),
113
+ interpolation=cv2.INTER_LINEAR) / 255.0
114
+ img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)),
115
+ interpolation=cv2.INTER_LINEAR) / 255.0
116
+ img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small)
117
+ gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1
118
+
119
+ img_gray_small_diff_trunc = img_gray_small_diff.copy()
120
+ img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0
121
+ gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1
122
+
123
+ # Edge
124
+ img1_edge = cv2.Canny(img1_gray, 100, 200)
125
+ img2_edge = cv2.Canny(img2_gray, 100, 200)
126
+ bw_edges1 = (img1_edge > 0).astype(bool)
127
+ bw_edges2 = (img2_edge > 0).astype(bool)
128
+ hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2)
129
+ if vis == True:
130
+ fig, axs = plt.subplots(1, 4, figsize=(15, 5))
131
+ axs[0].imshow(img1_gray, cmap='gray')
132
+ axs[0].set_title('Img1')
133
+ axs[1].imshow(img2_gray, cmap='gray')
134
+ axs[1].set_title('Img2')
135
+ axs[2].imshow(img1_mask)
136
+ axs[2].set_title('Mask1')
137
+ axs[3].imshow(img2_mask)
138
+ axs[3].set_title('Mask2')
139
+ plt.show()
140
+ plt.figure()
141
+ mask_cmp = np.zeros((height, width, 3))
142
+ mask_cmp[img_intersection, 1] = 1
143
+ mask_cmp[img_union, 0] = 1
144
+ plt.imshow(mask_cmp)
145
+ plt.show()
146
+ fig, axs = plt.subplots(1, 4, figsize=(15, 5))
147
+ axs[0].imshow(img1_gray_small, cmap='gray')
148
+ axs[0].set_title('Img1 Gray')
149
+ axs[1].imshow(img2_gray_small, cmap='gray')
150
+ axs[1].set_title('Img2 Gary')
151
+ axs[2].imshow(img_gray_small_diff, cmap='gray')
152
+ axs[2].set_title('diff')
153
+ axs[3].imshow(img_gray_small_diff_trunc, cmap='gray')
154
+ axs[3].set_title('diff_trunct')
155
+ plt.show()
156
+ fig, axs = plt.subplots(1, 2, figsize=(15, 5))
157
+ axs[0].imshow(img1_edge, cmap='gray')
158
+ axs[0].set_title('img1_edge')
159
+ axs[1].imshow(img2_edge, cmap='gray')
160
+ axs[1].set_title('img2_edge')
161
+ plt.show()
162
+
163
+ info = {}
164
+ info['match_num'] = match_num
165
+ info['match_rate'] = match_rate
166
+ info['mask_iou'] = mask_iou
167
+ info['gray_diff'] = gray_diff
168
+ info['gray_diff_trunc'] = gray_diff_trunc
169
+ info['hausdorff_dist'] = hausdorff_dist
170
+ return info
171
+
172
+
173
+ def predict_match_success_human(info):
174
+ match_num = info['match_num']
175
+ match_rate = info['match_rate']
176
+ mask_iou = info['mask_iou']
177
+ gray_diff = info['gray_diff']
178
+ gray_diff_trunc = info['gray_diff_trunc']
179
+ hausdorff_dist = info['hausdorff_dist']
180
+
181
+ if mask_iou > 0.95:
182
+ return True
183
+
184
+ if match_num < 20 or match_rate < 0.7:
185
+ return False
186
+
187
+ if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010:
188
+ return True
189
+
190
+ if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008:
191
+ return True
192
+
193
+ '''
194
+ if match_rate<0.70 or match_num<3000:
195
+ return False
196
+
197
+ if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90:
198
+ return True
199
+ '''
200
+
201
+ return False
202
+
203
+
204
+ def predict_match_success(info, model=None):
205
+ if model == None:
206
+ return predict_match_success_human(info)
207
+ else:
208
+ feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist']
209
+ features = [info[f] for f in feat_name]
210
+ pred = model.predict([features])[0]
211
+ return pred >= 0.5