Spaces:
Configuration error
Configuration error
Mithun12345
commited on
Upload 7 files
Browse files- .gitignore +43 -0
- LICENSE +201 -0
- README.md +146 -12
- app.py +348 -0
- requirements.txt +19 -0
- run.py +262 -0
- train.py +286 -0
.gitignore
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
eggs/
|
15 |
+
.eggs/
|
16 |
+
.vscode/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
.DS_Store
|
29 |
+
|
30 |
+
tools/objaverse_rendering/blender-3.2.2-linux-x64/
|
31 |
+
tools/objaverse_rendering/output/
|
32 |
+
ckpts/
|
33 |
+
lightning_logs/
|
34 |
+
logs/
|
35 |
+
.trash/
|
36 |
+
.env/
|
37 |
+
outputs/
|
38 |
+
figures*/
|
39 |
+
|
40 |
+
# Useless Files
|
41 |
+
*.sh
|
42 |
+
blender/
|
43 |
+
.restore/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,146 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models
|
4 |
+
|
5 |
+
<a href="https://arxiv.org/abs/2404.07191"><img src="https://img.shields.io/badge/ArXiv-2404.07191-brightgreen"></a>
|
6 |
+
<a href="https://huggingface.co/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
|
7 |
+
<a href="https://huggingface.co/spaces/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <br>
|
8 |
+
<a href="https://replicate.com/camenduru/instantmesh"><img src="https://img.shields.io/badge/Demo-Replicate-blue"></a>
|
9 |
+
<a href="https://colab.research.google.com/github/camenduru/InstantMesh-jupyter/blob/main/InstantMesh_jupyter.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg"></a>
|
10 |
+
<a href="https://github.com/jtydhr88/ComfyUI-InstantMesh"><img src="https://img.shields.io/badge/Demo-ComfyUI-8A2BE2"></a>
|
11 |
+
|
12 |
+
</div>
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
This repo is the official implementation of InstantMesh, a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
|
17 |
+
|
18 |
+
https://github.com/TencentARC/InstantMesh/assets/20635237/dab3511e-e7c6-4c0b-bab7-15772045c47d
|
19 |
+
|
20 |
+
# 🚩 Features and Todo List
|
21 |
+
- [x] 🔥🔥 Release Zero123++ fine-tuning code.
|
22 |
+
- [x] 🔥🔥 Support for running gradio demo on two GPUs to save memory.
|
23 |
+
- [x] 🔥🔥 Support for running demo with docker. Please refer to the [docker](docker/) directory.
|
24 |
+
- [x] Release inference and training code.
|
25 |
+
- [x] Release model weights.
|
26 |
+
- [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link.
|
27 |
+
- [ ] Add support for more multi-view diffusion models.
|
28 |
+
|
29 |
+
# ⚙️ Dependencies and Installation
|
30 |
+
|
31 |
+
We recommend using `Python>=3.10`, `PyTorch>=2.1.0`, and `CUDA>=12.1`.
|
32 |
+
```bash
|
33 |
+
conda create --name instantmesh python=3.10
|
34 |
+
conda activate instantmesh
|
35 |
+
pip install -U pip
|
36 |
+
|
37 |
+
# Ensure Ninja is installed
|
38 |
+
conda install Ninja
|
39 |
+
|
40 |
+
# Install the correct version of CUDA
|
41 |
+
conda install cuda -c nvidia/label/cuda-12.1.0
|
42 |
+
|
43 |
+
# Install PyTorch and xformers
|
44 |
+
# You may need to install another xformers version if you use a different PyTorch version
|
45 |
+
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
46 |
+
pip install xformers==0.0.22.post7
|
47 |
+
|
48 |
+
# For Linux users: Install Triton
|
49 |
+
pip install triton
|
50 |
+
|
51 |
+
# For Windows users: Use the prebuilt version of Triton provided here:
|
52 |
+
pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl
|
53 |
+
|
54 |
+
# Install other requirements
|
55 |
+
pip install -r requirements.txt
|
56 |
+
```
|
57 |
+
|
58 |
+
# 💫 How to Use
|
59 |
+
|
60 |
+
## Download the models
|
61 |
+
|
62 |
+
We provide 4 sparse-view reconstruction model variants and a customized Zero123++ UNet for white-background image generation in the [model card](https://huggingface.co/TencentARC/InstantMesh).
|
63 |
+
|
64 |
+
Our inference script will download the models automatically. Alternatively, you can manually download the models and put them under the `ckpts/` directory.
|
65 |
+
|
66 |
+
By default, we use the `instant-mesh-large` reconstruction model variant.
|
67 |
+
|
68 |
+
## Start a local gradio demo
|
69 |
+
|
70 |
+
To start a gradio demo in your local machine, simply run:
|
71 |
+
```bash
|
72 |
+
python app.py
|
73 |
+
```
|
74 |
+
|
75 |
+
If you have multiple GPUs in your machine, the demo app will run on two GPUs automatically to save memory. You can also force it to run on a single GPU:
|
76 |
+
```bash
|
77 |
+
CUDA_VISIBLE_DEVICES=0 python app.py
|
78 |
+
```
|
79 |
+
|
80 |
+
Alternatively, you can run the demo with docker. Please follow the instructions in the [docker](docker/) directory.
|
81 |
+
|
82 |
+
## Running with command line
|
83 |
+
|
84 |
+
To generate 3D meshes from images via command line, simply run:
|
85 |
+
```bash
|
86 |
+
python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video
|
87 |
+
```
|
88 |
+
|
89 |
+
We use [rembg](https://github.com/danielgatis/rembg) to segment the foreground object. If the input image already has an alpha mask, please specify the `no_rembg` flag:
|
90 |
+
```bash
|
91 |
+
python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video --no_rembg
|
92 |
+
```
|
93 |
+
|
94 |
+
By default, our script exports a `.obj` mesh with vertex colors, please specify the `--export_texmap` flag if you hope to export a mesh with a texture map instead (this will cost longer time):
|
95 |
+
```bash
|
96 |
+
python run.py configs/instant-mesh-large.yaml examples/hatsune_miku.png --save_video --export_texmap
|
97 |
+
```
|
98 |
+
|
99 |
+
Please use a different `.yaml` config file in the [configs](./configs) directory if you hope to use other reconstruction model variants. For example, using the `instant-nerf-large` model for generation:
|
100 |
+
```bash
|
101 |
+
python run.py configs/instant-nerf-large.yaml examples/hatsune_miku.png --save_video
|
102 |
+
```
|
103 |
+
**Note:** When using the `NeRF` model variants for image-to-3D generation, exporting a mesh with texture map by specifying `--export_texmap` may cost long time in the UV unwarping step since the default iso-surface extraction resolution is `256`. You can set a lower iso-surface extraction resolution in the config file.
|
104 |
+
|
105 |
+
# 💻 Training
|
106 |
+
|
107 |
+
We provide our training code to facilitate future research. But we cannot provide the training dataset due to its size. Please refer to our [dataloader](src/data/objaverse.py) for more details.
|
108 |
+
|
109 |
+
To train the sparse-view reconstruction models, please run:
|
110 |
+
```bash
|
111 |
+
# Training on NeRF representation
|
112 |
+
python train.py --base configs/instant-nerf-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
|
113 |
+
|
114 |
+
# Training on Mesh representation
|
115 |
+
python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
|
116 |
+
```
|
117 |
+
|
118 |
+
We also provide our Zero123++ fine-tuning code since it is frequently requested. The running command is:
|
119 |
+
```bash
|
120 |
+
python train.py --base configs/zero123plus-finetune.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
|
121 |
+
```
|
122 |
+
|
123 |
+
# :books: Citation
|
124 |
+
|
125 |
+
If you find our work useful for your research or applications, please cite using this BibTeX:
|
126 |
+
|
127 |
+
```BibTeX
|
128 |
+
@article{xu2024instantmesh,
|
129 |
+
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
|
130 |
+
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
|
131 |
+
journal={arXiv preprint arXiv:2404.07191},
|
132 |
+
year={2024}
|
133 |
+
}
|
134 |
+
```
|
135 |
+
|
136 |
+
# 🤗 Acknowledgements
|
137 |
+
|
138 |
+
We thank the authors of the following projects for their excellent contributions to 3D generative AI!
|
139 |
+
|
140 |
+
- [Zero123++](https://github.com/SUDO-AI-3D/zero123plus)
|
141 |
+
- [OpenLRM](https://github.com/3DTopia/OpenLRM)
|
142 |
+
- [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes)
|
143 |
+
- [Instant3D](https://instant-3d.github.io/)
|
144 |
+
|
145 |
+
Thank [@camenduru](https://github.com/camenduru) for implementing [Replicate Demo](https://replicate.com/camenduru/instantmesh) and [Colab Demo](https://colab.research.google.com/github/camenduru/InstantMesh-jupyter/blob/main/InstantMesh_jupyter.ipynb)!
|
146 |
+
Thank [@jtydhr88](https://github.com/jtydhr88) for implementing [ComfyUI support](https://github.com/jtydhr88/ComfyUI-InstantMesh)!
|
app.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import rembg
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import v2
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from tqdm import tqdm
|
12 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
13 |
+
|
14 |
+
from src.utils.train_util import instantiate_from_config
|
15 |
+
from src.utils.camera_util import (
|
16 |
+
FOV_to_intrinsics,
|
17 |
+
get_zero123plus_input_cameras,
|
18 |
+
get_circular_camera_poses,
|
19 |
+
)
|
20 |
+
from src.utils.mesh_util import save_obj, save_glb
|
21 |
+
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
|
22 |
+
|
23 |
+
import tempfile
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
|
27 |
+
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
|
28 |
+
device0 = torch.device('cuda:0')
|
29 |
+
device1 = torch.device('cuda:1')
|
30 |
+
else:
|
31 |
+
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
device1 = device0
|
33 |
+
|
34 |
+
# Define the cache directory for model files
|
35 |
+
model_cache_dir = './ckpts/'
|
36 |
+
os.makedirs(model_cache_dir, exist_ok=True)
|
37 |
+
|
38 |
+
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
39 |
+
"""
|
40 |
+
Get the rendering camera parameters.
|
41 |
+
"""
|
42 |
+
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
43 |
+
if is_flexicubes:
|
44 |
+
cameras = torch.linalg.inv(c2ws)
|
45 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
46 |
+
else:
|
47 |
+
extrinsics = c2ws.flatten(-2)
|
48 |
+
intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
49 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
50 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
51 |
+
return cameras
|
52 |
+
|
53 |
+
|
54 |
+
def images_to_video(images, output_path, fps=30):
|
55 |
+
# images: (N, C, H, W)
|
56 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
57 |
+
frames = []
|
58 |
+
for i in range(images.shape[0]):
|
59 |
+
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
|
60 |
+
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
61 |
+
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
62 |
+
assert frame.min() >= 0 and frame.max() <= 255, \
|
63 |
+
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
64 |
+
frames.append(frame)
|
65 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
|
66 |
+
|
67 |
+
|
68 |
+
###############################################################################
|
69 |
+
# Configuration.
|
70 |
+
###############################################################################
|
71 |
+
|
72 |
+
seed_everything(0)
|
73 |
+
|
74 |
+
config_path = 'configs/instant-mesh-large.yaml'
|
75 |
+
config = OmegaConf.load(config_path)
|
76 |
+
config_name = os.path.basename(config_path).replace('.yaml', '')
|
77 |
+
model_config = config.model_config
|
78 |
+
infer_config = config.infer_config
|
79 |
+
|
80 |
+
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
|
81 |
+
|
82 |
+
device = torch.device('cuda')
|
83 |
+
|
84 |
+
# load diffusion model
|
85 |
+
print('Loading diffusion model ...')
|
86 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
87 |
+
"sudo-ai/zero123plus-v1.2",
|
88 |
+
custom_pipeline="zero123plus",
|
89 |
+
torch_dtype=torch.float16,
|
90 |
+
cache_dir=model_cache_dir
|
91 |
+
)
|
92 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
93 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
94 |
+
)
|
95 |
+
|
96 |
+
# load custom white-background UNet
|
97 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir)
|
98 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
99 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
100 |
+
|
101 |
+
pipeline = pipeline.to(device0)
|
102 |
+
|
103 |
+
# load reconstruction model
|
104 |
+
print('Loading reconstruction model ...')
|
105 |
+
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir)
|
106 |
+
model = instantiate_from_config(model_config)
|
107 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
108 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
109 |
+
model.load_state_dict(state_dict, strict=True)
|
110 |
+
|
111 |
+
model = model.to(device1)
|
112 |
+
if IS_FLEXICUBES:
|
113 |
+
model.init_flexicubes_geometry(device1, fovy=30.0)
|
114 |
+
model = model.eval()
|
115 |
+
|
116 |
+
print('Loading Finished!')
|
117 |
+
|
118 |
+
|
119 |
+
def check_input_image(input_image):
|
120 |
+
if input_image is None:
|
121 |
+
raise gr.Error("No image uploaded!")
|
122 |
+
|
123 |
+
|
124 |
+
def preprocess(input_image, do_remove_background):
|
125 |
+
|
126 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
127 |
+
if do_remove_background:
|
128 |
+
input_image = remove_background(input_image, rembg_session)
|
129 |
+
input_image = resize_foreground(input_image, 0.85)
|
130 |
+
|
131 |
+
return input_image
|
132 |
+
|
133 |
+
|
134 |
+
def generate_mvs(input_image, sample_steps, sample_seed):
|
135 |
+
|
136 |
+
seed_everything(sample_seed)
|
137 |
+
|
138 |
+
# sampling
|
139 |
+
generator = torch.Generator(device=device0)
|
140 |
+
z123_image = pipeline(
|
141 |
+
input_image,
|
142 |
+
num_inference_steps=sample_steps,
|
143 |
+
generator=generator,
|
144 |
+
).images[0]
|
145 |
+
|
146 |
+
show_image = np.asarray(z123_image, dtype=np.uint8)
|
147 |
+
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
148 |
+
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
149 |
+
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
150 |
+
show_image = Image.fromarray(show_image.numpy())
|
151 |
+
|
152 |
+
return z123_image, show_image
|
153 |
+
|
154 |
+
|
155 |
+
def make_mesh(mesh_fpath, planes):
|
156 |
+
|
157 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
158 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
159 |
+
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
# get mesh
|
163 |
+
|
164 |
+
mesh_out = model.extract_mesh(
|
165 |
+
planes,
|
166 |
+
use_texture_map=False,
|
167 |
+
**infer_config,
|
168 |
+
)
|
169 |
+
|
170 |
+
vertices, faces, vertex_colors = mesh_out
|
171 |
+
vertices = vertices[:, [1, 2, 0]]
|
172 |
+
|
173 |
+
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
174 |
+
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
175 |
+
|
176 |
+
print(f"Mesh saved to {mesh_fpath}")
|
177 |
+
|
178 |
+
return mesh_fpath, mesh_glb_fpath
|
179 |
+
|
180 |
+
|
181 |
+
def make3d(images):
|
182 |
+
|
183 |
+
images = np.asarray(images, dtype=np.float32) / 255.0
|
184 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
185 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
186 |
+
|
187 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device1)
|
188 |
+
render_cameras = get_render_cameras(
|
189 |
+
batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device1)
|
190 |
+
|
191 |
+
images = images.unsqueeze(0).to(device1)
|
192 |
+
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
193 |
+
|
194 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
195 |
+
print(mesh_fpath)
|
196 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
197 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
198 |
+
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
|
199 |
+
|
200 |
+
with torch.no_grad():
|
201 |
+
# get triplane
|
202 |
+
planes = model.forward_planes(images, input_cameras)
|
203 |
+
|
204 |
+
# get video
|
205 |
+
chunk_size = 20 if IS_FLEXICUBES else 1
|
206 |
+
render_size = 384
|
207 |
+
|
208 |
+
frames = []
|
209 |
+
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
210 |
+
if IS_FLEXICUBES:
|
211 |
+
frame = model.forward_geometry(
|
212 |
+
planes,
|
213 |
+
render_cameras[:, i:i+chunk_size],
|
214 |
+
render_size=render_size,
|
215 |
+
)['img']
|
216 |
+
else:
|
217 |
+
frame = model.synthesizer(
|
218 |
+
planes,
|
219 |
+
cameras=render_cameras[:, i:i+chunk_size],
|
220 |
+
render_size=render_size,
|
221 |
+
)['images_rgb']
|
222 |
+
frames.append(frame)
|
223 |
+
frames = torch.cat(frames, dim=1)
|
224 |
+
|
225 |
+
images_to_video(
|
226 |
+
frames[0],
|
227 |
+
video_fpath,
|
228 |
+
fps=30,
|
229 |
+
)
|
230 |
+
|
231 |
+
print(f"Video saved to {video_fpath}")
|
232 |
+
|
233 |
+
mesh_fpath, mesh_glb_fpath = make_mesh(mesh_fpath, planes)
|
234 |
+
|
235 |
+
return video_fpath, mesh_fpath, mesh_glb_fpath
|
236 |
+
|
237 |
+
|
238 |
+
import gradio as gr
|
239 |
+
|
240 |
+
|
241 |
+
with gr.Blocks() as demo:
|
242 |
+
gr.Markdown(_HEADER_)
|
243 |
+
with gr.Row(variant="panel"):
|
244 |
+
with gr.Column():
|
245 |
+
with gr.Row():
|
246 |
+
input_image = gr.Image(
|
247 |
+
label="Input Image",
|
248 |
+
image_mode="RGBA",
|
249 |
+
sources="upload",
|
250 |
+
width=256,
|
251 |
+
height=256,
|
252 |
+
type="pil",
|
253 |
+
elem_id="content_image",
|
254 |
+
)
|
255 |
+
processed_image = gr.Image(
|
256 |
+
label="Processed Image",
|
257 |
+
image_mode="RGBA",
|
258 |
+
width=256,
|
259 |
+
height=256,
|
260 |
+
type="pil",
|
261 |
+
interactive=False
|
262 |
+
)
|
263 |
+
with gr.Row():
|
264 |
+
with gr.Group():
|
265 |
+
do_remove_background = gr.Checkbox(
|
266 |
+
label="Remove Background", value=True
|
267 |
+
)
|
268 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
269 |
+
|
270 |
+
sample_steps = gr.Slider(
|
271 |
+
label="Sample Steps",
|
272 |
+
minimum=30,
|
273 |
+
maximum=75,
|
274 |
+
value=75,
|
275 |
+
step=5
|
276 |
+
)
|
277 |
+
|
278 |
+
with gr.Row():
|
279 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
280 |
+
|
281 |
+
with gr.Row(variant="panel"):
|
282 |
+
gr.Examples(
|
283 |
+
examples=[
|
284 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
285 |
+
],
|
286 |
+
inputs=[input_image],
|
287 |
+
label="Examples",
|
288 |
+
examples_per_page=20
|
289 |
+
)
|
290 |
+
|
291 |
+
with gr.Column():
|
292 |
+
|
293 |
+
with gr.Row():
|
294 |
+
|
295 |
+
with gr.Column():
|
296 |
+
mv_show_images = gr.Image(
|
297 |
+
label="Generated Multi-views",
|
298 |
+
type="pil",
|
299 |
+
width=379,
|
300 |
+
interactive=False
|
301 |
+
)
|
302 |
+
|
303 |
+
with gr.Column():
|
304 |
+
output_video = gr.Video(
|
305 |
+
label="video", format="mp4",
|
306 |
+
width=379,
|
307 |
+
autoplay=True,
|
308 |
+
interactive=False
|
309 |
+
)
|
310 |
+
|
311 |
+
with gr.Row():
|
312 |
+
with gr.Tab("OBJ"):
|
313 |
+
output_model_obj = gr.Model3D(
|
314 |
+
label="Output Model (OBJ Format)",
|
315 |
+
#width=768,
|
316 |
+
interactive=False,
|
317 |
+
)
|
318 |
+
gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
|
319 |
+
with gr.Tab("GLB"):
|
320 |
+
output_model_glb = gr.Model3D(
|
321 |
+
label="Output Model (GLB Format)",
|
322 |
+
#width=768,
|
323 |
+
interactive=False,
|
324 |
+
)
|
325 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
326 |
+
|
327 |
+
with gr.Row():
|
328 |
+
gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
|
329 |
+
|
330 |
+
gr.Markdown(_CITE_)
|
331 |
+
mv_images = gr.State()
|
332 |
+
|
333 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
334 |
+
fn=preprocess,
|
335 |
+
inputs=[input_image, do_remove_background],
|
336 |
+
outputs=[processed_image],
|
337 |
+
).success(
|
338 |
+
fn=generate_mvs,
|
339 |
+
inputs=[processed_image, sample_steps, sample_seed],
|
340 |
+
outputs=[mv_images, mv_show_images],
|
341 |
+
).success(
|
342 |
+
fn=make3d,
|
343 |
+
inputs=[mv_images],
|
344 |
+
outputs=[output_video, output_model_obj, output_model_glb]
|
345 |
+
)
|
346 |
+
|
347 |
+
demo.queue(max_size=10)
|
348 |
+
demo.launch(server_name="0.0.0.0", server_port=43839)
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==2.1.2
|
2 |
+
gradio==3.41.2
|
3 |
+
huggingface-hub
|
4 |
+
einops
|
5 |
+
omegaconf
|
6 |
+
torchmetrics
|
7 |
+
webdataset
|
8 |
+
accelerate
|
9 |
+
tensorboard
|
10 |
+
PyMCubes
|
11 |
+
trimesh
|
12 |
+
rembg
|
13 |
+
transformers==4.34.1
|
14 |
+
diffusers==0.20.2
|
15 |
+
bitsandbytes
|
16 |
+
imageio[ffmpeg]
|
17 |
+
xatlas
|
18 |
+
plyfile
|
19 |
+
git+https://github.com/NVlabs/nvdiffrast/
|
run.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import rembg
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import v2
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from tqdm import tqdm
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
14 |
+
|
15 |
+
from src.utils.train_util import instantiate_from_config
|
16 |
+
from src.utils.camera_util import (
|
17 |
+
FOV_to_intrinsics,
|
18 |
+
get_zero123plus_input_cameras,
|
19 |
+
get_circular_camera_poses,
|
20 |
+
)
|
21 |
+
from src.utils.mesh_util import save_obj, save_obj_with_mtl
|
22 |
+
from src.utils.infer_util import remove_background, resize_foreground, save_video
|
23 |
+
|
24 |
+
|
25 |
+
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False):
|
26 |
+
"""
|
27 |
+
Get the rendering camera parameters.
|
28 |
+
"""
|
29 |
+
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
30 |
+
if is_flexicubes:
|
31 |
+
cameras = torch.linalg.inv(c2ws)
|
32 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
33 |
+
else:
|
34 |
+
extrinsics = c2ws.flatten(-2)
|
35 |
+
intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
36 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
37 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
38 |
+
return cameras
|
39 |
+
|
40 |
+
|
41 |
+
def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False):
|
42 |
+
"""
|
43 |
+
Render frames from triplanes.
|
44 |
+
"""
|
45 |
+
frames = []
|
46 |
+
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
47 |
+
if is_flexicubes:
|
48 |
+
frame = model.forward_geometry(
|
49 |
+
planes,
|
50 |
+
render_cameras[:, i:i+chunk_size],
|
51 |
+
render_size=render_size,
|
52 |
+
)['img']
|
53 |
+
else:
|
54 |
+
frame = model.forward_synthesizer(
|
55 |
+
planes,
|
56 |
+
render_cameras[:, i:i+chunk_size],
|
57 |
+
render_size=render_size,
|
58 |
+
)['images_rgb']
|
59 |
+
frames.append(frame)
|
60 |
+
|
61 |
+
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
|
62 |
+
return frames
|
63 |
+
|
64 |
+
|
65 |
+
###############################################################################
|
66 |
+
# Arguments.
|
67 |
+
###############################################################################
|
68 |
+
|
69 |
+
parser = argparse.ArgumentParser()
|
70 |
+
parser.add_argument('config', type=str, help='Path to config file.')
|
71 |
+
parser.add_argument('input_path', type=str, help='Path to input image or directory.')
|
72 |
+
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
|
73 |
+
parser.add_argument('--diffusion_steps', type=int, default=75, help='Denoising Sampling steps.')
|
74 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
|
75 |
+
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
|
76 |
+
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
|
77 |
+
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
|
78 |
+
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
|
79 |
+
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
|
80 |
+
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
|
81 |
+
args = parser.parse_args()
|
82 |
+
seed_everything(args.seed)
|
83 |
+
|
84 |
+
###############################################################################
|
85 |
+
# Stage 0: Configuration.
|
86 |
+
###############################################################################
|
87 |
+
|
88 |
+
config = OmegaConf.load(args.config)
|
89 |
+
config_name = os.path.basename(args.config).replace('.yaml', '')
|
90 |
+
model_config = config.model_config
|
91 |
+
infer_config = config.infer_config
|
92 |
+
|
93 |
+
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
|
94 |
+
|
95 |
+
device = torch.device('cuda')
|
96 |
+
|
97 |
+
# load diffusion model
|
98 |
+
print('Loading diffusion model ...')
|
99 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
100 |
+
"sudo-ai/zero123plus-v1.2",
|
101 |
+
custom_pipeline="zero123plus",
|
102 |
+
torch_dtype=torch.float16,
|
103 |
+
)
|
104 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
105 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
106 |
+
)
|
107 |
+
|
108 |
+
# load custom white-background UNet
|
109 |
+
print('Loading custom white-background unet ...')
|
110 |
+
if os.path.exists(infer_config.unet_path):
|
111 |
+
unet_ckpt_path = infer_config.unet_path
|
112 |
+
else:
|
113 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
114 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
115 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
116 |
+
|
117 |
+
pipeline = pipeline.to(device)
|
118 |
+
|
119 |
+
# load reconstruction model
|
120 |
+
print('Loading reconstruction model ...')
|
121 |
+
model = instantiate_from_config(model_config)
|
122 |
+
if os.path.exists(infer_config.model_path):
|
123 |
+
model_ckpt_path = infer_config.model_path
|
124 |
+
else:
|
125 |
+
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
|
126 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
127 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
128 |
+
model.load_state_dict(state_dict, strict=True)
|
129 |
+
|
130 |
+
model = model.to(device)
|
131 |
+
if IS_FLEXICUBES:
|
132 |
+
model.init_flexicubes_geometry(device, fovy=30.0)
|
133 |
+
model = model.eval()
|
134 |
+
|
135 |
+
# make output directories
|
136 |
+
image_path = os.path.join(args.output_path, config_name, 'images')
|
137 |
+
mesh_path = os.path.join(args.output_path, config_name, 'meshes')
|
138 |
+
video_path = os.path.join(args.output_path, config_name, 'videos')
|
139 |
+
os.makedirs(image_path, exist_ok=True)
|
140 |
+
os.makedirs(mesh_path, exist_ok=True)
|
141 |
+
os.makedirs(video_path, exist_ok=True)
|
142 |
+
|
143 |
+
# process input files
|
144 |
+
if os.path.isdir(args.input_path):
|
145 |
+
input_files = [
|
146 |
+
os.path.join(args.input_path, file)
|
147 |
+
for file in os.listdir(args.input_path)
|
148 |
+
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
|
149 |
+
]
|
150 |
+
else:
|
151 |
+
input_files = [args.input_path]
|
152 |
+
print(f'Total number of input images: {len(input_files)}')
|
153 |
+
|
154 |
+
|
155 |
+
###############################################################################
|
156 |
+
# Stage 1: Multiview generation.
|
157 |
+
###############################################################################
|
158 |
+
|
159 |
+
rembg_session = None if args.no_rembg else rembg.new_session()
|
160 |
+
|
161 |
+
outputs = []
|
162 |
+
for idx, image_file in enumerate(input_files):
|
163 |
+
name = os.path.basename(image_file).split('.')[0]
|
164 |
+
print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
|
165 |
+
|
166 |
+
# remove background optionally
|
167 |
+
input_image = Image.open(image_file)
|
168 |
+
if not args.no_rembg:
|
169 |
+
input_image = remove_background(input_image, rembg_session)
|
170 |
+
input_image = resize_foreground(input_image, 0.85)
|
171 |
+
|
172 |
+
# sampling
|
173 |
+
output_image = pipeline(
|
174 |
+
input_image,
|
175 |
+
num_inference_steps=args.diffusion_steps,
|
176 |
+
).images[0]
|
177 |
+
|
178 |
+
output_image.save(os.path.join(image_path, f'{name}.png'))
|
179 |
+
print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
|
180 |
+
|
181 |
+
images = np.asarray(output_image, dtype=np.float32) / 255.0
|
182 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
183 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
184 |
+
|
185 |
+
outputs.append({'name': name, 'images': images})
|
186 |
+
|
187 |
+
# delete pipeline to save memory
|
188 |
+
del pipeline
|
189 |
+
|
190 |
+
###############################################################################
|
191 |
+
# Stage 2: Reconstruction.
|
192 |
+
###############################################################################
|
193 |
+
|
194 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0*args.scale).to(device)
|
195 |
+
chunk_size = 20 if IS_FLEXICUBES else 1
|
196 |
+
|
197 |
+
for idx, sample in enumerate(outputs):
|
198 |
+
name = sample['name']
|
199 |
+
print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
|
200 |
+
|
201 |
+
images = sample['images'].unsqueeze(0).to(device)
|
202 |
+
images = v2.functional.resize(images, 320, interpolation=3, antialias=True).clamp(0, 1)
|
203 |
+
|
204 |
+
if args.view == 4:
|
205 |
+
indices = torch.tensor([0, 2, 4, 5]).long().to(device)
|
206 |
+
images = images[:, indices]
|
207 |
+
input_cameras = input_cameras[:, indices]
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
# get triplane
|
211 |
+
planes = model.forward_planes(images, input_cameras)
|
212 |
+
|
213 |
+
# get mesh
|
214 |
+
mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
|
215 |
+
|
216 |
+
mesh_out = model.extract_mesh(
|
217 |
+
planes,
|
218 |
+
use_texture_map=args.export_texmap,
|
219 |
+
**infer_config,
|
220 |
+
)
|
221 |
+
if args.export_texmap:
|
222 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
223 |
+
save_obj_with_mtl(
|
224 |
+
vertices.data.cpu().numpy(),
|
225 |
+
uvs.data.cpu().numpy(),
|
226 |
+
faces.data.cpu().numpy(),
|
227 |
+
mesh_tex_idx.data.cpu().numpy(),
|
228 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
229 |
+
mesh_path_idx,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
vertices, faces, vertex_colors = mesh_out
|
233 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
234 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
235 |
+
|
236 |
+
# get video
|
237 |
+
if args.save_video:
|
238 |
+
video_path_idx = os.path.join(video_path, f'{name}.mp4')
|
239 |
+
render_size = infer_config.render_resolution
|
240 |
+
render_cameras = get_render_cameras(
|
241 |
+
batch_size=1,
|
242 |
+
M=120,
|
243 |
+
radius=args.distance,
|
244 |
+
elevation=20.0,
|
245 |
+
is_flexicubes=IS_FLEXICUBES,
|
246 |
+
).to(device)
|
247 |
+
|
248 |
+
frames = render_frames(
|
249 |
+
model,
|
250 |
+
planes,
|
251 |
+
render_cameras=render_cameras,
|
252 |
+
render_size=render_size,
|
253 |
+
chunk_size=chunk_size,
|
254 |
+
is_flexicubes=IS_FLEXICUBES,
|
255 |
+
)
|
256 |
+
|
257 |
+
save_video(
|
258 |
+
frames,
|
259 |
+
video_path_idx,
|
260 |
+
fps=30,
|
261 |
+
)
|
262 |
+
print(f"Video saved to {video_path_idx}")
|
train.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import argparse
|
3 |
+
import shutil
|
4 |
+
import subprocess
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
|
7 |
+
from pytorch_lightning import seed_everything
|
8 |
+
from pytorch_lightning.trainer import Trainer
|
9 |
+
from pytorch_lightning.strategies import DDPStrategy
|
10 |
+
from pytorch_lightning.callbacks import Callback
|
11 |
+
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
12 |
+
|
13 |
+
from src.utils.train_util import instantiate_from_config
|
14 |
+
|
15 |
+
|
16 |
+
@rank_zero_only
|
17 |
+
def rank_zero_print(*args):
|
18 |
+
print(*args)
|
19 |
+
|
20 |
+
|
21 |
+
def get_parser(**parser_kwargs):
|
22 |
+
def str2bool(v):
|
23 |
+
if isinstance(v, bool):
|
24 |
+
return v
|
25 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
26 |
+
return True
|
27 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
28 |
+
return False
|
29 |
+
else:
|
30 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
31 |
+
|
32 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
33 |
+
parser.add_argument(
|
34 |
+
"-r",
|
35 |
+
"--resume",
|
36 |
+
type=str,
|
37 |
+
default=None,
|
38 |
+
help="resume from checkpoint",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--resume_weights_only",
|
42 |
+
action="store_true",
|
43 |
+
help="only resume model weights",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"-b",
|
47 |
+
"--base",
|
48 |
+
type=str,
|
49 |
+
default="base_config.yaml",
|
50 |
+
help="path to base configs",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"-n",
|
54 |
+
"--name",
|
55 |
+
type=str,
|
56 |
+
default="",
|
57 |
+
help="experiment name",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--num_nodes",
|
61 |
+
type=int,
|
62 |
+
default=1,
|
63 |
+
help="number of nodes to use",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--gpus",
|
67 |
+
type=str,
|
68 |
+
default="0,",
|
69 |
+
help="gpu ids to use",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"-s",
|
73 |
+
"--seed",
|
74 |
+
type=int,
|
75 |
+
default=42,
|
76 |
+
help="seed for seed_everything",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"-l",
|
80 |
+
"--logdir",
|
81 |
+
type=str,
|
82 |
+
default="logs",
|
83 |
+
help="directory for logging data",
|
84 |
+
)
|
85 |
+
return parser
|
86 |
+
|
87 |
+
|
88 |
+
class SetupCallback(Callback):
|
89 |
+
def __init__(self, resume, logdir, ckptdir, cfgdir, config):
|
90 |
+
super().__init__()
|
91 |
+
self.resume = resume
|
92 |
+
self.logdir = logdir
|
93 |
+
self.ckptdir = ckptdir
|
94 |
+
self.cfgdir = cfgdir
|
95 |
+
self.config = config
|
96 |
+
|
97 |
+
def on_fit_start(self, trainer, pl_module):
|
98 |
+
if trainer.global_rank == 0:
|
99 |
+
# Create logdirs and save configs
|
100 |
+
os.makedirs(self.logdir, exist_ok=True)
|
101 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
102 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
103 |
+
|
104 |
+
rank_zero_print("Project config")
|
105 |
+
rank_zero_print(OmegaConf.to_yaml(self.config))
|
106 |
+
OmegaConf.save(self.config,
|
107 |
+
os.path.join(self.cfgdir, "project.yaml"))
|
108 |
+
|
109 |
+
|
110 |
+
class CodeSnapshot(Callback):
|
111 |
+
"""
|
112 |
+
Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60
|
113 |
+
"""
|
114 |
+
def __init__(self, savedir):
|
115 |
+
self.savedir = savedir
|
116 |
+
|
117 |
+
def get_file_list(self):
|
118 |
+
return [
|
119 |
+
b.decode()
|
120 |
+
for b in set(
|
121 |
+
subprocess.check_output(
|
122 |
+
'git ls-files -- ":!:configs/*"', shell=True
|
123 |
+
).splitlines()
|
124 |
+
)
|
125 |
+
| set( # hard code, TODO: use config to exclude folders or files
|
126 |
+
subprocess.check_output(
|
127 |
+
"git ls-files --others --exclude-standard", shell=True
|
128 |
+
).splitlines()
|
129 |
+
)
|
130 |
+
]
|
131 |
+
|
132 |
+
@rank_zero_only
|
133 |
+
def save_code_snapshot(self):
|
134 |
+
os.makedirs(self.savedir, exist_ok=True)
|
135 |
+
for f in self.get_file_list():
|
136 |
+
if not os.path.exists(f) or os.path.isdir(f):
|
137 |
+
continue
|
138 |
+
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
|
139 |
+
shutil.copyfile(f, os.path.join(self.savedir, f))
|
140 |
+
|
141 |
+
def on_fit_start(self, trainer, pl_module):
|
142 |
+
try:
|
143 |
+
self.save_code_snapshot()
|
144 |
+
except:
|
145 |
+
rank_zero_warn(
|
146 |
+
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
# add cwd for convenience and to make classes in this file available when
|
152 |
+
# running as `python main.py`
|
153 |
+
sys.path.append(os.getcwd())
|
154 |
+
|
155 |
+
parser = get_parser()
|
156 |
+
opt, unknown = parser.parse_known_args()
|
157 |
+
|
158 |
+
cfg_fname = os.path.split(opt.base)[-1]
|
159 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
160 |
+
exp_name = "-" + opt.name if opt.name != "" else ""
|
161 |
+
logdir = os.path.join(opt.logdir, cfg_name+exp_name)
|
162 |
+
|
163 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
164 |
+
cfgdir = os.path.join(logdir, "configs")
|
165 |
+
codedir = os.path.join(logdir, "code")
|
166 |
+
seed_everything(opt.seed)
|
167 |
+
|
168 |
+
# init configs
|
169 |
+
config = OmegaConf.load(opt.base)
|
170 |
+
lightning_config = config.lightning
|
171 |
+
trainer_config = lightning_config.trainer
|
172 |
+
|
173 |
+
trainer_config["accelerator"] = "gpu"
|
174 |
+
rank_zero_print(f"Running on GPUs {opt.gpus}")
|
175 |
+
ngpu = len(opt.gpus.strip(",").split(','))
|
176 |
+
trainer_config['devices'] = ngpu
|
177 |
+
|
178 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
179 |
+
lightning_config.trainer = trainer_config
|
180 |
+
|
181 |
+
# model
|
182 |
+
model = instantiate_from_config(config.model)
|
183 |
+
if opt.resume and opt.resume_weights_only:
|
184 |
+
model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params)
|
185 |
+
|
186 |
+
model.logdir = logdir
|
187 |
+
|
188 |
+
# trainer and callbacks
|
189 |
+
trainer_kwargs = dict()
|
190 |
+
|
191 |
+
# logger
|
192 |
+
default_logger_cfg = {
|
193 |
+
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
194 |
+
"params": {
|
195 |
+
"name": "tensorboard",
|
196 |
+
"save_dir": logdir,
|
197 |
+
"version": "0",
|
198 |
+
}
|
199 |
+
}
|
200 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg)
|
201 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
202 |
+
|
203 |
+
# model checkpoint
|
204 |
+
default_modelckpt_cfg = {
|
205 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
206 |
+
"params": {
|
207 |
+
"dirpath": ckptdir,
|
208 |
+
"filename": "{step:08}",
|
209 |
+
"verbose": True,
|
210 |
+
"save_last": True,
|
211 |
+
"every_n_train_steps": 5000,
|
212 |
+
"save_top_k": -1, # save all checkpoints
|
213 |
+
}
|
214 |
+
}
|
215 |
+
|
216 |
+
if "modelcheckpoint" in lightning_config:
|
217 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
218 |
+
else:
|
219 |
+
modelckpt_cfg = OmegaConf.create()
|
220 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
221 |
+
|
222 |
+
# callbacks
|
223 |
+
default_callbacks_cfg = {
|
224 |
+
"setup_callback": {
|
225 |
+
"target": "train.SetupCallback",
|
226 |
+
"params": {
|
227 |
+
"resume": opt.resume,
|
228 |
+
"logdir": logdir,
|
229 |
+
"ckptdir": ckptdir,
|
230 |
+
"cfgdir": cfgdir,
|
231 |
+
"config": config,
|
232 |
+
}
|
233 |
+
},
|
234 |
+
"learning_rate_logger": {
|
235 |
+
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
|
236 |
+
"params": {
|
237 |
+
"logging_interval": "step",
|
238 |
+
}
|
239 |
+
},
|
240 |
+
"code_snapshot": {
|
241 |
+
"target": "train.CodeSnapshot",
|
242 |
+
"params": {
|
243 |
+
"savedir": codedir,
|
244 |
+
}
|
245 |
+
},
|
246 |
+
}
|
247 |
+
default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg
|
248 |
+
|
249 |
+
if "callbacks" in lightning_config:
|
250 |
+
callbacks_cfg = lightning_config.callbacks
|
251 |
+
else:
|
252 |
+
callbacks_cfg = OmegaConf.create()
|
253 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
254 |
+
|
255 |
+
trainer_kwargs["callbacks"] = [
|
256 |
+
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
257 |
+
|
258 |
+
trainer_kwargs['precision'] = '32-true'
|
259 |
+
trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=True)
|
260 |
+
|
261 |
+
# trainer
|
262 |
+
trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes)
|
263 |
+
trainer.logdir = logdir
|
264 |
+
|
265 |
+
# data
|
266 |
+
data = instantiate_from_config(config.data)
|
267 |
+
data.prepare_data()
|
268 |
+
data.setup("fit")
|
269 |
+
|
270 |
+
# configure learning rate
|
271 |
+
base_lr = config.model.base_learning_rate
|
272 |
+
if 'accumulate_grad_batches' in lightning_config.trainer:
|
273 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
274 |
+
else:
|
275 |
+
accumulate_grad_batches = 1
|
276 |
+
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
277 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
278 |
+
model.learning_rate = base_lr
|
279 |
+
rank_zero_print("++++ NOT USING LR SCALING ++++")
|
280 |
+
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
|
281 |
+
|
282 |
+
# run training loop
|
283 |
+
if opt.resume and not opt.resume_weights_only:
|
284 |
+
trainer.fit(model, data, ckpt_path=opt.resume)
|
285 |
+
else:
|
286 |
+
trainer.fit(model, data)
|