CazC commited on
Commit
9c98fd2
·
verified ·
1 Parent(s): 689b2fa

Upload 49 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. TripoSR +0 -1
  3. TripoSR/.gitignore +164 -0
  4. TripoSR/LICENSE +21 -0
  5. TripoSR/README.md +80 -0
  6. TripoSR/__pycache__/obj_gen.cpython-310.pyc +0 -0
  7. TripoSR/examples/captured.jpeg +3 -0
  8. TripoSR/examples/captured_p.png +3 -0
  9. TripoSR/examples/chair.png +0 -0
  10. TripoSR/examples/flamingo.png +0 -0
  11. TripoSR/examples/hamburger.png +0 -0
  12. TripoSR/examples/horse.png +0 -0
  13. TripoSR/examples/iso_house.png +3 -0
  14. TripoSR/examples/marble.png +0 -0
  15. TripoSR/examples/police_woman.png +0 -0
  16. TripoSR/examples/poly_fox.png +0 -0
  17. TripoSR/examples/robot.png +0 -0
  18. TripoSR/examples/stripes.png +0 -0
  19. TripoSR/examples/teapot.png +0 -0
  20. TripoSR/examples/tiger_girl.png +0 -0
  21. TripoSR/examples/unicorn.png +0 -0
  22. TripoSR/figures/comparison800.gif +3 -0
  23. TripoSR/figures/scatter-comparison.png +0 -0
  24. TripoSR/figures/teaser800.gif +3 -0
  25. TripoSR/figures/visual_comparisons.jpg +3 -0
  26. TripoSR/gradio_app.py +187 -0
  27. TripoSR/obj_gen.py +92 -0
  28. TripoSR/output/0/input.png +0 -0
  29. TripoSR/output/0/mesh.obj +0 -0
  30. TripoSR/requirements.txt +9 -0
  31. TripoSR/run.py +162 -0
  32. TripoSR/tsr/__pycache__/system.cpython-310.pyc +0 -0
  33. TripoSR/tsr/__pycache__/utils.cpython-310.pyc +0 -0
  34. TripoSR/tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
  35. TripoSR/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
  36. TripoSR/tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
  37. TripoSR/tsr/models/isosurface.py +52 -0
  38. TripoSR/tsr/models/nerf_renderer.py +180 -0
  39. TripoSR/tsr/models/network_utils.py +124 -0
  40. TripoSR/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
  41. TripoSR/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
  42. TripoSR/tsr/models/tokenizers/image.py +66 -0
  43. TripoSR/tsr/models/tokenizers/triplane.py +45 -0
  44. TripoSR/tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
  45. TripoSR/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
  46. TripoSR/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
  47. TripoSR/tsr/models/transformer/attention.py +653 -0
  48. TripoSR/tsr/models/transformer/basic_transformer_block.py +334 -0
  49. TripoSR/tsr/models/transformer/transformer_1d.py +219 -0
  50. TripoSR/tsr/system.py +203 -0
.gitattributes CHANGED
@@ -34,3 +34,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ TripoSR/examples/captured_p.png filter=lfs diff=lfs merge=lfs -text
38
+ TripoSR/examples/captured.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ TripoSR/examples/iso_house.png filter=lfs diff=lfs merge=lfs -text
40
+ TripoSR/figures/comparison800.gif filter=lfs diff=lfs merge=lfs -text
41
+ TripoSR/figures/teaser800.gif filter=lfs diff=lfs merge=lfs -text
42
+ TripoSR/figures/visual_comparisons.jpg filter=lfs diff=lfs merge=lfs -text
TripoSR DELETED
@@ -1 +0,0 @@
1
- Subproject commit 8e51fec8095c9eae20e6ea7c9aef6368c5631a21
 
 
TripoSR/.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # default output directory
163
+ output/
164
+ outputs/
TripoSR/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Tripo AI & Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
TripoSR/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TripoSR <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <a href="https://huggingface.co/spaces/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a> <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/Arxiv-2403.02151-B31B1B.svg"></a>
2
+
3
+ <div align="center">
4
+ <img src="figures/teaser800.gif" alt="Teaser Video">
5
+ </div>
6
+
7
+ This is the official codebase for **TripoSR**, a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
8
+ <br><br>
9
+ Leveraging the principles of the [Large Reconstruction Model (LRM)](https://yiconghong.me/LRM/), TripoSR brings to the table key advancements that significantly boost both the speed and quality of 3D reconstruction. Our model is distinguished by its ability to rapidly process inputs, generating high-quality 3D models in less than 0.5 seconds on an NVIDIA A100 GPU. TripoSR has exhibited superior performance in both qualitative and quantitative evaluations, outperforming other open-source alternatives across multiple public datasets. The figures below illustrate visual comparisons and metrics showcasing TripoSR's performance relative to other leading models. Details about the model architecture, training process, and comparisons can be found in this [technical report](https://arxiv.org/abs/2403.02151).
10
+
11
+ <!--
12
+ <div align="center">
13
+ <img src="figures/comparison800.gif" alt="Teaser Video">
14
+ </div>
15
+ -->
16
+ <p align="center">
17
+ <img width="800" src="figures/visual_comparisons.jpg"/>
18
+ </p>
19
+
20
+ <p align="center">
21
+ <img width="450" src="figures/scatter-comparison.png"/>
22
+ </p>
23
+
24
+
25
+ The model is released under the MIT license, which includes the source code, pretrained models, and an interactive online demo. Our goal is to empower researchers, developers, and creatives to push the boundaries of what's possible in 3D generative AI and 3D content creation.
26
+
27
+ ## Getting Started
28
+ ### Installation
29
+ - Python >= 3.8
30
+ - Install CUDA if available
31
+ - Install PyTorch according to your platform: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) **[Please make sure that the locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.]**
32
+ - Update setuptools by `pip install --upgrade setuptools`
33
+ - Install other dependencies by `pip install -r requirements.txt`
34
+
35
+ ### Manual Inference
36
+ ```sh
37
+ python run.py examples/chair.png --output-dir output/
38
+ ```
39
+ This will save the reconstructed 3D model to `output/`. You can also specify more than one image path separated by spaces. The default options takes about **6GB VRAM** for a single image input.
40
+
41
+ For detailed usage of this script, use `python run.py --help`.
42
+
43
+ ### Local Gradio App
44
+ Install Gradio:
45
+ ```sh
46
+ pip install gradio
47
+ ```
48
+ Start the Gradio App:
49
+ ```sh
50
+ python gradio_app.py
51
+ ```
52
+
53
+ ## Troubleshooting
54
+ > AttributeError: module 'torchmcubes_module' has no attribute 'mcubes_cuda'
55
+
56
+ or
57
+
58
+ > torchmcubes was not compiled with CUDA support, use CPU version instead.
59
+
60
+ This is because `torchmcubes` is compiled without CUDA support. Please make sure that
61
+
62
+ - The locally-installed CUDA major version matches the PyTorch-shipped CUDA major version. For example if you have CUDA 11.x installed, make sure to install PyTorch compiled with CUDA 11.x.
63
+ - `setuptools>=49.6.0`. If not, upgrade by `pip install --upgrade setuptools`.
64
+
65
+ Then re-install `torchmcubes` by:
66
+
67
+ ```sh
68
+ pip uninstall torchmcubes
69
+ pip install git+https://github.com/tatsy/torchmcubes.git
70
+ ```
71
+
72
+ ## Citation
73
+ ```BibTeX
74
+ @article{TripoSR2024,
75
+ title={TripoSR: Fast 3D Object Reconstruction from a Single Image},
76
+ author={Tochilkin, Dmitry and Pankratz, David and Liu, Zexiang and Huang, Zixuan and and Letts, Adam and Li, Yangguang and Liang, Ding and Laforte, Christian and Jampani, Varun and Cao, Yan-Pei},
77
+ journal={arXiv preprint arXiv:2403.02151},
78
+ year={2024}
79
+ }
80
+ ```
TripoSR/__pycache__/obj_gen.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
TripoSR/examples/captured.jpeg ADDED

Git LFS Details

  • SHA256: c6eb2768703a0e3d6034daa7fd5e0b286450b1077a90f36da8110749bb1cb8a8
  • Pointer size: 132 Bytes
  • Size of remote file: 5.94 MB
TripoSR/examples/captured_p.png ADDED

Git LFS Details

  • SHA256: 32d34c2f3a7d0ff7a5002ed98e0866679918cd6e2acf08ffbf6c50173a9f1c6e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.35 MB
TripoSR/examples/chair.png ADDED
TripoSR/examples/flamingo.png ADDED
TripoSR/examples/hamburger.png ADDED
TripoSR/examples/horse.png ADDED
TripoSR/examples/iso_house.png ADDED

Git LFS Details

  • SHA256: e48818cf2d5edc15d4da67467c6e82cf08b0d4e24591b60ca0572e16f09f98c1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
TripoSR/examples/marble.png ADDED
TripoSR/examples/police_woman.png ADDED
TripoSR/examples/poly_fox.png ADDED
TripoSR/examples/robot.png ADDED
TripoSR/examples/stripes.png ADDED
TripoSR/examples/teapot.png ADDED
TripoSR/examples/tiger_girl.png ADDED
TripoSR/examples/unicorn.png ADDED
TripoSR/figures/comparison800.gif ADDED

Git LFS Details

  • SHA256: 887e69297e4446f122801ff2cc39962eda0933906d7ed7be7abf659e721914be
  • Pointer size: 132 Bytes
  • Size of remote file: 8.87 MB
TripoSR/figures/scatter-comparison.png ADDED
TripoSR/figures/teaser800.gif ADDED

Git LFS Details

  • SHA256: 52ecc6ff24e008b0d28236425a1b59718931841f6fb9f5e6f8471829fc9bc292
  • Pointer size: 132 Bytes
  • Size of remote file: 3.84 MB
TripoSR/figures/visual_comparisons.jpg ADDED

Git LFS Details

  • SHA256: 019235d716d8832aaa659acd31cf17267af94df6b5a9beca2a7002b41d59c8db
  • Pointer size: 133 Bytes
  • Size of remote file: 10.3 MB
TripoSR/gradio_app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+ from functools import partial
12
+
13
+ from tsr.system import TSR
14
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ import argparse
17
+
18
+
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+ def check_input_image(input_image):
38
+ if input_image is None:
39
+ raise gr.Error("No image uploaded!")
40
+
41
+
42
+ def preprocess(input_image, do_remove_background, foreground_ratio):
43
+ def fill_background(image):
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
46
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
47
+ return image
48
+
49
+ if do_remove_background:
50
+ image = input_image.convert("RGB")
51
+ image = remove_background(image, rembg_session)
52
+ image = resize_foreground(image, foreground_ratio)
53
+ image = fill_background(image)
54
+ else:
55
+ image = input_image
56
+ if image.mode == "RGBA":
57
+ image = fill_background(image)
58
+ return image
59
+
60
+
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
+ scene_codes = model(image, device=device)
63
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
64
+ mesh = to_gradio_3d_orientation(mesh)
65
+ rv = []
66
+ for format in formats:
67
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
68
+ mesh.export(mesh_path.name)
69
+ rv.append(mesh_path.name)
70
+ return rv
71
+
72
+
73
+ def run_example(image_pil):
74
+ preprocessed = preprocess(image_pil, False, 0.9)
75
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
76
+ return preprocessed, mesh_name_obj, mesh_name_glb
77
+
78
+
79
+ with gr.Blocks(title="TripoSR") as interface:
80
+ gr.Markdown(
81
+ """
82
+ # TripoSR Demo
83
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
84
+
85
+ **Tips:**
86
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
87
+ 2. You can disable "Remove Background" for the provided examples since they have been already preprocessed.
88
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
89
+ """
90
+ )
91
+ with gr.Row(variant="panel"):
92
+ with gr.Column():
93
+ with gr.Row():
94
+ input_image = gr.Image(
95
+ label="Input Image",
96
+ image_mode="RGBA",
97
+ sources="upload",
98
+ type="pil",
99
+ elem_id="content_image",
100
+ )
101
+ processed_image = gr.Image(label="Processed Image", interactive=False)
102
+ with gr.Row():
103
+ with gr.Group():
104
+ do_remove_background = gr.Checkbox(
105
+ label="Remove Background", value=True
106
+ )
107
+ foreground_ratio = gr.Slider(
108
+ label="Foreground Ratio",
109
+ minimum=0.5,
110
+ maximum=1.0,
111
+ value=0.85,
112
+ step=0.05,
113
+ )
114
+ mc_resolution = gr.Slider(
115
+ label="Marching Cubes Resolution",
116
+ minimum=32,
117
+ maximum=320,
118
+ value=256,
119
+ step=32
120
+ )
121
+ with gr.Row():
122
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
123
+ with gr.Column():
124
+ with gr.Tab("OBJ"):
125
+ output_model_obj = gr.Model3D(
126
+ label="Output Model (OBJ Format)",
127
+ interactive=False,
128
+ )
129
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
130
+ with gr.Tab("GLB"):
131
+ output_model_glb = gr.Model3D(
132
+ label="Output Model (GLB Format)",
133
+ interactive=False,
134
+ )
135
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
136
+ with gr.Row(variant="panel"):
137
+ gr.Examples(
138
+ examples=[
139
+ "examples/hamburger.png",
140
+ "examples/poly_fox.png",
141
+ "examples/robot.png",
142
+ "examples/teapot.png",
143
+ "examples/tiger_girl.png",
144
+ "examples/horse.png",
145
+ "examples/flamingo.png",
146
+ "examples/unicorn.png",
147
+ "examples/chair.png",
148
+ "examples/iso_house.png",
149
+ "examples/marble.png",
150
+ "examples/police_woman.png",
151
+ "examples/captured_p.png",
152
+ ],
153
+ inputs=[input_image],
154
+ outputs=[processed_image, output_model_obj, output_model_glb],
155
+ cache_examples=False,
156
+ fn=partial(run_example),
157
+ label="Examples",
158
+ examples_per_page=20,
159
+ )
160
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
161
+ fn=preprocess,
162
+ inputs=[input_image, do_remove_background, foreground_ratio],
163
+ outputs=[processed_image],
164
+ ).success(
165
+ fn=generate,
166
+ inputs=[processed_image, mc_resolution],
167
+ outputs=[output_model_obj, output_model_glb],
168
+ )
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser()
174
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
175
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
176
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
177
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
178
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
179
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
180
+ args = parser.parse_args()
181
+ interface.queue(max_size=args.queuesize)
182
+ interface.launch(
183
+ auth=(args.username, args.password) if (args.username and args.password) else None,
184
+ share=args.share,
185
+ server_name="0.0.0.0" if args.listen else None,
186
+ server_port=args.port
187
+ )
TripoSR/obj_gen.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+ from functools import partial
11
+
12
+ from tsr.system import TSR
13
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
14
+
15
+ import argparse
16
+ from dotenv import load_dotenv
17
+
18
+ load_dotenv()
19
+
20
+
21
+
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+
38
+
39
+ def preprocess(input_image, do_remove_background, foreground_ratio):
40
+ def fill_background(image):
41
+ image = np.array(image).astype(np.float32) / 255.0
42
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
43
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
44
+ return image
45
+
46
+ if do_remove_background:
47
+ image = input_image.convert("RGB")
48
+ image = remove_background(image, rembg_session)
49
+ image = resize_foreground(image, foreground_ratio)
50
+ image = fill_background(image)
51
+ else:
52
+ image = input_image
53
+ if image.mode == "RGBA":
54
+ image = fill_background(image)
55
+ return image
56
+
57
+
58
+ def generate(image, mc_resolution, formats=["obj", "glb"], path="output.obj"):
59
+ scene_codes = model(image, device=device)
60
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
61
+ mesh = to_gradio_3d_orientation(mesh)
62
+ rv = []
63
+ for format in formats:
64
+ mesh_path = path.replace(".obj", f".{format}")
65
+ mesh.export(mesh_path)
66
+ rv.append(mesh_path)
67
+ return rv
68
+
69
+
70
+ def run_example(image_pil):
71
+ preprocessed = preprocess(image_pil, False, 0.9)
72
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
73
+ return preprocessed, mesh_name_obj, mesh_name_glb
74
+
75
+ def generate_obj_from_image(image_pil, path="output.obj"):
76
+ # Preprocess the image without removing the background and with a foreground ratio of 0.9
77
+ preprocessed = preprocess(image_pil, True, 0.9)
78
+
79
+ # Generate the mesh and get the paths to the .obj and .glb files
80
+ mesh_paths = generate(preprocessed, 256, ["obj"], path)
81
+
82
+ # Return the path to the .obj file
83
+ return mesh_paths[0]
84
+
85
+ if __name__ == "__main__":
86
+ # run a test
87
+ image_path = "output.png"
88
+ image = Image.open(image_path)
89
+ generate_obj_from_image(image, "output.obj")
90
+ # move the .obj file to the output directory
91
+
92
+
TripoSR/output/0/input.png ADDED
TripoSR/output/0/mesh.obj ADDED
The diff for this file is too large to render. See raw diff
 
TripoSR/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
5
+ transformers==4.35.0
6
+ trimesh==4.0.5
7
+ rembg
8
+ huggingface-hub
9
+ imageio[ffmpeg]
TripoSR/run.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import remove_background, resize_foreground, save_video
13
+
14
+
15
+ class Timer:
16
+ def __init__(self):
17
+ self.items = {}
18
+ self.time_scale = 1000.0 # ms
19
+ self.time_unit = "ms"
20
+
21
+ def start(self, name: str) -> None:
22
+ if torch.cuda.is_available():
23
+ torch.cuda.synchronize()
24
+ self.items[name] = time.time()
25
+ logging.info(f"{name} ...")
26
+
27
+ def end(self, name: str) -> float:
28
+ if name not in self.items:
29
+ return
30
+ if torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ start_time = self.items.pop(name)
33
+ delta = time.time() - start_time
34
+ t = delta * self.time_scale
35
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
36
+
37
+
38
+ timer = Timer()
39
+
40
+
41
+ logging.basicConfig(
42
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
43
+ )
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
46
+ parser.add_argument(
47
+ "--device",
48
+ default="cuda:0",
49
+ type=str,
50
+ help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
51
+ )
52
+ parser.add_argument(
53
+ "--pretrained-model-name-or-path",
54
+ default="stabilityai/TripoSR",
55
+ type=str,
56
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
57
+ )
58
+ parser.add_argument(
59
+ "--chunk-size",
60
+ default=8192,
61
+ type=int,
62
+ help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
63
+ )
64
+ parser.add_argument(
65
+ "--mc-resolution",
66
+ default=256,
67
+ type=int,
68
+ help="Marching cubes grid resolution. Default: 256"
69
+ )
70
+ parser.add_argument(
71
+ "--no-remove-bg",
72
+ action="store_true",
73
+ help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
74
+ )
75
+ parser.add_argument(
76
+ "--foreground-ratio",
77
+ default=0.85,
78
+ type=float,
79
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
80
+ )
81
+ parser.add_argument(
82
+ "--output-dir",
83
+ default="output/",
84
+ type=str,
85
+ help="Output directory to save the results. Default: 'output/'",
86
+ )
87
+ parser.add_argument(
88
+ "--model-save-format",
89
+ default="obj",
90
+ type=str,
91
+ choices=["obj", "glb"],
92
+ help="Format to save the extracted mesh. Default: 'obj'",
93
+ )
94
+ parser.add_argument(
95
+ "--render",
96
+ action="store_true",
97
+ help="If specified, save a NeRF-rendered video. Default: false",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ output_dir = args.output_dir
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ device = args.device
105
+ if not torch.cuda.is_available():
106
+ device = "cpu"
107
+
108
+ timer.start("Initializing model")
109
+ model = TSR.from_pretrained(
110
+ args.pretrained_model_name_or_path,
111
+ config_name="config.yaml",
112
+ weight_name="model.ckpt",
113
+ )
114
+ model.renderer.set_chunk_size(args.chunk_size)
115
+ model.to(device)
116
+ timer.end("Initializing model")
117
+
118
+ timer.start("Processing images")
119
+ images = []
120
+
121
+ if args.no_remove_bg:
122
+ rembg_session = None
123
+ else:
124
+ rembg_session = rembg.new_session()
125
+
126
+ for i, image_path in enumerate(args.image):
127
+ if args.no_remove_bg:
128
+ image = np.array(Image.open(image_path).convert("RGB"))
129
+ else:
130
+ image = remove_background(Image.open(image_path), rembg_session)
131
+ image = resize_foreground(image, args.foreground_ratio)
132
+ image = np.array(image).astype(np.float32) / 255.0
133
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
134
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
135
+ if not os.path.exists(os.path.join(output_dir, str(i))):
136
+ os.makedirs(os.path.join(output_dir, str(i)))
137
+ image.save(os.path.join(output_dir, str(i), f"input.png"))
138
+ images.append(image)
139
+ timer.end("Processing images")
140
+
141
+ for i, image in enumerate(images):
142
+ logging.info(f"Running image {i + 1}/{len(images)} ...")
143
+
144
+ timer.start("Running model")
145
+ with torch.no_grad():
146
+ scene_codes = model([image], device=device)
147
+ timer.end("Running model")
148
+
149
+ if args.render:
150
+ timer.start("Rendering")
151
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
152
+ for ri, render_image in enumerate(render_images[0]):
153
+ render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
154
+ save_video(
155
+ render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
156
+ )
157
+ timer.end("Rendering")
158
+
159
+ timer.start("Exporting mesh")
160
+ meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
161
+ meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
162
+ timer.end("Exporting mesh")
TripoSR/tsr/__pycache__/system.cpython-310.pyc ADDED
Binary file (5.15 kB). View file
 
TripoSR/tsr/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
TripoSR/tsr/models/__pycache__/isosurface.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
TripoSR/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc ADDED
Binary file (5.28 kB). View file
 
TripoSR/tsr/models/__pycache__/network_utils.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
TripoSR/tsr/models/isosurface.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchmcubes import marching_cubes
7
+
8
+
9
+ class IsosurfaceHelper(nn.Module):
10
+ points_range: Tuple[float, float] = (0, 1)
11
+
12
+ @property
13
+ def grid_vertices(self) -> torch.FloatTensor:
14
+ raise NotImplementedError
15
+
16
+
17
+ class MarchingCubeHelper(IsosurfaceHelper):
18
+ def __init__(self, resolution: int) -> None:
19
+ super().__init__()
20
+ self.resolution = resolution
21
+ self.mc_func: Callable = marching_cubes
22
+ self._grid_vertices: Optional[torch.FloatTensor] = None
23
+
24
+ @property
25
+ def grid_vertices(self) -> torch.FloatTensor:
26
+ if self._grid_vertices is None:
27
+ # keep the vertices on CPU so that we can support very large resolution
28
+ x, y, z = (
29
+ torch.linspace(*self.points_range, self.resolution),
30
+ torch.linspace(*self.points_range, self.resolution),
31
+ torch.linspace(*self.points_range, self.resolution),
32
+ )
33
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
34
+ verts = torch.cat(
35
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
36
+ ).reshape(-1, 3)
37
+ self._grid_vertices = verts
38
+ return self._grid_vertices
39
+
40
+ def forward(
41
+ self,
42
+ level: torch.FloatTensor,
43
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
+ level = -level.view(self.resolution, self.resolution, self.resolution)
45
+ try:
46
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
47
+ except AttributeError:
48
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
49
+ v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
50
+ v_pos = v_pos[..., [2, 1, 0]]
51
+ v_pos = v_pos / (self.resolution - 1.0)
52
+ return v_pos.to(level.device), t_pos_idx.to(level.device)
TripoSR/tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()
TripoSR/tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out
TripoSR/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
TripoSR/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
TripoSR/tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers.models.vit.modeling_vit import ViTModel
8
+
9
+ from ...utils import BaseModule
10
+
11
+
12
+ class DINOSingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16
+ enable_gradient_checkpointing: bool = False
17
+
18
+ cfg: Config
19
+
20
+ def configure(self) -> None:
21
+ self.model: ViTModel = ViTModel(
22
+ ViTModel.config_class.from_pretrained(
23
+ hf_hub_download(
24
+ repo_id=self.cfg.pretrained_model_name_or_path,
25
+ filename="config.json",
26
+ )
27
+ )
28
+ )
29
+
30
+ if self.cfg.enable_gradient_checkpointing:
31
+ self.model.encoder.gradient_checkpointing = True
32
+
33
+ self.register_buffer(
34
+ "image_mean",
35
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36
+ persistent=False,
37
+ )
38
+ self.register_buffer(
39
+ "image_std",
40
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41
+ persistent=False,
42
+ )
43
+
44
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45
+ packed = False
46
+ if images.ndim == 4:
47
+ packed = True
48
+ images = images.unsqueeze(1)
49
+
50
+ batch_size, n_input_views = images.shape[:2]
51
+ images = (images - self.image_mean) / self.image_std
52
+ out = self.model(
53
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54
+ )
55
+ local_features, global_features = out.last_hidden_state, out.pooler_output
56
+ local_features = local_features.permute(0, 2, 1)
57
+ local_features = rearrange(
58
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59
+ )
60
+ if packed:
61
+ local_features = local_features.squeeze(1)
62
+
63
+ return local_features
64
+
65
+ def detokenize(self, *args, **kwargs):
66
+ raise NotImplementedError
TripoSR/tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )
TripoSR/tsr/models/transformer/__pycache__/attention.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
TripoSR/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc ADDED
Binary file (9.61 kB). View file
 
TripoSR/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
TripoSR/tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+
46
+ class Attention(nn.Module):
47
+ r"""
48
+ A cross attention layer.
49
+
50
+ Parameters:
51
+ query_dim (`int`):
52
+ The number of channels in the query.
53
+ cross_attention_dim (`int`, *optional*):
54
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55
+ heads (`int`, *optional*, defaults to 8):
56
+ The number of heads to use for multi-head attention.
57
+ dim_head (`int`, *optional*, defaults to 64):
58
+ The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probability to use.
61
+ bias (`bool`, *optional*, defaults to False):
62
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63
+ upcast_attention (`bool`, *optional*, defaults to False):
64
+ Set to `True` to upcast the attention computation to `float32`.
65
+ upcast_softmax (`bool`, *optional*, defaults to False):
66
+ Set to `True` to upcast the softmax computation to `float32`.
67
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
68
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70
+ The number of groups to use for the group norm in the cross attention.
71
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
73
+ norm_num_groups (`int`, *optional*, defaults to `None`):
74
+ The number of groups to use for the group norm in the attention.
75
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
76
+ The number of channels to use for the spatial normalization.
77
+ out_bias (`bool`, *optional*, defaults to `True`):
78
+ Set to `True` to use a bias in the output linear layer.
79
+ scale_qk (`bool`, *optional*, defaults to `True`):
80
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
82
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83
+ `added_kv_proj_dim` is not `None`.
84
+ eps (`float`, *optional*, defaults to 1e-5):
85
+ An additional value added to the denominator in group normalization that is used for numerical stability.
86
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
87
+ A factor to rescale the output by dividing it with this value.
88
+ residual_connection (`bool`, *optional*, defaults to `False`):
89
+ Set to `True` to add the residual connection to the output.
90
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91
+ Set to `True` if the attention block is loaded from a deprecated state dict.
92
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
93
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94
+ `AttnProcessor` otherwise.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ query_dim: int,
100
+ cross_attention_dim: Optional[int] = None,
101
+ heads: int = 8,
102
+ dim_head: int = 64,
103
+ dropout: float = 0.0,
104
+ bias: bool = False,
105
+ upcast_attention: bool = False,
106
+ upcast_softmax: bool = False,
107
+ cross_attention_norm: Optional[str] = None,
108
+ cross_attention_norm_num_groups: int = 32,
109
+ added_kv_proj_dim: Optional[int] = None,
110
+ norm_num_groups: Optional[int] = None,
111
+ out_bias: bool = True,
112
+ scale_qk: bool = True,
113
+ only_cross_attention: bool = False,
114
+ eps: float = 1e-5,
115
+ rescale_output_factor: float = 1.0,
116
+ residual_connection: bool = False,
117
+ _from_deprecated_attn_block: bool = False,
118
+ processor: Optional["AttnProcessor"] = None,
119
+ out_dim: int = None,
120
+ ):
121
+ super().__init__()
122
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123
+ self.query_dim = query_dim
124
+ self.cross_attention_dim = (
125
+ cross_attention_dim if cross_attention_dim is not None else query_dim
126
+ )
127
+ self.upcast_attention = upcast_attention
128
+ self.upcast_softmax = upcast_softmax
129
+ self.rescale_output_factor = rescale_output_factor
130
+ self.residual_connection = residual_connection
131
+ self.dropout = dropout
132
+ self.fused_projections = False
133
+ self.out_dim = out_dim if out_dim is not None else query_dim
134
+
135
+ # we make use of this private variable to know whether this class is loaded
136
+ # with an deprecated state dict so that we can convert it on the fly
137
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
138
+
139
+ self.scale_qk = scale_qk
140
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141
+
142
+ self.heads = out_dim // dim_head if out_dim is not None else heads
143
+ # for slice_size > 0 the attention score computation
144
+ # is split across the batch axis to save memory
145
+ # You can set slice_size with `set_attention_slice`
146
+ self.sliceable_head_dim = heads
147
+
148
+ self.added_kv_proj_dim = added_kv_proj_dim
149
+ self.only_cross_attention = only_cross_attention
150
+
151
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
152
+ raise ValueError(
153
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154
+ )
155
+
156
+ if norm_num_groups is not None:
157
+ self.group_norm = nn.GroupNorm(
158
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.group_norm = None
162
+
163
+ self.spatial_norm = None
164
+
165
+ if cross_attention_norm is None:
166
+ self.norm_cross = None
167
+ elif cross_attention_norm == "layer_norm":
168
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169
+ elif cross_attention_norm == "group_norm":
170
+ if self.added_kv_proj_dim is not None:
171
+ # The given `encoder_hidden_states` are initially of shape
172
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
173
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174
+ # before the projection, so we need to use `added_kv_proj_dim` as
175
+ # the number of channels for the group norm.
176
+ norm_cross_num_channels = added_kv_proj_dim
177
+ else:
178
+ norm_cross_num_channels = self.cross_attention_dim
179
+
180
+ self.norm_cross = nn.GroupNorm(
181
+ num_channels=norm_cross_num_channels,
182
+ num_groups=cross_attention_norm_num_groups,
183
+ eps=1e-5,
184
+ affine=True,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189
+ )
190
+
191
+ linear_cls = nn.Linear
192
+
193
+ self.linear_cls = linear_cls
194
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195
+
196
+ if not self.only_cross_attention:
197
+ # only relevant for the `AddedKVProcessor` classes
198
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200
+ else:
201
+ self.to_k = None
202
+ self.to_v = None
203
+
204
+ if self.added_kv_proj_dim is not None:
205
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207
+
208
+ self.to_out = nn.ModuleList([])
209
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210
+ self.to_out.append(nn.Dropout(dropout))
211
+
212
+ # set attention processor
213
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216
+ if processor is None:
217
+ processor = (
218
+ AttnProcessor2_0()
219
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220
+ else AttnProcessor()
221
+ )
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor") -> None:
225
+ self.processor = processor
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.FloatTensor,
230
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
231
+ attention_mask: Optional[torch.FloatTensor] = None,
232
+ **cross_attention_kwargs,
233
+ ) -> torch.Tensor:
234
+ r"""
235
+ The forward method of the `Attention` class.
236
+
237
+ Args:
238
+ hidden_states (`torch.Tensor`):
239
+ The hidden states of the query.
240
+ encoder_hidden_states (`torch.Tensor`, *optional*):
241
+ The hidden states of the encoder.
242
+ attention_mask (`torch.Tensor`, *optional*):
243
+ The attention mask to use. If `None`, no mask is applied.
244
+ **cross_attention_kwargs:
245
+ Additional keyword arguments to pass along to the cross attention.
246
+
247
+ Returns:
248
+ `torch.Tensor`: The output of the attention layer.
249
+ """
250
+ # The `Attention` class can call different attention processors / attention functions
251
+ # here we simply pass along all tensors to the selected processor class
252
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253
+ return self.processor(
254
+ self,
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ **cross_attention_kwargs,
259
+ )
260
+
261
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262
+ r"""
263
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264
+ is the number of heads initialized while constructing the `Attention` class.
265
+
266
+ Args:
267
+ tensor (`torch.Tensor`): The tensor to reshape.
268
+
269
+ Returns:
270
+ `torch.Tensor`: The reshaped tensor.
271
+ """
272
+ head_size = self.heads
273
+ batch_size, seq_len, dim = tensor.shape
274
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
276
+ batch_size // head_size, seq_len, dim * head_size
277
+ )
278
+ return tensor
279
+
280
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281
+ r"""
282
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283
+ the number of heads initialized while constructing the `Attention` class.
284
+
285
+ Args:
286
+ tensor (`torch.Tensor`): The tensor to reshape.
287
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289
+
290
+ Returns:
291
+ `torch.Tensor`: The reshaped tensor.
292
+ """
293
+ head_size = self.heads
294
+ batch_size, seq_len, dim = tensor.shape
295
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296
+ tensor = tensor.permute(0, 2, 1, 3)
297
+
298
+ if out_dim == 3:
299
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300
+
301
+ return tensor
302
+
303
+ def get_attention_scores(
304
+ self,
305
+ query: torch.Tensor,
306
+ key: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ ) -> torch.Tensor:
309
+ r"""
310
+ Compute the attention scores.
311
+
312
+ Args:
313
+ query (`torch.Tensor`): The query tensor.
314
+ key (`torch.Tensor`): The key tensor.
315
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316
+
317
+ Returns:
318
+ `torch.Tensor`: The attention probabilities/scores.
319
+ """
320
+ dtype = query.dtype
321
+ if self.upcast_attention:
322
+ query = query.float()
323
+ key = key.float()
324
+
325
+ if attention_mask is None:
326
+ baddbmm_input = torch.empty(
327
+ query.shape[0],
328
+ query.shape[1],
329
+ key.shape[1],
330
+ dtype=query.dtype,
331
+ device=query.device,
332
+ )
333
+ beta = 0
334
+ else:
335
+ baddbmm_input = attention_mask
336
+ beta = 1
337
+
338
+ attention_scores = torch.baddbmm(
339
+ baddbmm_input,
340
+ query,
341
+ key.transpose(-1, -2),
342
+ beta=beta,
343
+ alpha=self.scale,
344
+ )
345
+ del baddbmm_input
346
+
347
+ if self.upcast_softmax:
348
+ attention_scores = attention_scores.float()
349
+
350
+ attention_probs = attention_scores.softmax(dim=-1)
351
+ del attention_scores
352
+
353
+ attention_probs = attention_probs.to(dtype)
354
+
355
+ return attention_probs
356
+
357
+ def prepare_attention_mask(
358
+ self,
359
+ attention_mask: torch.Tensor,
360
+ target_length: int,
361
+ batch_size: int,
362
+ out_dim: int = 3,
363
+ ) -> torch.Tensor:
364
+ r"""
365
+ Prepare the attention mask for the attention computation.
366
+
367
+ Args:
368
+ attention_mask (`torch.Tensor`):
369
+ The attention mask to prepare.
370
+ target_length (`int`):
371
+ The target length of the attention mask. This is the length of the attention mask after padding.
372
+ batch_size (`int`):
373
+ The batch size, which is used to repeat the attention mask.
374
+ out_dim (`int`, *optional*, defaults to `3`):
375
+ The output dimension of the attention mask. Can be either `3` or `4`.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The prepared attention mask.
379
+ """
380
+ head_size = self.heads
381
+ if attention_mask is None:
382
+ return attention_mask
383
+
384
+ current_length: int = attention_mask.shape[-1]
385
+ if current_length != target_length:
386
+ if attention_mask.device.type == "mps":
387
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388
+ # Instead, we can manually construct the padding tensor.
389
+ padding_shape = (
390
+ attention_mask.shape[0],
391
+ attention_mask.shape[1],
392
+ target_length,
393
+ )
394
+ padding = torch.zeros(
395
+ padding_shape,
396
+ dtype=attention_mask.dtype,
397
+ device=attention_mask.device,
398
+ )
399
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
400
+ else:
401
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
403
+ # remaining_length: int = target_length - current_length
404
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406
+
407
+ if out_dim == 3:
408
+ if attention_mask.shape[0] < batch_size * head_size:
409
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410
+ elif out_dim == 4:
411
+ attention_mask = attention_mask.unsqueeze(1)
412
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413
+
414
+ return attention_mask
415
+
416
+ def norm_encoder_hidden_states(
417
+ self, encoder_hidden_states: torch.Tensor
418
+ ) -> torch.Tensor:
419
+ r"""
420
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421
+ `Attention` class.
422
+
423
+ Args:
424
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425
+
426
+ Returns:
427
+ `torch.Tensor`: The normalized encoder hidden states.
428
+ """
429
+ assert (
430
+ self.norm_cross is not None
431
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+ @torch.no_grad()
450
+ def fuse_projections(self, fuse=True):
451
+ is_cross_attention = self.cross_attention_dim != self.query_dim
452
+ device = self.to_q.weight.data.device
453
+ dtype = self.to_q.weight.data.dtype
454
+
455
+ if not is_cross_attention:
456
+ # fetch weight matrices.
457
+ concatenated_weights = torch.cat(
458
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459
+ )
460
+ in_features = concatenated_weights.shape[1]
461
+ out_features = concatenated_weights.shape[0]
462
+
463
+ # create a new single projection layer and copy over the weights.
464
+ self.to_qkv = self.linear_cls(
465
+ in_features, out_features, bias=False, device=device, dtype=dtype
466
+ )
467
+ self.to_qkv.weight.copy_(concatenated_weights)
468
+
469
+ else:
470
+ concatenated_weights = torch.cat(
471
+ [self.to_k.weight.data, self.to_v.weight.data]
472
+ )
473
+ in_features = concatenated_weights.shape[1]
474
+ out_features = concatenated_weights.shape[0]
475
+
476
+ self.to_kv = self.linear_cls(
477
+ in_features, out_features, bias=False, device=device, dtype=dtype
478
+ )
479
+ self.to_kv.weight.copy_(concatenated_weights)
480
+
481
+ self.fused_projections = fuse
482
+
483
+
484
+ class AttnProcessor:
485
+ r"""
486
+ Default processor for performing attention-related computations.
487
+ """
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ ) -> torch.Tensor:
496
+ residual = hidden_states
497
+
498
+ input_ndim = hidden_states.ndim
499
+
500
+ if input_ndim == 4:
501
+ batch_size, channel, height, width = hidden_states.shape
502
+ hidden_states = hidden_states.view(
503
+ batch_size, channel, height * width
504
+ ).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape
508
+ if encoder_hidden_states is None
509
+ else encoder_hidden_states.shape
510
+ )
511
+ attention_mask = attn.prepare_attention_mask(
512
+ attention_mask, sequence_length, batch_size
513
+ )
514
+
515
+ if attn.group_norm is not None:
516
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517
+ 1, 2
518
+ )
519
+
520
+ query = attn.to_q(hidden_states)
521
+
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
526
+ encoder_hidden_states
527
+ )
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ query = attn.head_to_batch_dim(query)
533
+ key = attn.head_to_batch_dim(key)
534
+ value = attn.head_to_batch_dim(value)
535
+
536
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
537
+ hidden_states = torch.bmm(attention_probs, value)
538
+ hidden_states = attn.batch_to_head_dim(hidden_states)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
547
+ batch_size, channel, height, width
548
+ )
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+ class AttnProcessor2_0:
559
+ r"""
560
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561
+ """
562
+
563
+ def __init__(self):
564
+ if not hasattr(F, "scaled_dot_product_attention"):
565
+ raise ImportError(
566
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567
+ )
568
+
569
+ def __call__(
570
+ self,
571
+ attn: Attention,
572
+ hidden_states: torch.FloatTensor,
573
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
574
+ attention_mask: Optional[torch.FloatTensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ residual = hidden_states
577
+
578
+ input_ndim = hidden_states.ndim
579
+
580
+ if input_ndim == 4:
581
+ batch_size, channel, height, width = hidden_states.shape
582
+ hidden_states = hidden_states.view(
583
+ batch_size, channel, height * width
584
+ ).transpose(1, 2)
585
+
586
+ batch_size, sequence_length, _ = (
587
+ hidden_states.shape
588
+ if encoder_hidden_states is None
589
+ else encoder_hidden_states.shape
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attn.prepare_attention_mask(
594
+ attention_mask, sequence_length, batch_size
595
+ )
596
+ # scaled_dot_product_attention expects attention_mask shape to be
597
+ # (batch, heads, source_length, target_length)
598
+ attention_mask = attention_mask.view(
599
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
600
+ )
601
+
602
+ if attn.group_norm is not None:
603
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604
+ 1, 2
605
+ )
606
+
607
+ query = attn.to_q(hidden_states)
608
+
609
+ if encoder_hidden_states is None:
610
+ encoder_hidden_states = hidden_states
611
+ elif attn.norm_cross:
612
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
613
+ encoder_hidden_states
614
+ )
615
+
616
+ key = attn.to_k(encoder_hidden_states)
617
+ value = attn.to_v(encoder_hidden_states)
618
+
619
+ inner_dim = key.shape[-1]
620
+ head_dim = inner_dim // attn.heads
621
+
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+
624
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+
627
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
628
+ # TODO: add support for attn.scale when we move to Torch 2.1
629
+ hidden_states = F.scaled_dot_product_attention(
630
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631
+ )
632
+
633
+ hidden_states = hidden_states.transpose(1, 2).reshape(
634
+ batch_size, -1, attn.heads * head_dim
635
+ )
636
+ hidden_states = hidden_states.to(query.dtype)
637
+
638
+ # linear proj
639
+ hidden_states = attn.to_out[0](hidden_states)
640
+ # dropout
641
+ hidden_states = attn.to_out[1](hidden_states)
642
+
643
+ if input_ndim == 4:
644
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
645
+ batch_size, channel, height, width
646
+ )
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states
TripoSR/tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+ from .attention import Attention
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ r"""
50
+ A basic Transformer block.
51
+
52
+ Parameters:
53
+ dim (`int`): The number of channels in the input and output.
54
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`): The number of channels in each head.
56
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59
+ attention_bias (:
60
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
+ only_cross_attention (`bool`, *optional*):
62
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
63
+ double_self_attention (`bool`, *optional*):
64
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
65
+ upcast_attention (`bool`, *optional*):
66
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68
+ Whether to use learnable elementwise affine parameters for normalization.
69
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
+ final_dropout (`bool` *optional*, defaults to False):
72
+ Whether to apply a final dropout after the last feed-forward layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ dropout=0.0,
81
+ cross_attention_dim: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ attention_bias: bool = False,
84
+ only_cross_attention: bool = False,
85
+ double_self_attention: bool = False,
86
+ upcast_attention: bool = False,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_type: str = "layer_norm",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.only_cross_attention = only_cross_attention
93
+
94
+ assert norm_type == "layer_norm"
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106
+ upcast_attention=upcast_attention,
107
+ )
108
+
109
+ # 2. Cross-Attn
110
+ if cross_attention_dim is not None or double_self_attention:
111
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113
+ # the second cross attention block.
114
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.attn2 = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ dropout=dropout,
124
+ bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ ) # is self-attn if encoder_hidden_states is none
127
+ else:
128
+ self.norm2 = None
129
+ self.attn2 = None
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn=activation_fn,
137
+ final_dropout=final_dropout,
138
+ )
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145
+ # Sets chunk feed-forward
146
+ self._chunk_size = chunk_size
147
+ self._chunk_dim = dim
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.FloatTensor,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ) -> torch.FloatTensor:
156
+ # Notice that normalization is always applied before the real computation in the following blocks.
157
+ # 0. Self-Attention
158
+ norm_hidden_states = self.norm1(hidden_states)
159
+
160
+ attn_output = self.attn1(
161
+ norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ hidden_states = attn_output + hidden_states
169
+
170
+ # 3. Cross-Attention
171
+ if self.attn2 is not None:
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 4. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [
194
+ self.ff(hid_slice)
195
+ for hid_slice in norm_hidden_states.chunk(
196
+ num_chunks, dim=self._chunk_dim
197
+ )
198
+ ],
199
+ dim=self._chunk_dim,
200
+ )
201
+ else:
202
+ ff_output = self.ff(norm_hidden_states)
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+ linear_cls = nn.Linear
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(linear_cls(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+
266
+ Parameters:
267
+ dim_in (`int`): The number of channels in the input.
268
+ dim_out (`int`): The number of channels in the output.
269
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270
+ """
271
+
272
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273
+ super().__init__()
274
+ self.proj = nn.Linear(dim_in, dim_out)
275
+ self.approximate = approximate
276
+
277
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278
+ if gate.device.type != "mps":
279
+ return F.gelu(gate, approximate=self.approximate)
280
+ # mps: gelu is not implemented for float16
281
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282
+ dtype=gate.dtype
283
+ )
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = self.proj(hidden_states)
287
+ hidden_states = self.gelu(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class GEGLU(nn.Module):
292
+ r"""
293
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294
+
295
+ Parameters:
296
+ dim_in (`int`): The number of channels in the input.
297
+ dim_out (`int`): The number of channels in the output.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int):
301
+ super().__init__()
302
+ linear_cls = nn.Linear
303
+
304
+ self.proj = linear_cls(dim_in, dim_out * 2)
305
+
306
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307
+ if gate.device.type != "mps":
308
+ return F.gelu(gate)
309
+ # mps: gelu is not implemented for float16
310
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311
+
312
+ def forward(self, hidden_states, scale: float = 1.0):
313
+ args = ()
314
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315
+ return hidden_states * self.gelu(gate)
316
+
317
+
318
+ class ApproximateGELU(nn.Module):
319
+ r"""
320
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321
+ https://arxiv.org/abs/1606.08415.
322
+
323
+ Parameters:
324
+ dim_in (`int`): The number of channels in the input.
325
+ dim_out (`int`): The number of channels in the output.
326
+ """
327
+
328
+ def __init__(self, dim_in: int, dim_out: int):
329
+ super().__init__()
330
+ self.proj = nn.Linear(dim_in, dim_out)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.proj(x)
334
+ return x * torch.sigmoid(1.702 * x)
TripoSR/tsr/models/transformer/transformer_1d.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from dataclasses import dataclass
40
+ from typing import Optional
41
+
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from torch import nn
45
+
46
+ from ...utils import BaseModule
47
+ from .basic_transformer_block import BasicTransformerBlock
48
+
49
+
50
+ class Transformer1D(BaseModule):
51
+ @dataclass
52
+ class Config(BaseModule.Config):
53
+ num_attention_heads: int = 16
54
+ attention_head_dim: int = 88
55
+ in_channels: Optional[int] = None
56
+ out_channels: Optional[int] = None
57
+ num_layers: int = 1
58
+ dropout: float = 0.0
59
+ norm_num_groups: int = 32
60
+ cross_attention_dim: Optional[int] = None
61
+ attention_bias: bool = False
62
+ activation_fn: str = "geglu"
63
+ only_cross_attention: bool = False
64
+ double_self_attention: bool = False
65
+ upcast_attention: bool = False
66
+ norm_type: str = "layer_norm"
67
+ norm_elementwise_affine: bool = True
68
+ gradient_checkpointing: bool = False
69
+
70
+ cfg: Config
71
+
72
+ def configure(self) -> None:
73
+ self.num_attention_heads = self.cfg.num_attention_heads
74
+ self.attention_head_dim = self.cfg.attention_head_dim
75
+ inner_dim = self.num_attention_heads * self.attention_head_dim
76
+
77
+ linear_cls = nn.Linear
78
+
79
+ # 2. Define input layers
80
+ self.in_channels = self.cfg.in_channels
81
+
82
+ self.norm = torch.nn.GroupNorm(
83
+ num_groups=self.cfg.norm_num_groups,
84
+ num_channels=self.cfg.in_channels,
85
+ eps=1e-6,
86
+ affine=True,
87
+ )
88
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
89
+
90
+ # 3. Define transformers blocks
91
+ self.transformer_blocks = nn.ModuleList(
92
+ [
93
+ BasicTransformerBlock(
94
+ inner_dim,
95
+ self.num_attention_heads,
96
+ self.attention_head_dim,
97
+ dropout=self.cfg.dropout,
98
+ cross_attention_dim=self.cfg.cross_attention_dim,
99
+ activation_fn=self.cfg.activation_fn,
100
+ attention_bias=self.cfg.attention_bias,
101
+ only_cross_attention=self.cfg.only_cross_attention,
102
+ double_self_attention=self.cfg.double_self_attention,
103
+ upcast_attention=self.cfg.upcast_attention,
104
+ norm_type=self.cfg.norm_type,
105
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
106
+ )
107
+ for d in range(self.cfg.num_layers)
108
+ ]
109
+ )
110
+
111
+ # 4. Define output layers
112
+ self.out_channels = (
113
+ self.cfg.in_channels
114
+ if self.cfg.out_channels is None
115
+ else self.cfg.out_channels
116
+ )
117
+
118
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
119
+
120
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ encoder_hidden_states: Optional[torch.Tensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ encoder_attention_mask: Optional[torch.Tensor] = None,
128
+ ):
129
+ """
130
+ The [`Transformer1DModel`] forward method.
131
+
132
+ Args:
133
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134
+ Input `hidden_states`.
135
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
+ self-attention.
138
+ attention_mask ( `torch.Tensor`, *optional*):
139
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
141
+ negative values to the attention scores corresponding to "discard" tokens.
142
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
143
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
144
+
145
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
146
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
147
+
148
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149
+ above. This bias will be added to the cross-attention scores.
150
+
151
+ Returns:
152
+ torch.FloatTensor
153
+ """
154
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
156
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
157
+ # expects mask of shape:
158
+ # [batch, key_tokens]
159
+ # adds singleton query_tokens dimension:
160
+ # [batch, 1, key_tokens]
161
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
162
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
163
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
164
+ if attention_mask is not None and attention_mask.ndim == 2:
165
+ # assume that mask is expressed as:
166
+ # (1 = keep, 0 = discard)
167
+ # convert mask into a bias that can be added to attention scores:
168
+ # (keep = +0, discard = -10000.0)
169
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
170
+ attention_mask = attention_mask.unsqueeze(1)
171
+
172
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
173
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
174
+ encoder_attention_mask = (
175
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
176
+ ) * -10000.0
177
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
178
+
179
+ # 1. Input
180
+ batch, _, seq_len = hidden_states.shape
181
+ residual = hidden_states
182
+
183
+ hidden_states = self.norm(hidden_states)
184
+ inner_dim = hidden_states.shape[1]
185
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
186
+ batch, seq_len, inner_dim
187
+ )
188
+ hidden_states = self.proj_in(hidden_states)
189
+
190
+ # 2. Blocks
191
+ for block in self.transformer_blocks:
192
+ if self.training and self.gradient_checkpointing:
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ block,
195
+ hidden_states,
196
+ attention_mask,
197
+ encoder_hidden_states,
198
+ encoder_attention_mask,
199
+ use_reentrant=False,
200
+ )
201
+ else:
202
+ hidden_states = block(
203
+ hidden_states,
204
+ attention_mask=attention_mask,
205
+ encoder_hidden_states=encoder_hidden_states,
206
+ encoder_attention_mask=encoder_attention_mask,
207
+ )
208
+
209
+ # 3. Output
210
+ hidden_states = self.proj_out(hidden_states)
211
+ hidden_states = (
212
+ hidden_states.reshape(batch, seq_len, inner_dim)
213
+ .permute(0, 2, 1)
214
+ .contiguous()
215
+ )
216
+
217
+ output = hidden_states + residual
218
+
219
+ return output
TripoSR/tsr/system.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from .models.isosurface import MarchingCubeHelper
17
+ from .utils import (
18
+ BaseModule,
19
+ ImagePreprocessor,
20
+ find_class,
21
+ get_spherical_cameras,
22
+ scale_tensor,
23
+ )
24
+
25
+
26
+ class TSR(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ cond_image_size: int
30
+
31
+ image_tokenizer_cls: str
32
+ image_tokenizer: dict
33
+
34
+ tokenizer_cls: str
35
+ tokenizer: dict
36
+
37
+ backbone_cls: str
38
+ backbone: dict
39
+
40
+ post_processor_cls: str
41
+ post_processor: dict
42
+
43
+ decoder_cls: str
44
+ decoder: dict
45
+
46
+ renderer_cls: str
47
+ renderer: dict
48
+
49
+ cfg: Config
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54
+ ):
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
+ else:
59
+ config_path = hf_hub_download(
60
+ repo_id=pretrained_model_name_or_path, filename=config_name
61
+ )
62
+ weight_path = hf_hub_download(
63
+ repo_id=pretrained_model_name_or_path, filename=weight_name
64
+ )
65
+
66
+ cfg = OmegaConf.load(config_path)
67
+ OmegaConf.resolve(cfg)
68
+ model = cls(cfg)
69
+ ckpt = torch.load(weight_path, map_location="cpu")
70
+ model.load_state_dict(ckpt)
71
+ return model
72
+
73
+ def configure(self):
74
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
75
+ self.cfg.image_tokenizer
76
+ )
77
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
78
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
79
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
80
+ self.cfg.post_processor
81
+ )
82
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
83
+ self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
84
+ self.image_processor = ImagePreprocessor()
85
+ self.isosurface_helper = None
86
+
87
+ def forward(
88
+ self,
89
+ image: Union[
90
+ PIL.Image.Image,
91
+ np.ndarray,
92
+ torch.FloatTensor,
93
+ List[PIL.Image.Image],
94
+ List[np.ndarray],
95
+ List[torch.FloatTensor],
96
+ ],
97
+ device: str,
98
+ ) -> torch.FloatTensor:
99
+ rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
100
+ device
101
+ )
102
+ batch_size = rgb_cond.shape[0]
103
+
104
+ input_image_tokens: torch.Tensor = self.image_tokenizer(
105
+ rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
106
+ )
107
+
108
+ input_image_tokens = rearrange(
109
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
110
+ )
111
+
112
+ tokens: torch.Tensor = self.tokenizer(batch_size)
113
+
114
+ tokens = self.backbone(
115
+ tokens,
116
+ encoder_hidden_states=input_image_tokens,
117
+ )
118
+
119
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
120
+ return scene_codes
121
+
122
+ def render(
123
+ self,
124
+ scene_codes,
125
+ n_views: int,
126
+ elevation_deg: float = 0.0,
127
+ camera_distance: float = 1.9,
128
+ fovy_deg: float = 40.0,
129
+ height: int = 256,
130
+ width: int = 256,
131
+ return_type: str = "pil",
132
+ ):
133
+ rays_o, rays_d = get_spherical_cameras(
134
+ n_views, elevation_deg, camera_distance, fovy_deg, height, width
135
+ )
136
+ rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
137
+
138
+ def process_output(image: torch.FloatTensor):
139
+ if return_type == "pt":
140
+ return image
141
+ elif return_type == "np":
142
+ return image.detach().cpu().numpy()
143
+ elif return_type == "pil":
144
+ return Image.fromarray(
145
+ (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
146
+ )
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ images = []
151
+ for scene_code in scene_codes:
152
+ images_ = []
153
+ for i in range(n_views):
154
+ with torch.no_grad():
155
+ image = self.renderer(
156
+ self.decoder, scene_code, rays_o[i], rays_d[i]
157
+ )
158
+ images_.append(process_output(image))
159
+ images.append(images_)
160
+
161
+ return images
162
+
163
+ def set_marching_cubes_resolution(self, resolution: int):
164
+ if (
165
+ self.isosurface_helper is not None
166
+ and self.isosurface_helper.resolution == resolution
167
+ ):
168
+ return
169
+ self.isosurface_helper = MarchingCubeHelper(resolution)
170
+
171
+ def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
172
+ self.set_marching_cubes_resolution(resolution)
173
+ meshes = []
174
+ for scene_code in scene_codes:
175
+ with torch.no_grad():
176
+ density = self.renderer.query_triplane(
177
+ self.decoder,
178
+ scale_tensor(
179
+ self.isosurface_helper.grid_vertices.to(scene_codes.device),
180
+ self.isosurface_helper.points_range,
181
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
182
+ ),
183
+ scene_code,
184
+ )["density_act"]
185
+ v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
186
+ v_pos = scale_tensor(
187
+ v_pos,
188
+ self.isosurface_helper.points_range,
189
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
190
+ )
191
+ with torch.no_grad():
192
+ color = self.renderer.query_triplane(
193
+ self.decoder,
194
+ v_pos,
195
+ scene_code,
196
+ )["color"]
197
+ mesh = trimesh.Trimesh(
198
+ vertices=v_pos.cpu().numpy(),
199
+ faces=t_pos_idx.cpu().numpy(),
200
+ vertex_colors=color.cpu().numpy(),
201
+ )
202
+ meshes.append(mesh)
203
+ return meshes