Mithun12345 commited on
Commit
5a33903
·
verified ·
1 Parent(s): ae3c97e

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +43 -0
  2. LICENSE +201 -0
  3. README.md +146 -12
  4. app.py +348 -0
  5. requirements.txt +19 -0
  6. run.py +262 -0
  7. train.py +286 -0
.gitignore ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ eggs/
15
+ .eggs/
16
+ .vscode/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ .DS_Store
29
+
30
+ tools/objaverse_rendering/blender-3.2.2-linux-x64/
31
+ tools/objaverse_rendering/output/
32
+ ckpts/
33
+ lightning_logs/
34
+ logs/
35
+ .trash/
36
+ .env/
37
+ outputs/
38
+ figures*/
39
+
40
+ # Useless Files
41
+ *.sh
42
+ blender/
43
+ .restore/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,146 @@
1
- ---
2
- title: 3D Model Demo
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models
4
+
5
+ <a href="https://arxiv.org/abs/2404.07191"><img src="https://img.shields.io/badge/ArXiv-2404.07191-brightgreen"></a>
6
+ <a href="https://huggingface.co/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
7
+ <a href="https://huggingface.co/spaces/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <br>
8
+ <a href="https://replicate.com/camenduru/instantmesh"><img src="https://img.shields.io/badge/Demo-Replicate-blue"></a>
9
+ <a href="https://colab.research.google.com/github/camenduru/InstantMesh-jupyter/blob/main/InstantMesh_jupyter.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg"></a>
10
+ <a href="https://github.com/jtydhr88/ComfyUI-InstantMesh"><img src="https://img.shields.io/badge/Demo-ComfyUI-8A2BE2"></a>
11
+
12
+ </div>
13
+
14
+ ---
15
+
16
+ This repo is the official implementation of InstantMesh, a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
17
+
18
+ https://github.com/TencentARC/InstantMesh/assets/20635237/dab3511e-e7c6-4c0b-bab7-15772045c47d
19
+
20
+ # 🚩 Features and Todo List
21
+ - [x] 🔥🔥 Release Zero123++ fine-tuning code.
22
+ - [x] 🔥🔥 Support for running gradio demo on two GPUs to save memory.
23
+ - [x] 🔥🔥 Support for running demo with docker. Please refer to the [docker](docker/) directory.
24
+ - [x] Release inference and training code.
25
+ - [x] Release model weights.
26
+ - [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link.
27
+ - [ ] Add support for more multi-view diffusion models.
28
+
29
+ # ⚙️ Dependencies and Installation
30
+
31
+ We recommend using `Python>=3.10`, `PyTorch>=2.1.0`, and `CUDA>=12.1`.
32
+ ```bash
33
+ conda create --name instantmesh python=3.10
34
+ conda activate instantmesh
35
+ pip install -U pip
36
+
37
+ # Ensure Ninja is installed
38
+ conda install Ninja
39
+
40
+ # Install the correct version of CUDA
41
+ conda install cuda -c nvidia/label/cuda-12.1.0
42
+
43
+ # Install PyTorch and xformers
44
+ # You may need to install another xformers version if you use a different PyTorch version
45
+ pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
46
+ pip install xformers==0.0.22.post7
47
+
48
+ # For Linux users: Install Triton
49
+ pip install triton
50
+
51
+ # For Windows users: Use the prebuilt version of Triton provided here:
52
+ pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl
53
+
54
+ # Install other requirements
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ # 💫 How to Use
59
+
60
+ ## Download the models
61
+
62
+ We provide 4 sparse-view reconstruction model variants and a customized Zero123++ UNet for white-background image generation in the [model card](https://huggingface.co/TencentARC/InstantMesh).
63
+
64
+ Our inference script will download the models automatically. Alternatively, you can manually download the models and put them under the `ckpts/` directory.
65
+
66
+ By default, we use the `instant-mesh-large` reconstruction model variant.
67
+
68
+ ## Start a local gradio demo
69
+
70
+ To start a gradio demo in your local machine, simply run:
71
+ ```bash
72
+ python app.py
73
+ ```
74
+
75
+ If you have multiple GPUs in your machine, the demo app will run on two GPUs automatically to save memory. You can also force it to run on a single GPU:
76
+ ```bash
77
+ CUDA_VISIBLE_DEVICES=0 python app.py
78
+ ```
79
+
80
+ Alternatively, you can run the demo with docker. Please follow the instructions in the [docker](docker/) directory.
81
+
82
+ ## Running with command line
83
+
84
+ To generate 3D meshes from images via command line, simply run:
85
+ ```bash
86
+ python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video
87
+ ```
88
+
89
+ We use [rembg](https://github.com/danielgatis/rembg) to segment the foreground object. If the input image already has an alpha mask, please specify the `no_rembg` flag:
90
+ ```bash
91
+ python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video --no_rembg
92
+ ```
93
+
94
+ By default, our script exports a `.obj` mesh with vertex colors, please specify the `--export_texmap` flag if you hope to export a mesh with a texture map instead (this will cost longer time):
95
+ ```bash
96
+ python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video --export_texmap
97
+ ```
98
+
99
+ Please use a different `.yaml` config file in the [configs](./configs) directory if you hope to use other reconstruction model variants. For example, using the `instant-nerf-large` model for generation:
100
+ ```bash
101
+ python run.py configs/instant-nerf-large.yaml examples/hatsune_miku.png --save_video
102
+ ```
103
+ **Note:** When using the `NeRF` model variants for image-to-3D generation, exporting a mesh with texture map by specifying `--export_texmap` may cost long time in the UV unwarping step since the default iso-surface extraction resolution is `256`. You can set a lower iso-surface extraction resolution in the config file.
104
+
105
+ # 💻 Training
106
+
107
+ We provide our training code to facilitate future research. But we cannot provide the training dataset due to its size. Please refer to our [dataloader](src/data/objaverse.py) for more details.
108
+
109
+ To train the sparse-view reconstruction models, please run:
110
+ ```bash
111
+ # Training on NeRF representation
112
+ python train.py --base configs/instant-nerf-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
113
+
114
+ # Training on Mesh representation
115
+ python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
116
+ ```
117
+
118
+ We also provide our Zero123++ fine-tuning code since it is frequently requested. The running command is:
119
+ ```bash
120
+ python train.py --base configs/zero123plus-finetune.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
121
+ ```
122
+
123
+ # :books: Citation
124
+
125
+ If you find our work useful for your research or applications, please cite using this BibTeX:
126
+
127
+ ```BibTeX
128
+ @article{xu2024instantmesh,
129
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
130
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
131
+ journal={arXiv preprint arXiv:2404.07191},
132
+ year={2024}
133
+ }
134
+ ```
135
+
136
+ # 🤗 Acknowledgements
137
+
138
+ We thank the authors of the following projects for their excellent contributions to 3D generative AI!
139
+
140
+ - [Zero123++](https://github.com/SUDO-AI-3D/zero123plus)
141
+ - [OpenLRM](https://github.com/3DTopia/OpenLRM)
142
+ - [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes)
143
+ - [Instant3D](https://instant-3d.github.io/)
144
+
145
+ Thank [@camenduru](https://github.com/camenduru) for implementing [Replicate Demo](https://replicate.com/camenduru/instantmesh) and [Colab Demo](https://colab.research.google.com/github/camenduru/InstantMesh-jupyter/blob/main/InstantMesh_jupyter.ipynb)!
146
+ Thank [@jtydhr88](https://github.com/jtydhr88) for implementing [ComfyUI support](https://github.com/jtydhr88/ComfyUI-InstantMesh)!
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
13
+
14
+ from src.utils.train_util import instantiate_from_config
15
+ from src.utils.camera_util import (
16
+ FOV_to_intrinsics,
17
+ get_zero123plus_input_cameras,
18
+ get_circular_camera_poses,
19
+ )
20
+ from src.utils.mesh_util import save_obj, save_glb
21
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
22
+
23
+ import tempfile
24
+ from huggingface_hub import hf_hub_download
25
+
26
+
27
+ if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
28
+ device0 = torch.device('cuda:0')
29
+ device1 = torch.device('cuda:1')
30
+ else:
31
+ device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ device1 = device0
33
+
34
+ # Define the cache directory for model files
35
+ model_cache_dir = './ckpts/'
36
+ os.makedirs(model_cache_dir, exist_ok=True)
37
+
38
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
39
+ """
40
+ Get the rendering camera parameters.
41
+ """
42
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
43
+ if is_flexicubes:
44
+ cameras = torch.linalg.inv(c2ws)
45
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
46
+ else:
47
+ extrinsics = c2ws.flatten(-2)
48
+ intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
49
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
50
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
51
+ return cameras
52
+
53
+
54
+ def images_to_video(images, output_path, fps=30):
55
+ # images: (N, C, H, W)
56
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
57
+ frames = []
58
+ for i in range(images.shape[0]):
59
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
60
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
61
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
62
+ assert frame.min() >= 0 and frame.max() <= 255, \
63
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
64
+ frames.append(frame)
65
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
66
+
67
+
68
+ ###############################################################################
69
+ # Configuration.
70
+ ###############################################################################
71
+
72
+ seed_everything(0)
73
+
74
+ config_path = 'configs/instant-mesh-large.yaml'
75
+ config = OmegaConf.load(config_path)
76
+ config_name = os.path.basename(config_path).replace('.yaml', '')
77
+ model_config = config.model_config
78
+ infer_config = config.infer_config
79
+
80
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
81
+
82
+ device = torch.device('cuda')
83
+
84
+ # load diffusion model
85
+ print('Loading diffusion model ...')
86
+ pipeline = DiffusionPipeline.from_pretrained(
87
+ "sudo-ai/zero123plus-v1.2",
88
+ custom_pipeline="zero123plus",
89
+ torch_dtype=torch.float16,
90
+ cache_dir=model_cache_dir
91
+ )
92
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
93
+ pipeline.scheduler.config, timestep_spacing='trailing'
94
+ )
95
+
96
+ # load custom white-background UNet
97
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir)
98
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
99
+ pipeline.unet.load_state_dict(state_dict, strict=True)
100
+
101
+ pipeline = pipeline.to(device0)
102
+
103
+ # load reconstruction model
104
+ print('Loading reconstruction model ...')
105
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir)
106
+ model = instantiate_from_config(model_config)
107
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
108
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
109
+ model.load_state_dict(state_dict, strict=True)
110
+
111
+ model = model.to(device1)
112
+ if IS_FLEXICUBES:
113
+ model.init_flexicubes_geometry(device1, fovy=30.0)
114
+ model = model.eval()
115
+
116
+ print('Loading Finished!')
117
+
118
+
119
+ def check_input_image(input_image):
120
+ if input_image is None:
121
+ raise gr.Error("No image uploaded!")
122
+
123
+
124
+ def preprocess(input_image, do_remove_background):
125
+
126
+ rembg_session = rembg.new_session() if do_remove_background else None
127
+ if do_remove_background:
128
+ input_image = remove_background(input_image, rembg_session)
129
+ input_image = resize_foreground(input_image, 0.85)
130
+
131
+ return input_image
132
+
133
+
134
+ def generate_mvs(input_image, sample_steps, sample_seed):
135
+
136
+ seed_everything(sample_seed)
137
+
138
+ # sampling
139
+ generator = torch.Generator(device=device0)
140
+ z123_image = pipeline(
141
+ input_image,
142
+ num_inference_steps=sample_steps,
143
+ generator=generator,
144
+ ).images[0]
145
+
146
+ show_image = np.asarray(z123_image, dtype=np.uint8)
147
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
148
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
149
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
150
+ show_image = Image.fromarray(show_image.numpy())
151
+
152
+ return z123_image, show_image
153
+
154
+
155
+ def make_mesh(mesh_fpath, planes):
156
+
157
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
158
+ mesh_dirname = os.path.dirname(mesh_fpath)
159
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
160
+
161
+ with torch.no_grad():
162
+ # get mesh
163
+
164
+ mesh_out = model.extract_mesh(
165
+ planes,
166
+ use_texture_map=False,
167
+ **infer_config,
168
+ )
169
+
170
+ vertices, faces, vertex_colors = mesh_out
171
+ vertices = vertices[:, [1, 2, 0]]
172
+
173
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
174
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
175
+
176
+ print(f"Mesh saved to {mesh_fpath}")
177
+
178
+ return mesh_fpath, mesh_glb_fpath
179
+
180
+
181
+ def make3d(images):
182
+
183
+ images = np.asarray(images, dtype=np.float32) / 255.0
184
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
185
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
186
+
187
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device1)
188
+ render_cameras = get_render_cameras(
189
+ batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device1)
190
+
191
+ images = images.unsqueeze(0).to(device1)
192
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
193
+
194
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
195
+ print(mesh_fpath)
196
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
197
+ mesh_dirname = os.path.dirname(mesh_fpath)
198
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
199
+
200
+ with torch.no_grad():
201
+ # get triplane
202
+ planes = model.forward_planes(images, input_cameras)
203
+
204
+ # get video
205
+ chunk_size = 20 if IS_FLEXICUBES else 1
206
+ render_size = 384
207
+
208
+ frames = []
209
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
210
+ if IS_FLEXICUBES:
211
+ frame = model.forward_geometry(
212
+ planes,
213
+ render_cameras[:, i:i+chunk_size],
214
+ render_size=render_size,
215
+ )['img']
216
+ else:
217
+ frame = model.synthesizer(
218
+ planes,
219
+ cameras=render_cameras[:, i:i+chunk_size],
220
+ render_size=render_size,
221
+ )['images_rgb']
222
+ frames.append(frame)
223
+ frames = torch.cat(frames, dim=1)
224
+
225
+ images_to_video(
226
+ frames[0],
227
+ video_fpath,
228
+ fps=30,
229
+ )
230
+
231
+ print(f"Video saved to {video_fpath}")
232
+
233
+ mesh_fpath, mesh_glb_fpath = make_mesh(mesh_fpath, planes)
234
+
235
+ return video_fpath, mesh_fpath, mesh_glb_fpath
236
+
237
+
238
+ import gradio as gr
239
+
240
+
241
+ with gr.Blocks() as demo:
242
+ gr.Markdown(_HEADER_)
243
+ with gr.Row(variant="panel"):
244
+ with gr.Column():
245
+ with gr.Row():
246
+ input_image = gr.Image(
247
+ label="Input Image",
248
+ image_mode="RGBA",
249
+ sources="upload",
250
+ width=256,
251
+ height=256,
252
+ type="pil",
253
+ elem_id="content_image",
254
+ )
255
+ processed_image = gr.Image(
256
+ label="Processed Image",
257
+ image_mode="RGBA",
258
+ width=256,
259
+ height=256,
260
+ type="pil",
261
+ interactive=False
262
+ )
263
+ with gr.Row():
264
+ with gr.Group():
265
+ do_remove_background = gr.Checkbox(
266
+ label="Remove Background", value=True
267
+ )
268
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
269
+
270
+ sample_steps = gr.Slider(
271
+ label="Sample Steps",
272
+ minimum=30,
273
+ maximum=75,
274
+ value=75,
275
+ step=5
276
+ )
277
+
278
+ with gr.Row():
279
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
280
+
281
+ with gr.Row(variant="panel"):
282
+ gr.Examples(
283
+ examples=[
284
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
285
+ ],
286
+ inputs=[input_image],
287
+ label="Examples",
288
+ examples_per_page=20
289
+ )
290
+
291
+ with gr.Column():
292
+
293
+ with gr.Row():
294
+
295
+ with gr.Column():
296
+ mv_show_images = gr.Image(
297
+ label="Generated Multi-views",
298
+ type="pil",
299
+ width=379,
300
+ interactive=False
301
+ )
302
+
303
+ with gr.Column():
304
+ output_video = gr.Video(
305
+ label="video", format="mp4",
306
+ width=379,
307
+ autoplay=True,
308
+ interactive=False
309
+ )
310
+
311
+ with gr.Row():
312
+ with gr.Tab("OBJ"):
313
+ output_model_obj = gr.Model3D(
314
+ label="Output Model (OBJ Format)",
315
+ #width=768,
316
+ interactive=False,
317
+ )
318
+ gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
319
+ with gr.Tab("GLB"):
320
+ output_model_glb = gr.Model3D(
321
+ label="Output Model (GLB Format)",
322
+ #width=768,
323
+ interactive=False,
324
+ )
325
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
326
+
327
+ with gr.Row():
328
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
329
+
330
+ gr.Markdown(_CITE_)
331
+ mv_images = gr.State()
332
+
333
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
334
+ fn=preprocess,
335
+ inputs=[input_image, do_remove_background],
336
+ outputs=[processed_image],
337
+ ).success(
338
+ fn=generate_mvs,
339
+ inputs=[processed_image, sample_steps, sample_seed],
340
+ outputs=[mv_images, mv_show_images],
341
+ ).success(
342
+ fn=make3d,
343
+ inputs=[mv_images],
344
+ outputs=[output_video, output_model_obj, output_model_glb]
345
+ )
346
+
347
+ demo.queue(max_size=10)
348
+ demo.launch(server_name="0.0.0.0", server_port=43839)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch-lightning==2.1.2
2
+ gradio==3.41.2
3
+ huggingface-hub
4
+ einops
5
+ omegaconf
6
+ torchmetrics
7
+ webdataset
8
+ accelerate
9
+ tensorboard
10
+ PyMCubes
11
+ trimesh
12
+ rembg
13
+ transformers==4.34.1
14
+ diffusers==0.20.2
15
+ bitsandbytes
16
+ imageio[ffmpeg]
17
+ xatlas
18
+ plyfile
19
+ git+https://github.com/NVlabs/nvdiffrast/
run.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ from huggingface_hub import hf_hub_download
13
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
14
+
15
+ from src.utils.train_util import instantiate_from_config
16
+ from src.utils.camera_util import (
17
+ FOV_to_intrinsics,
18
+ get_zero123plus_input_cameras,
19
+ get_circular_camera_poses,
20
+ )
21
+ from src.utils.mesh_util import save_obj, save_obj_with_mtl
22
+ from src.utils.infer_util import remove_background, resize_foreground, save_video
23
+
24
+
25
+ def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False):
26
+ """
27
+ Get the rendering camera parameters.
28
+ """
29
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
30
+ if is_flexicubes:
31
+ cameras = torch.linalg.inv(c2ws)
32
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
33
+ else:
34
+ extrinsics = c2ws.flatten(-2)
35
+ intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
36
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
37
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
38
+ return cameras
39
+
40
+
41
+ def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False):
42
+ """
43
+ Render frames from triplanes.
44
+ """
45
+ frames = []
46
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
47
+ if is_flexicubes:
48
+ frame = model.forward_geometry(
49
+ planes,
50
+ render_cameras[:, i:i+chunk_size],
51
+ render_size=render_size,
52
+ )['img']
53
+ else:
54
+ frame = model.forward_synthesizer(
55
+ planes,
56
+ render_cameras[:, i:i+chunk_size],
57
+ render_size=render_size,
58
+ )['images_rgb']
59
+ frames.append(frame)
60
+
61
+ frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
62
+ return frames
63
+
64
+
65
+ ###############################################################################
66
+ # Arguments.
67
+ ###############################################################################
68
+
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument('config', type=str, help='Path to config file.')
71
+ parser.add_argument('input_path', type=str, help='Path to input image or directory.')
72
+ parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
73
+ parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.')
74
+ parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
75
+ parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
76
+ parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
77
+ parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
78
+ parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
79
+ parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
80
+ parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
81
+ args = parser.parse_args()
82
+ seed_everything(args.seed)
83
+
84
+ ###############################################################################
85
+ # Stage 0: Configuration.
86
+ ###############################################################################
87
+
88
+ config = OmegaConf.load(args.config)
89
+ config_name = os.path.basename(args.config).replace('.yaml', '')
90
+ model_config = config.model_config
91
+ infer_config = config.infer_config
92
+
93
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
94
+
95
+ device = torch.device('cuda')
96
+
97
+ # load diffusion model
98
+ print('Loading diffusion model ...')
99
+ pipeline = DiffusionPipeline.from_pretrained(
100
+ "sudo-ai/zero123plus-v1.2",
101
+ custom_pipeline="zero123plus",
102
+ torch_dtype=torch.float16,
103
+ )
104
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
105
+ pipeline.scheduler.config, timestep_spacing='trailing'
106
+ )
107
+
108
+ # load custom white-background UNet
109
+ print('Loading custom white-background unet ...')
110
+ if os.path.exists(infer_config.unet_path):
111
+ unet_ckpt_path = infer_config.unet_path
112
+ else:
113
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
114
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
115
+ pipeline.unet.load_state_dict(state_dict, strict=True)
116
+
117
+ pipeline = pipeline.to(device)
118
+
119
+ # load reconstruction model
120
+ print('Loading reconstruction model ...')
121
+ model = instantiate_from_config(model_config)
122
+ if os.path.exists(infer_config.model_path):
123
+ model_ckpt_path = infer_config.model_path
124
+ else:
125
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
126
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
127
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
128
+ model.load_state_dict(state_dict, strict=True)
129
+
130
+ model = model.to(device)
131
+ if IS_FLEXICUBES:
132
+ model.init_flexicubes_geometry(device, fovy=30.0)
133
+ model = model.eval()
134
+
135
+ # make output directories
136
+ image_path = os.path.join(args.output_path, config_name, 'images')
137
+ mesh_path = os.path.join(args.output_path, config_name, 'meshes')
138
+ video_path = os.path.join(args.output_path, config_name, 'videos')
139
+ os.makedirs(image_path, exist_ok=True)
140
+ os.makedirs(mesh_path, exist_ok=True)
141
+ os.makedirs(video_path, exist_ok=True)
142
+
143
+ # process input files
144
+ if os.path.isdir(args.input_path):
145
+ input_files = [
146
+ os.path.join(args.input_path, file)
147
+ for file in os.listdir(args.input_path)
148
+ if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
149
+ ]
150
+ else:
151
+ input_files = [args.input_path]
152
+ print(f'Total number of input images: {len(input_files)}')
153
+
154
+
155
+ ###############################################################################
156
+ # Stage 1: Multiview generation.
157
+ ###############################################################################
158
+
159
+ rembg_session = None if args.no_rembg else rembg.new_session()
160
+
161
+ outputs = []
162
+ for idx, image_file in enumerate(input_files):
163
+ name = os.path.basename(image_file).split('.')[0]
164
+ print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
165
+
166
+ # remove background optionally
167
+ input_image = Image.open(image_file)
168
+ if not args.no_rembg:
169
+ input_image = remove_background(input_image, rembg_session)
170
+ input_image = resize_foreground(input_image, 0.85)
171
+
172
+ # sampling
173
+ output_image = pipeline(
174
+ input_image,
175
+ num_inference_steps=args.diffusion_steps,
176
+ ).images[0]
177
+
178
+ output_image.save(os.path.join(image_path, f'{name}.png'))
179
+ print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
180
+
181
+ images = np.asarray(output_image, dtype=np.float32) / 255.0
182
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
183
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
184
+
185
+ outputs.append({'name': name, 'images': images})
186
+
187
+ # delete pipeline to save memory
188
+ del pipeline
189
+
190
+ ###############################################################################
191
+ # Stage 2: Reconstruction.
192
+ ###############################################################################
193
+
194
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*args.scale).to(device)
195
+ chunk_size = 20 if IS_FLEXICUBES else 1
196
+
197
+ for idx, sample in enumerate(outputs):
198
+ name = sample['name']
199
+ print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
200
+
201
+ images = sample['images'].unsqueeze(0).to(device)
202
+ images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
203
+
204
+ if args.view == 4:
205
+ indices = torch.tensor([0, 2, 4, 5]).long().to(device)
206
+ images = images[:, indices]
207
+ input_cameras = input_cameras[:, indices]
208
+
209
+ with torch.no_grad():
210
+ # get triplane
211
+ planes = model.forward_planes(images, input_cameras)
212
+
213
+ # get mesh
214
+ mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
215
+
216
+ mesh_out = model.extract_mesh(
217
+ planes,
218
+ use_texture_map=args.export_texmap,
219
+ **infer_config,
220
+ )
221
+ if args.export_texmap:
222
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
223
+ save_obj_with_mtl(
224
+ vertices.data.cpu().numpy(),
225
+ uvs.data.cpu().numpy(),
226
+ faces.data.cpu().numpy(),
227
+ mesh_tex_idx.data.cpu().numpy(),
228
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
229
+ mesh_path_idx,
230
+ )
231
+ else:
232
+ vertices, faces, vertex_colors = mesh_out
233
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
234
+ print(f"Mesh saved to {mesh_path_idx}")
235
+
236
+ # get video
237
+ if args.save_video:
238
+ video_path_idx = os.path.join(video_path, f'{name}.mp4')
239
+ render_size = infer_config.render_resolution
240
+ render_cameras = get_render_cameras(
241
+ batch_size=1,
242
+ M=120,
243
+ radius=args.distance,
244
+ elevation=20.0,
245
+ is_flexicubes=IS_FLEXICUBES,
246
+ ).to(device)
247
+
248
+ frames = render_frames(
249
+ model,
250
+ planes,
251
+ render_cameras=render_cameras,
252
+ render_size=render_size,
253
+ chunk_size=chunk_size,
254
+ is_flexicubes=IS_FLEXICUBES,
255
+ )
256
+
257
+ save_video(
258
+ frames,
259
+ video_path_idx,
260
+ fps=30,
261
+ )
262
+ print(f"Video saved to {video_path_idx}")
train.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import argparse
3
+ import shutil
4
+ import subprocess
5
+ from omegaconf import OmegaConf
6
+
7
+ from pytorch_lightning import seed_everything
8
+ from pytorch_lightning.trainer import Trainer
9
+ from pytorch_lightning.strategies import DDPStrategy
10
+ from pytorch_lightning.callbacks import Callback
11
+ from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
12
+
13
+ from src.utils.train_util import instantiate_from_config
14
+
15
+
16
+ @rank_zero_only
17
+ def rank_zero_print(*args):
18
+ print(*args)
19
+
20
+
21
+ def get_parser(**parser_kwargs):
22
+ def str2bool(v):
23
+ if isinstance(v, bool):
24
+ return v
25
+ if v.lower() in ("yes", "true", "t", "y", "1"):
26
+ return True
27
+ elif v.lower() in ("no", "false", "f", "n", "0"):
28
+ return False
29
+ else:
30
+ raise argparse.ArgumentTypeError("Boolean value expected.")
31
+
32
+ parser = argparse.ArgumentParser(**parser_kwargs)
33
+ parser.add_argument(
34
+ "-r",
35
+ "--resume",
36
+ type=str,
37
+ default=None,
38
+ help="resume from checkpoint",
39
+ )
40
+ parser.add_argument(
41
+ "--resume_weights_only",
42
+ action="store_true",
43
+ help="only resume model weights",
44
+ )
45
+ parser.add_argument(
46
+ "-b",
47
+ "--base",
48
+ type=str,
49
+ default="base_config.yaml",
50
+ help="path to base configs",
51
+ )
52
+ parser.add_argument(
53
+ "-n",
54
+ "--name",
55
+ type=str,
56
+ default="",
57
+ help="experiment name",
58
+ )
59
+ parser.add_argument(
60
+ "--num_nodes",
61
+ type=int,
62
+ default=1,
63
+ help="number of nodes to use",
64
+ )
65
+ parser.add_argument(
66
+ "--gpus",
67
+ type=str,
68
+ default="0,",
69
+ help="gpu ids to use",
70
+ )
71
+ parser.add_argument(
72
+ "-s",
73
+ "--seed",
74
+ type=int,
75
+ default=42,
76
+ help="seed for seed_everything",
77
+ )
78
+ parser.add_argument(
79
+ "-l",
80
+ "--logdir",
81
+ type=str,
82
+ default="logs",
83
+ help="directory for logging data",
84
+ )
85
+ return parser
86
+
87
+
88
+ class SetupCallback(Callback):
89
+ def __init__(self, resume, logdir, ckptdir, cfgdir, config):
90
+ super().__init__()
91
+ self.resume = resume
92
+ self.logdir = logdir
93
+ self.ckptdir = ckptdir
94
+ self.cfgdir = cfgdir
95
+ self.config = config
96
+
97
+ def on_fit_start(self, trainer, pl_module):
98
+ if trainer.global_rank == 0:
99
+ # Create logdirs and save configs
100
+ os.makedirs(self.logdir, exist_ok=True)
101
+ os.makedirs(self.ckptdir, exist_ok=True)
102
+ os.makedirs(self.cfgdir, exist_ok=True)
103
+
104
+ rank_zero_print("Project config")
105
+ rank_zero_print(OmegaConf.to_yaml(self.config))
106
+ OmegaConf.save(self.config,
107
+ os.path.join(self.cfgdir, "project.yaml"))
108
+
109
+
110
+ class CodeSnapshot(Callback):
111
+ """
112
+ Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60
113
+ """
114
+ def __init__(self, savedir):
115
+ self.savedir = savedir
116
+
117
+ def get_file_list(self):
118
+ return [
119
+ b.decode()
120
+ for b in set(
121
+ subprocess.check_output(
122
+ 'git ls-files -- ":!:configs/*"', shell=True
123
+ ).splitlines()
124
+ )
125
+ | set( # hard code, TODO: use config to exclude folders or files
126
+ subprocess.check_output(
127
+ "git ls-files --others --exclude-standard", shell=True
128
+ ).splitlines()
129
+ )
130
+ ]
131
+
132
+ @rank_zero_only
133
+ def save_code_snapshot(self):
134
+ os.makedirs(self.savedir, exist_ok=True)
135
+ for f in self.get_file_list():
136
+ if not os.path.exists(f) or os.path.isdir(f):
137
+ continue
138
+ os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
139
+ shutil.copyfile(f, os.path.join(self.savedir, f))
140
+
141
+ def on_fit_start(self, trainer, pl_module):
142
+ try:
143
+ self.save_code_snapshot()
144
+ except:
145
+ rank_zero_warn(
146
+ "Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
147
+ )
148
+
149
+
150
+ if __name__ == "__main__":
151
+ # add cwd for convenience and to make classes in this file available when
152
+ # running as `python main.py`
153
+ sys.path.append(os.getcwd())
154
+
155
+ parser = get_parser()
156
+ opt, unknown = parser.parse_known_args()
157
+
158
+ cfg_fname = os.path.split(opt.base)[-1]
159
+ cfg_name = os.path.splitext(cfg_fname)[0]
160
+ exp_name = "-" + opt.name if opt.name != "" else ""
161
+ logdir = os.path.join(opt.logdir, cfg_name+exp_name)
162
+
163
+ ckptdir = os.path.join(logdir, "checkpoints")
164
+ cfgdir = os.path.join(logdir, "configs")
165
+ codedir = os.path.join(logdir, "code")
166
+ seed_everything(opt.seed)
167
+
168
+ # init configs
169
+ config = OmegaConf.load(opt.base)
170
+ lightning_config = config.lightning
171
+ trainer_config = lightning_config.trainer
172
+
173
+ trainer_config["accelerator"] = "gpu"
174
+ rank_zero_print(f"Running on GPUs {opt.gpus}")
175
+ ngpu = len(opt.gpus.strip(",").split(','))
176
+ trainer_config['devices'] = ngpu
177
+
178
+ trainer_opt = argparse.Namespace(**trainer_config)
179
+ lightning_config.trainer = trainer_config
180
+
181
+ # model
182
+ model = instantiate_from_config(config.model)
183
+ if opt.resume and opt.resume_weights_only:
184
+ model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params)
185
+
186
+ model.logdir = logdir
187
+
188
+ # trainer and callbacks
189
+ trainer_kwargs = dict()
190
+
191
+ # logger
192
+ default_logger_cfg = {
193
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
194
+ "params": {
195
+ "name": "tensorboard",
196
+ "save_dir": logdir,
197
+ "version": "0",
198
+ }
199
+ }
200
+ logger_cfg = OmegaConf.merge(default_logger_cfg)
201
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
202
+
203
+ # model checkpoint
204
+ default_modelckpt_cfg = {
205
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
206
+ "params": {
207
+ "dirpath": ckptdir,
208
+ "filename": "{step:08}",
209
+ "verbose": True,
210
+ "save_last": True,
211
+ "every_n_train_steps": 5000,
212
+ "save_top_k": -1, # save all checkpoints
213
+ }
214
+ }
215
+
216
+ if "modelcheckpoint" in lightning_config:
217
+ modelckpt_cfg = lightning_config.modelcheckpoint
218
+ else:
219
+ modelckpt_cfg = OmegaConf.create()
220
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
221
+
222
+ # callbacks
223
+ default_callbacks_cfg = {
224
+ "setup_callback": {
225
+ "target": "train.SetupCallback",
226
+ "params": {
227
+ "resume": opt.resume,
228
+ "logdir": logdir,
229
+ "ckptdir": ckptdir,
230
+ "cfgdir": cfgdir,
231
+ "config": config,
232
+ }
233
+ },
234
+ "learning_rate_logger": {
235
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
236
+ "params": {
237
+ "logging_interval": "step",
238
+ }
239
+ },
240
+ "code_snapshot": {
241
+ "target": "train.CodeSnapshot",
242
+ "params": {
243
+ "savedir": codedir,
244
+ }
245
+ },
246
+ }
247
+ default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg
248
+
249
+ if "callbacks" in lightning_config:
250
+ callbacks_cfg = lightning_config.callbacks
251
+ else:
252
+ callbacks_cfg = OmegaConf.create()
253
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
254
+
255
+ trainer_kwargs["callbacks"] = [
256
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
257
+
258
+ trainer_kwargs['precision'] = '32-true'
259
+ trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=True)
260
+
261
+ # trainer
262
+ trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes)
263
+ trainer.logdir = logdir
264
+
265
+ # data
266
+ data = instantiate_from_config(config.data)
267
+ data.prepare_data()
268
+ data.setup("fit")
269
+
270
+ # configure learning rate
271
+ base_lr = config.model.base_learning_rate
272
+ if 'accumulate_grad_batches' in lightning_config.trainer:
273
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
274
+ else:
275
+ accumulate_grad_batches = 1
276
+ rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
277
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
278
+ model.learning_rate = base_lr
279
+ rank_zero_print("++++ NOT USING LR SCALING ++++")
280
+ rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
281
+
282
+ # run training loop
283
+ if opt.resume and not opt.resume_weights_only:
284
+ trainer.fit(model, data, ckpt_path=opt.resume)
285
+ else:
286
+ trainer.fit(model, data)