John6666 commited on
Commit
b572032
·
verified ·
1 Parent(s): d534d13

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +21 -0
  2. .pre-commit-config.yaml +24 -0
  3. LICENSE.md +51 -0
  4. README.md +19 -12
  5. __init__.py +363 -0
  6. demo_files/comp.gif +3 -0
  7. demo_files/examples/bird.png +3 -0
  8. demo_files/examples/castle.png +3 -0
  9. demo_files/examples/chest.png +3 -0
  10. demo_files/examples/doll.png +3 -0
  11. demo_files/examples/excavator.png +3 -0
  12. demo_files/examples/fish.png +3 -0
  13. demo_files/examples/horse-statue.png +3 -0
  14. demo_files/examples/penguin.png +3 -0
  15. demo_files/examples/pot.png +3 -0
  16. demo_files/examples/raccoon_wizard.png +3 -0
  17. demo_files/examples/stylized-rocks.png +3 -0
  18. demo_files/hdri/abandoned_tiled_room_1k.hdr +3 -0
  19. demo_files/hdri/metro_noord_1k.hdr +3 -0
  20. demo_files/hdri/neon_photostudio_1k.hdr +3 -0
  21. demo_files/hdri/peppermint_powerplant_1k.hdr +3 -0
  22. demo_files/hdri/rainforest_trail_1k.hdr +3 -0
  23. demo_files/hdri/studio_small_08_1k.hdr +3 -0
  24. demo_files/hdri/urban_alley_01_1k.hdr +3 -0
  25. demo_files/turntable.gif +3 -0
  26. demo_files/workflows/spar3d_example.json +263 -0
  27. deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl +3 -0
  28. gradio_app.py +227 -0
  29. load/tets/160_tets.npz +3 -0
  30. requirements.txt +23 -0
  31. ruff.toml +3 -0
  32. run.py +190 -0
  33. spar3d/models/camera.py +32 -0
  34. spar3d/models/diffusion/gaussian_diffusion.py +524 -0
  35. spar3d/models/diffusion/sampler.py +134 -0
  36. spar3d/models/global_estimator/reni_estimator.py +116 -0
  37. spar3d/models/illumination/reni/components/film_siren.py +148 -0
  38. spar3d/models/illumination/reni/components/siren.py +118 -0
  39. spar3d/models/illumination/reni/components/transformer_decoder.py +189 -0
  40. spar3d/models/illumination/reni/components/vn_layers.py +548 -0
  41. spar3d/models/illumination/reni/env_map.py +93 -0
  42. spar3d/models/illumination/reni/field.py +736 -0
  43. spar3d/models/image_estimator/clip_based_estimator.py +184 -0
  44. spar3d/models/isosurface.py +229 -0
  45. spar3d/models/mesh.py +317 -0
  46. spar3d/models/network.py +226 -0
  47. spar3d/models/tokenizers/dinov2.py +1196 -0
  48. spar3d/models/tokenizers/image.py +99 -0
  49. spar3d/models/tokenizers/point.py +51 -0
  50. spar3d/models/tokenizers/triplane.py +49 -0
.gitattributes CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_files/comp.gif filter=lfs diff=lfs merge=lfs -text
37
+ demo_files/examples/bird.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_files/examples/castle.png filter=lfs diff=lfs merge=lfs -text
39
+ demo_files/examples/chest.png filter=lfs diff=lfs merge=lfs -text
40
+ demo_files/examples/doll.png filter=lfs diff=lfs merge=lfs -text
41
+ demo_files/examples/excavator.png filter=lfs diff=lfs merge=lfs -text
42
+ demo_files/examples/fish.png filter=lfs diff=lfs merge=lfs -text
43
+ demo_files/examples/horse-statue.png filter=lfs diff=lfs merge=lfs -text
44
+ demo_files/examples/penguin.png filter=lfs diff=lfs merge=lfs -text
45
+ demo_files/examples/pot.png filter=lfs diff=lfs merge=lfs -text
46
+ demo_files/examples/raccoon_wizard.png filter=lfs diff=lfs merge=lfs -text
47
+ demo_files/examples/stylized-rocks.png filter=lfs diff=lfs merge=lfs -text
48
+ demo_files/hdri/abandoned_tiled_room_1k.hdr filter=lfs diff=lfs merge=lfs -text
49
+ demo_files/hdri/metro_noord_1k.hdr filter=lfs diff=lfs merge=lfs -text
50
+ demo_files/hdri/neon_photostudio_1k.hdr filter=lfs diff=lfs merge=lfs -text
51
+ demo_files/hdri/peppermint_powerplant_1k.hdr filter=lfs diff=lfs merge=lfs -text
52
+ demo_files/hdri/rainforest_trail_1k.hdr filter=lfs diff=lfs merge=lfs -text
53
+ demo_files/hdri/studio_small_08_1k.hdr filter=lfs diff=lfs merge=lfs -text
54
+ demo_files/hdri/urban_alley_01_1k.hdr filter=lfs diff=lfs merge=lfs -text
55
+ demo_files/turntable.gif filter=lfs diff=lfs merge=lfs -text
56
+ deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: check-ast
10
+ - id: check-merge-conflict
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ args: [--markdown-linebreak-ext=md]
15
+
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ # Ruff version.
18
+ rev: v0.3.5
19
+ hooks:
20
+ # Run the linter.
21
+ - id: ruff
22
+ args: [ --fix ]
23
+ # Run the formatter.
24
+ - id: ruff-format
LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+ Last Updated: July 5, 2024
3
+
4
+
5
+ I. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+
10
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
11
+
12
+
13
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
14
+
15
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
18
+
19
+ III. COMMERCIAL USE LICENSE
20
+
21
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
22
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
23
+
24
+ IV. GENERAL TERMS
25
+
26
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
27
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
28
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
29
+ c. Intellectual Property.
30
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
31
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
32
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
33
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
34
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
35
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
38
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
39
+
40
+ V. DEFINITIONS
41
+
42
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
43
+ "Agreement" means this Stability AI Community License Agreement.
44
+ "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
45
+ "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
46
+ "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
47
+ "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
48
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
49
+ "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
50
+ "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
51
+ "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,19 @@
1
- ---
2
- title: Image To 3d Test
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.16.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image to 3D
3
+ emoji:
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.15.0
8
+ app_file: gradio_app.py
9
+ pinned: false
10
+ ---
11
+
12
+ This space is a variant of [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d), but using text as input.
13
+
14
+ * **Repository**: [https://github.com/Stability-AI/stable-point-aware-3d](https://github.com/Stability-AI/stable-point-aware-3d)
15
+ * **Model**: [https://huggingface.co/stabilityai/stable-point-aware-3d](https://huggingface.co/stabilityai/stable-point-aware-3d)
16
+ * **Tech report**: [https://arxiv.org/pdf/2501.04689](https://arxiv.org/pdf/2501.04689)
17
+ * **Video summary**: [https://youtu.be/mlO3Nc3Nsng](https://youtu.be/mlO3Nc3Nsng)
18
+ * **Project page**: [https://spar3d.github.io](https://spar3d.github.io)
19
+ * **arXiv page**: [https://arxiv.org/abs/2501.04689](https://arxiv.org/abs/2501.04689)
__init__.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ import comfy.model_management
8
+ import folder_paths
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ from PIL import Image
13
+ from trimesh.exchange import gltf
14
+
15
+ sys.path.append(os.path.dirname(__file__))
16
+ from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
17
+ from spar3d.system import SPAR3D
18
+ from spar3d.utils import foreground_crop
19
+
20
+ SPAR3D_CATEGORY = "SPAR3D"
21
+ SPAR3D_MODEL_NAME = "stabilityai/spar3d"
22
+
23
+
24
+ class SPAR3DLoader:
25
+ CATEGORY = SPAR3D_CATEGORY
26
+ FUNCTION = "load"
27
+ RETURN_NAMES = ("spar3d_model",)
28
+ RETURN_TYPES = ("SPAR3D_MODEL",)
29
+
30
+ @classmethod
31
+ def INPUT_TYPES(cls):
32
+ return {
33
+ "required": {
34
+ "low_vram_mode": ("BOOLEAN", {"default": False}),
35
+ }
36
+ }
37
+
38
+ def load(self, low_vram_mode=False):
39
+ device = comfy.model_management.get_torch_device()
40
+ model = SPAR3D.from_pretrained(
41
+ SPAR3D_MODEL_NAME,
42
+ config_name="config.yaml",
43
+ weight_name="model.safetensors",
44
+ low_vram_mode=low_vram_mode,
45
+ )
46
+ model.to(device)
47
+ model.eval()
48
+
49
+ return (model,)
50
+
51
+
52
+ class SPAR3DPreview:
53
+ CATEGORY = SPAR3D_CATEGORY
54
+ FUNCTION = "preview"
55
+ OUTPUT_NODE = True
56
+ RETURN_TYPES = ()
57
+
58
+ @classmethod
59
+ def INPUT_TYPES(s):
60
+ return {"required": {"mesh": ("MESH",)}}
61
+
62
+ def preview(self, mesh):
63
+ glbs = []
64
+ for m in mesh:
65
+ scene = trimesh.Scene(m)
66
+ glb_data = gltf.export_glb(scene, include_normals=True)
67
+ glb_base64 = base64.b64encode(glb_data).decode("utf-8")
68
+ glbs.append(glb_base64)
69
+ return {"ui": {"glbs": glbs}}
70
+
71
+
72
+ class SPAR3DSampler:
73
+ CATEGORY = SPAR3D_CATEGORY
74
+ FUNCTION = "predict"
75
+ RETURN_NAMES = ("mesh", "pointcloud")
76
+ RETURN_TYPES = ("MESH", "POINTCLOUD")
77
+
78
+ @classmethod
79
+ def INPUT_TYPES(s):
80
+ remesh_choices = ["none"]
81
+ if TRIANGLE_REMESH_AVAILABLE:
82
+ remesh_choices.append("triangle")
83
+ if QUAD_REMESH_AVAILABLE:
84
+ remesh_choices.append("quad")
85
+
86
+ opt_dict = {
87
+ "mask": ("MASK",),
88
+ "pointcloud": ("POINTCLOUD",),
89
+ "target_type": (["none", "vertex", "face"],),
90
+ "target_count": (
91
+ "INT",
92
+ {"default": 1000, "min": 3, "max": 20000, "step": 1},
93
+ ),
94
+ "guidance_scale": (
95
+ "FLOAT",
96
+ {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05},
97
+ ),
98
+ "seed": (
99
+ "INT",
100
+ {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1},
101
+ ),
102
+ }
103
+ if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
104
+ opt_dict["remesh"] = (remesh_choices,)
105
+
106
+ return {
107
+ "required": {
108
+ "model": ("SPAR3D_MODEL",),
109
+ "image": ("IMAGE",),
110
+ "foreground_ratio": (
111
+ "FLOAT",
112
+ {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01},
113
+ ),
114
+ "texture_resolution": (
115
+ "INT",
116
+ {"default": 1024, "min": 512, "max": 2048, "step": 256},
117
+ ),
118
+ },
119
+ "optional": opt_dict,
120
+ }
121
+
122
+ def predict(
123
+ s,
124
+ model,
125
+ image,
126
+ mask,
127
+ foreground_ratio,
128
+ texture_resolution,
129
+ pointcloud=None,
130
+ remesh="none",
131
+ target_type="none",
132
+ target_count=1000,
133
+ guidance_scale=3.0,
134
+ seed=42,
135
+ ):
136
+ if image.shape[0] != 1:
137
+ raise ValueError("Only one image can be processed at a time")
138
+
139
+ vertex_count = (
140
+ -1
141
+ if target_type == "none"
142
+ else (target_count // 2 if target_type == "face" else target_count)
143
+ )
144
+
145
+ pil_image = Image.fromarray(
146
+ torch.clamp(torch.round(255.0 * image[0]), 0, 255)
147
+ .type(torch.uint8)
148
+ .cpu()
149
+ .numpy()
150
+ )
151
+
152
+ if mask is not None:
153
+ print("Using Mask")
154
+ mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
155
+ np.uint8
156
+ )
157
+ mask_pil = Image.fromarray(mask_np, mode="L")
158
+ pil_image.putalpha(mask_pil)
159
+ else:
160
+ if image.shape[3] != 4:
161
+ print("No mask or alpha channel detected, Converting to RGBA")
162
+ pil_image = pil_image.convert("RGBA")
163
+
164
+ pil_image = foreground_crop(pil_image, foreground_ratio)
165
+
166
+ model.cfg.guidance_scale = guidance_scale
167
+ random.seed(seed)
168
+ torch.manual_seed(seed)
169
+ np.random.seed(seed)
170
+
171
+ print(remesh)
172
+ with torch.no_grad():
173
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
174
+ if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle":
175
+ raise ImportError(
176
+ "Triangle remeshing requires gpytoolbox to be installed"
177
+ )
178
+ if not QUAD_REMESH_AVAILABLE and remesh == "quad":
179
+ raise ImportError("Quad remeshing requires pynim to be installed")
180
+ mesh, glob_dict = model.run_image(
181
+ pil_image,
182
+ bake_resolution=texture_resolution,
183
+ pointcloud=pointcloud,
184
+ remesh=remesh,
185
+ vertex_count=vertex_count,
186
+ )
187
+
188
+ if mesh.vertices.shape[0] == 0:
189
+ raise ValueError("No subject detected in the image")
190
+
191
+ return (
192
+ [mesh],
193
+ glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(),
194
+ )
195
+
196
+
197
+ class SPAR3DSave:
198
+ CATEGORY = SPAR3D_CATEGORY
199
+ FUNCTION = "save"
200
+ OUTPUT_NODE = True
201
+ RETURN_TYPES = ()
202
+
203
+ @classmethod
204
+ def INPUT_TYPES(s):
205
+ return {
206
+ "required": {
207
+ "mesh": ("MESH",),
208
+ "filename_prefix": ("STRING", {"default": "SPAR3D"}),
209
+ }
210
+ }
211
+
212
+ def __init__(self):
213
+ self.type = "output"
214
+
215
+ def save(self, mesh, filename_prefix):
216
+ output_dir = folder_paths.get_output_directory()
217
+ glbs = []
218
+ for idx, m in enumerate(mesh):
219
+ scene = trimesh.Scene(m)
220
+ glb_data = gltf.export_glb(scene, include_normals=True)
221
+ logging.info(f"Generated GLB model with {len(glb_data)} bytes")
222
+
223
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
224
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
225
+ )
226
+ filename = filename.replace("%batch_num%", str(idx))
227
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
228
+ with open(out_path, "wb") as f:
229
+ f.write(glb_data)
230
+ glbs.append(base64.b64encode(glb_data).decode("utf-8"))
231
+ return {"ui": {"glbs": glbs}}
232
+
233
+
234
+ class SPAR3DPointCloudLoader:
235
+ CATEGORY = SPAR3D_CATEGORY
236
+ FUNCTION = "load_pointcloud"
237
+ RETURN_TYPES = ("POINTCLOUD",)
238
+ RETURN_NAMES = ("pointcloud",)
239
+
240
+ @classmethod
241
+ def INPUT_TYPES(cls):
242
+ return {
243
+ "required": {
244
+ "file": ("STRING", {"default": None}),
245
+ }
246
+ }
247
+
248
+ def load_pointcloud(self, file):
249
+ if file is None or file == "":
250
+ return (None,)
251
+ # Load the mesh using trimesh
252
+ mesh = trimesh.load(file)
253
+
254
+ # Extract vertices and colors
255
+ vertices = mesh.vertices
256
+
257
+ # Get vertex colors, defaulting to white if none exist
258
+ if mesh.visual.vertex_colors is not None:
259
+ colors = (
260
+ mesh.visual.vertex_colors[:, :3] / 255.0
261
+ ) # Convert 0-255 to 0-1 range
262
+ else:
263
+ colors = np.ones((len(vertices), 3))
264
+
265
+ # Interleave XYZ and RGB values
266
+ point_cloud = []
267
+ for vertex, color in zip(vertices, colors):
268
+ point_cloud.extend(
269
+ [
270
+ float(vertex[0]),
271
+ float(vertex[1]),
272
+ float(vertex[2]),
273
+ float(color[0]),
274
+ float(color[1]),
275
+ float(color[2]),
276
+ ]
277
+ )
278
+
279
+ return (point_cloud,)
280
+
281
+
282
+ class SPAR3DPointCloudSaver:
283
+ CATEGORY = SPAR3D_CATEGORY
284
+ FUNCTION = "save_pointcloud"
285
+ OUTPUT_NODE = True
286
+ RETURN_TYPES = ()
287
+
288
+ @classmethod
289
+ def INPUT_TYPES(s):
290
+ return {
291
+ "required": {
292
+ "pointcloud": ("POINTCLOUD",),
293
+ "filename_prefix": ("STRING", {"default": "SPAR3D"}),
294
+ }
295
+ }
296
+
297
+ def save_pointcloud(self, pointcloud, filename_prefix):
298
+ if pointcloud is None:
299
+ return {"ui": {"text": "No point cloud data to save"}}
300
+
301
+ # Reshape the flat list into points with XYZ and RGB
302
+ points = np.array(pointcloud).reshape(-1, 6)
303
+
304
+ # Create vertex array for PLY
305
+ vertex_array = np.zeros(
306
+ len(points),
307
+ dtype=[
308
+ ("x", "f4"),
309
+ ("y", "f4"),
310
+ ("z", "f4"),
311
+ ("red", "u1"),
312
+ ("green", "u1"),
313
+ ("blue", "u1"),
314
+ ],
315
+ )
316
+
317
+ # Fill vertex array
318
+ vertex_array["x"] = points[:, 0]
319
+ vertex_array["y"] = points[:, 1]
320
+ vertex_array["z"] = points[:, 2]
321
+ # Convert RGB from 0-1 to 0-255 range
322
+ vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8)
323
+ vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8)
324
+ vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8)
325
+
326
+ # Create PLY object
327
+ ply_data = trimesh.PointCloud(
328
+ vertices=points[:, :3], colors=points[:, 3:] * 255
329
+ )
330
+
331
+ # Save to file
332
+ output_dir = folder_paths.get_output_directory()
333
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
334
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
335
+ )
336
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply")
337
+
338
+ ply_data.export(out_path)
339
+
340
+ return {"ui": {"text": f"Saved point cloud to {out_path}"}}
341
+
342
+
343
+ NODE_DISPLAY_NAME_MAPPINGS = {
344
+ "SPAR3DLoader": "SPAR3D Loader",
345
+ "SPAR3DPreview": "SPAR3D Preview",
346
+ "SPAR3DSampler": "SPAR3D Sampler",
347
+ "SPAR3DSave": "SPAR3D Save",
348
+ "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader",
349
+ "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver",
350
+ }
351
+
352
+ NODE_CLASS_MAPPINGS = {
353
+ "SPAR3DLoader": SPAR3DLoader,
354
+ "SPAR3DPreview": SPAR3DPreview,
355
+ "SPAR3DSampler": SPAR3DSampler,
356
+ "SPAR3DSave": SPAR3DSave,
357
+ "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader,
358
+ "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver,
359
+ }
360
+
361
+ WEB_DIRECTORY = "./comfyui"
362
+
363
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
demo_files/comp.gif ADDED

Git LFS Details

  • SHA256: 6190ca0c3bd164d37152ba985abea53e642fe5e434ca0a932a3b2c4dce698f6b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
demo_files/examples/bird.png ADDED

Git LFS Details

  • SHA256: 83373e2b75ebaad76b6fe093973ea1dc96c92527c8376062cf520ed9215f3e82
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
demo_files/examples/castle.png ADDED

Git LFS Details

  • SHA256: ededd2fe4c122cadfb4f2a485dfd82f83dc1ec6446c7a799d5fc1e1f103ae4b1
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
demo_files/examples/chest.png ADDED

Git LFS Details

  • SHA256: f1eec59b35c63aa50942edff37f0cbdea7d8360cd036a4b7eb9460afdfcbabd9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
demo_files/examples/doll.png ADDED

Git LFS Details

  • SHA256: fc5af86defd0a4fd7285e17a0eb8a108b9f33774408c194a594964d8d6e66c26
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
demo_files/examples/excavator.png ADDED

Git LFS Details

  • SHA256: 6f68c6ba4a9dc884d3786d98c4f0d835682bad02e85716d3a60fd2feedcb03d8
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
demo_files/examples/fish.png ADDED

Git LFS Details

  • SHA256: cd623d8b654de81e022e3741576a0d08dd26d6ba92ee1989605347ef26c399bb
  • Pointer size: 131 Bytes
  • Size of remote file: 838 kB
demo_files/examples/horse-statue.png ADDED

Git LFS Details

  • SHA256: c9c00f726efe9490b02d4c232293b629e0146dad6ce1ff8e22da8102345c5fe9
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
demo_files/examples/penguin.png ADDED

Git LFS Details

  • SHA256: 7a1667d874e9379a8d36e676fb80327bd7b5d3673cb77d7d4cf27bb53408fb98
  • Pointer size: 131 Bytes
  • Size of remote file: 659 kB
demo_files/examples/pot.png ADDED

Git LFS Details

  • SHA256: 32d5d8c110646a46ca24a4d6994cb848ef79cc7ad78dcc7419be0e6f02476a86
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
demo_files/examples/raccoon_wizard.png ADDED

Git LFS Details

  • SHA256: 32cc3850d9f48548882c7b148e508e8ab149bc4f363611e9739adcbd38e8b16d
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
demo_files/examples/stylized-rocks.png ADDED

Git LFS Details

  • SHA256: 386c3be3a6f24ee52e13f130c1ebc02a1bc46eb2c0ebe90d79ce6f38751f0fc6
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
demo_files/hdri/abandoned_tiled_room_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fea108d2dba9872fcf8d40883cfce96a15b3ad26caf427cf53d5bff27fef1c35
3
+ size 477743
demo_files/hdri/metro_noord_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b85f4e187bbe740f75d120e3798d7f7b6ac2778de1df0e9db5799124d1ad5f12
3
+ size 466724
demo_files/hdri/neon_photostudio_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a888974682d9ee6239539bfec68cc98e53318107705f636366813f7e85a4cec
3
+ size 438060
demo_files/hdri/peppermint_powerplant_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:921449733483c0ce5236503d8b6d80d6d676537bbe480c5929218f66c801fc9b
3
+ size 472851
demo_files/hdri/rainforest_trail_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e737fb83b627762fa2b04f527c82447d48020b2683fa57fc8b1aa29c6e75750d
3
+ size 512033
demo_files/hdri/studio_small_08_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a64f8ec219f9bb9f988e5311ac2eb4234ee4babe9ad6d3b34740851ab74d1e4b
3
+ size 411810
demo_files/hdri/urban_alley_01_1k.hdr ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:798e5664b129d6abc7c77f3b8aaaa19702d973296b785574248b5e57ad2e04b0
3
+ size 457913
demo_files/turntable.gif ADDED

Git LFS Details

  • SHA256: ffb5cfca3da84a569de41535781dfc6103834b99207136eb6cbf72d097799c6c
  • Pointer size: 132 Bytes
  • Size of remote file: 7.58 MB
demo_files/workflows/spar3d_example.json ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "last_node_id": 17,
3
+ "last_link_id": 18,
4
+ "nodes": [
5
+ {
6
+ "id": 10,
7
+ "type": "SPAR3DLoader",
8
+ "pos": [
9
+ 52.92446517944336,
10
+ 394.328369140625
11
+ ],
12
+ "size": [
13
+ 210,
14
+ 26
15
+ ],
16
+ "flags": {},
17
+ "order": 0,
18
+ "mode": 0,
19
+ "inputs": [],
20
+ "outputs": [
21
+ {
22
+ "name": "spar3d_model",
23
+ "type": "SPAR3D_MODEL",
24
+ "links": [
25
+ 10
26
+ ],
27
+ "slot_index": 0
28
+ }
29
+ ],
30
+ "properties": {
31
+ "Node name for S&R": "SPAR3DLoader"
32
+ },
33
+ "widgets_values": []
34
+ },
35
+ {
36
+ "id": 13,
37
+ "type": "LoadImage",
38
+ "pos": [
39
+ -43.437347412109375,
40
+ 482.89678955078125
41
+ ],
42
+ "size": [
43
+ 315,
44
+ 314
45
+ ],
46
+ "flags": {},
47
+ "order": 1,
48
+ "mode": 0,
49
+ "inputs": [],
50
+ "outputs": [
51
+ {
52
+ "name": "IMAGE",
53
+ "type": "IMAGE",
54
+ "links": [
55
+ 11
56
+ ],
57
+ "slot_index": 0
58
+ },
59
+ {
60
+ "name": "MASK",
61
+ "type": "MASK",
62
+ "links": [
63
+ 16
64
+ ],
65
+ "slot_index": 1
66
+ }
67
+ ],
68
+ "properties": {
69
+ "Node name for S&R": "LoadImage"
70
+ },
71
+ "widgets_values": [
72
+ "cat1.png",
73
+ "image"
74
+ ]
75
+ },
76
+ {
77
+ "id": 16,
78
+ "type": "InvertMask",
79
+ "pos": [
80
+ 377.1180419921875,
81
+ 605.384765625
82
+ ],
83
+ "size": [
84
+ 210,
85
+ 26
86
+ ],
87
+ "flags": {},
88
+ "order": 2,
89
+ "mode": 0,
90
+ "inputs": [
91
+ {
92
+ "name": "mask",
93
+ "type": "MASK",
94
+ "link": 16
95
+ }
96
+ ],
97
+ "outputs": [
98
+ {
99
+ "name": "MASK",
100
+ "type": "MASK",
101
+ "links": [
102
+ 17
103
+ ],
104
+ "slot_index": 0
105
+ }
106
+ ],
107
+ "properties": {
108
+ "Node name for S&R": "InvertMask"
109
+ },
110
+ "widgets_values": []
111
+ },
112
+ {
113
+ "id": 17,
114
+ "type": "SPAR3DSave",
115
+ "pos": [
116
+ 1133.669921875,
117
+ 439.6551513671875
118
+ ],
119
+ "size": [
120
+ 315,
121
+ 58
122
+ ],
123
+ "flags": {},
124
+ "order": 4,
125
+ "mode": 0,
126
+ "inputs": [
127
+ {
128
+ "name": "mesh",
129
+ "type": "MESH",
130
+ "link": 18
131
+ }
132
+ ],
133
+ "outputs": [],
134
+ "properties": {
135
+ "Node name for S&R": "SPAR3DSave"
136
+ },
137
+ "widgets_values": [
138
+ "SPAR3D"
139
+ ]
140
+ },
141
+ {
142
+ "id": 11,
143
+ "type": "SPAR3DSampler",
144
+ "pos": [
145
+ 673.0637817382812,
146
+ 441.2229309082031
147
+ ],
148
+ "size": [
149
+ 315,
150
+ 286
151
+ ],
152
+ "flags": {},
153
+ "order": 3,
154
+ "mode": 0,
155
+ "inputs": [
156
+ {
157
+ "name": "model",
158
+ "type": "SPAR3D_MODEL",
159
+ "link": 10
160
+ },
161
+ {
162
+ "name": "image",
163
+ "type": "IMAGE",
164
+ "link": 11
165
+ },
166
+ {
167
+ "name": "mask",
168
+ "type": "MASK",
169
+ "link": 17,
170
+ "shape": 7
171
+ },
172
+ {
173
+ "name": "pointcloud",
174
+ "type": "POINTCLOUD",
175
+ "link": null,
176
+ "shape": 7
177
+ }
178
+ ],
179
+ "outputs": [
180
+ {
181
+ "name": "mesh",
182
+ "type": "MESH",
183
+ "links": [
184
+ 18
185
+ ],
186
+ "slot_index": 0
187
+ },
188
+ {
189
+ "name": "pointcloud",
190
+ "type": "POINTCLOUD",
191
+ "links": null
192
+ }
193
+ ],
194
+ "properties": {
195
+ "Node name for S&R": "SPAR3DSampler"
196
+ },
197
+ "widgets_values": [
198
+ 1.3,
199
+ 1024,
200
+ "none",
201
+ 1000,
202
+ 3,
203
+ 3727502160,
204
+ "randomize",
205
+ "none"
206
+ ]
207
+ }
208
+ ],
209
+ "links": [
210
+ [
211
+ 10,
212
+ 10,
213
+ 0,
214
+ 11,
215
+ 0,
216
+ "SPAR3D_MODEL"
217
+ ],
218
+ [
219
+ 11,
220
+ 13,
221
+ 0,
222
+ 11,
223
+ 1,
224
+ "IMAGE"
225
+ ],
226
+ [
227
+ 16,
228
+ 13,
229
+ 1,
230
+ 16,
231
+ 0,
232
+ "MASK"
233
+ ],
234
+ [
235
+ 17,
236
+ 16,
237
+ 0,
238
+ 11,
239
+ 2,
240
+ "MASK"
241
+ ],
242
+ [
243
+ 18,
244
+ 11,
245
+ 0,
246
+ 17,
247
+ 0,
248
+ "MESH"
249
+ ]
250
+ ],
251
+ "groups": [],
252
+ "config": {},
253
+ "extra": {
254
+ "ds": {
255
+ "scale": 0.953502721998243,
256
+ "offset": [
257
+ 266.21995970220667,
258
+ 116.75398112171928
259
+ ]
260
+ }
261
+ },
262
+ "version": 0.4
263
+ }
deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0150bf4575b20f53ebd7d495afa2d12d922ad2d430a2a510f12c5febce50d08
3
+ size 1057831
gradio_app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import tempfile
4
+ from typing import Any
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import trimesh
10
+ from transparent_background import Remover
11
+
12
+ # Import and setup SPAR3D
13
+ os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
14
+ import spar3d.utils as spar3d_utils
15
+ from spar3d.system import SPAR3D
16
+
17
+ # Constants
18
+ COND_WIDTH = 512
19
+ COND_HEIGHT = 512
20
+ COND_DISTANCE = 2.2
21
+ COND_FOVY = 0.591627
22
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
23
+
24
+ # Initialize models
25
+ device = spar3d_utils.get_device()
26
+ bg_remover = Remover()
27
+ spar3d_model = SPAR3D.from_pretrained(
28
+ "stabilityai/stable-point-aware-3d",
29
+ config_name="config.yaml",
30
+ weight_name="model.safetensors"
31
+ ).eval().to(device)
32
+
33
+ # Initialize camera parameters
34
+ c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
35
+ intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
36
+ COND_FOVY, COND_HEIGHT, COND_WIDTH
37
+ )
38
+
39
+ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
40
+ """Create an RGBA image from RGB image and optional mask."""
41
+ rgba_image = rgb_image.convert('RGBA')
42
+ if mask is not None:
43
+ # Ensure mask is 2D before converting to alpha
44
+ if len(mask.shape) > 2:
45
+ mask = mask.squeeze()
46
+ alpha = Image.fromarray((mask * 255).astype(np.uint8))
47
+ rgba_image.putalpha(alpha)
48
+ return rgba_image
49
+
50
+ def create_batch(input_image: Image.Image) -> dict[str, Any]:
51
+ """Prepare image batch for model input."""
52
+ # Resize and convert input image to numpy array
53
+ resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
54
+ img_array = np.array(resized_image).astype(np.float32) / 255.0
55
+
56
+ # Extract RGB and alpha channels
57
+ if img_array.shape[-1] == 4: # RGBA
58
+ rgb = img_array[..., :3]
59
+ mask = img_array[..., 3:4]
60
+ else: # RGB
61
+ rgb = img_array
62
+ mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
63
+
64
+ # Convert to tensors while keeping channel-last format
65
+ rgb = torch.from_numpy(rgb).float() # [H, W, 3]
66
+ mask = torch.from_numpy(mask).float() # [H, W, 1]
67
+
68
+ # Create background blend (match channel-last format)
69
+ bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
70
+
71
+ # Blend RGB with background using mask (all in channel-last format)
72
+ rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
73
+
74
+ # Move channels to correct dimension and add batch dimension
75
+ # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
76
+ rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
77
+ mask = mask.unsqueeze(0) # [1, H, W, 1]
78
+
79
+ # Create the batch dictionary
80
+ batch = {
81
+ "rgb_cond": rgb_cond, # [1, H, W, 3]
82
+ "mask_cond": mask, # [1, H, W, 1]
83
+ "c2w_cond": c2w_cond.unsqueeze(0), # [1, 4, 4]
84
+ "intrinsic_cond": intrinsic.unsqueeze(0), # [1, 3, 3]
85
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
86
+ }
87
+
88
+ for k, v in batch.items():
89
+ print(f"[debug] {k} final shape:", v.shape)
90
+
91
+ return batch
92
+
93
+ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
94
+ """Process batch through model and generate point cloud."""
95
+
96
+ batch_size = batch["rgb_cond"].shape[0]
97
+ assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
98
+
99
+ # Generate point cloud tokens
100
+ try:
101
+ cond_tokens = system.forward_pdiff_cond(batch)
102
+ except Exception as e:
103
+ print("\n[ERROR] Failed in forward_pdiff_cond:")
104
+ print(e)
105
+ print("\nInput tensor properties:")
106
+ print("rgb_cond dtype:", batch["rgb_cond"].dtype)
107
+ print("rgb_cond device:", batch["rgb_cond"].device)
108
+ print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
109
+ raise
110
+
111
+ # Sample points
112
+ sample_iter = system.sampler.sample_batch_progressive(
113
+ batch_size,
114
+ cond_tokens,
115
+ guidance_scale=guidance_scale,
116
+ device=device
117
+ )
118
+
119
+ # Get final samples
120
+ for x in sample_iter:
121
+ samples = x["xstart"]
122
+
123
+ pc_cond = samples.permute(0, 2, 1).float()
124
+
125
+ # Normalize point cloud
126
+ pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
127
+
128
+ # Subsample to 512 points
129
+ pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
130
+
131
+ return pc_cond
132
+
133
+ @spaces.GPU
134
+ @torch.inference_mode()
135
+ def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image | None]:
136
+ """Generate image from prompt and convert to 3D model."""
137
+
138
+ # Generate random seed
139
+ seed = np.random.randint(0, np.iinfo(np.int32).max)
140
+
141
+ try:
142
+ rgb_image = image.convert('RGB')
143
+
144
+ # bg_remover returns a PIL Image already, no need to convert
145
+ no_bg_image = bg_remover.process(rgb_image)
146
+ print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
147
+
148
+ # Convert to RGBA if not already
149
+ rgba_image = no_bg_image.convert('RGBA')
150
+ print(f"[debug] rgba_image mode: {rgba_image.mode}")
151
+
152
+ processed_image = spar3d_utils.foreground_crop(
153
+ rgba_image,
154
+ crop_ratio=1.3,
155
+ newsize=(COND_WIDTH, COND_HEIGHT),
156
+ no_crop=False
157
+ )
158
+
159
+ # Show the processed image alpha channel for debugging
160
+ alpha = np.array(processed_image)[:, :, 3]
161
+ print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
162
+
163
+ # Prepare batch for processing
164
+ batch = create_batch(processed_image)
165
+ batch = {k: v.to(device) for k, v in batch.items()}
166
+
167
+ # Generate point cloud
168
+ pc_cond = forward_model(
169
+ batch,
170
+ spar3d_model,
171
+ guidance_scale=3.0,
172
+ seed=seed,
173
+ device=device
174
+ )
175
+ batch["pc_cond"] = pc_cond
176
+
177
+ # Generate mesh
178
+ with torch.no_grad():
179
+ with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
180
+ trimesh_mesh, _ = spar3d_model.generate_mesh(
181
+ batch,
182
+ 1024, # texture_resolution
183
+ remesh="none",
184
+ vertex_count=-1,
185
+ estimate_illumination=True
186
+ )
187
+ trimesh_mesh = trimesh_mesh[0]
188
+
189
+ # Export to GLB
190
+ temp_dir = tempfile.mkdtemp()
191
+ output_path = os.path.join(temp_dir, 'output.glb')
192
+
193
+ trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
194
+
195
+ return output_path
196
+
197
+ except Exception as e:
198
+ print(f"Error during generation: {str(e)}")
199
+ import traceback
200
+ traceback.print_exc()
201
+ return None
202
+
203
+ # Create Gradio app using Blocks
204
+ with gr.Blocks() as demo:
205
+ gr.Markdown("This space is based on [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d) by Stability AI, [Text to 3D](https://huggingface.co/spaces/jbilcke-hf/text-to-3d) by jbilcke-hf.")
206
+
207
+ with gr.Row():
208
+ input_img = gr.Image(
209
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
210
+ )
211
+
212
+ with gr.Row():
213
+ model_output = gr.Model3D(
214
+ label="Generated .GLB model",
215
+ clear_color=[0.0, 0.0, 0.0, 0.0],
216
+ )
217
+
218
+ # Event handler
219
+ input_img.upload(
220
+ fn=generate_and_process_3d,
221
+ inputs=[input_img],
222
+ outputs=[model_output],
223
+ api_name="generate"
224
+ )
225
+
226
+ if __name__ == "__main__":
227
+ demo.queue().launch()
load/tets/160_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
+ size 15408790
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wheel
2
+ setuptools>=69.5.1
3
+ torch
4
+ torchvision
5
+ einops==0.7.0
6
+ jaxtyping==0.2.31
7
+ omegaconf>=2.3.0
8
+ transformers>=4.42.4
9
+ loralib>=0.1.2
10
+ git+https://github.com/openai/CLIP.git
11
+ git+https://github.com/SunzeY/AlphaCLIP.git
12
+ trimesh==4.4.1
13
+ numpy==1.26.4
14
+ huggingface-hub>=0.27.0
15
+ transparent-background==1.3.3
16
+ gradio>=4.43.0
17
+ gradio-litmodel3d>=0.0.1
18
+ gradio-pointcloudeditor>=0.0.9
19
+ opencv-python>=4.10.0.84
20
+ gpytoolbox==0.2.0
21
+ # ./texture_baker/
22
+ # ./uv_unwrapper/
23
+ accelerate
ruff.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [lint]
2
+ ignore = ["F722", "F821"]
3
+ extend-select = ["I"]
run.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from contextlib import nullcontext
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from transparent_background import Remover
9
+
10
+ from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
11
+ from spar3d.system import SPAR3D
12
+ from spar3d.utils import foreground_crop, get_device, remove_background
13
+
14
+
15
+ def check_positive(value):
16
+ ivalue = int(value)
17
+ if ivalue <= 0:
18
+ raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
19
+ return ivalue
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "image", type=str, nargs="+", help="Path to input image(s) or folder."
26
+ )
27
+ parser.add_argument(
28
+ "--device",
29
+ default=get_device(),
30
+ type=str,
31
+ help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
32
+ )
33
+ parser.add_argument(
34
+ "--pretrained-model",
35
+ default="stabilityai/stable-point-aware-3d",
36
+ type=str,
37
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'",
38
+ )
39
+ parser.add_argument(
40
+ "--foreground-ratio",
41
+ default=1.3,
42
+ type=float,
43
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
44
+ )
45
+ parser.add_argument(
46
+ "--output-dir",
47
+ default="output/",
48
+ type=str,
49
+ help="Output directory to save the results. Default: 'output/'",
50
+ )
51
+ parser.add_argument(
52
+ "--texture-resolution",
53
+ default=1024,
54
+ type=int,
55
+ help="Texture atlas resolution. Default: 1024",
56
+ )
57
+ parser.add_argument(
58
+ "--low-vram-mode",
59
+ action="store_true",
60
+ help=(
61
+ "Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. "
62
+ "This mode will reduce the VRAM consumption to roughly 7GB but in exchange "
63
+ "the model will be slower. Default: False"
64
+ ),
65
+ )
66
+
67
+ remesh_choices = ["none"]
68
+ if TRIANGLE_REMESH_AVAILABLE:
69
+ remesh_choices.append("triangle")
70
+ if QUAD_REMESH_AVAILABLE:
71
+ remesh_choices.append("quad")
72
+ parser.add_argument(
73
+ "--remesh_option",
74
+ choices=remesh_choices,
75
+ default="none",
76
+ help="Remeshing option",
77
+ )
78
+ if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
79
+ parser.add_argument(
80
+ "--reduction_count_type",
81
+ choices=["keep", "vertex", "faces"],
82
+ default="keep",
83
+ help="Vertex count type",
84
+ )
85
+ parser.add_argument(
86
+ "--target_count",
87
+ type=check_positive,
88
+ help="Selected target count.",
89
+ default=2000,
90
+ )
91
+ parser.add_argument(
92
+ "--batch_size", default=1, type=int, help="Batch size for inference"
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ # Ensure args.device contains cuda
97
+ devices = ["cuda", "mps", "cpu"]
98
+ if not any(args.device in device for device in devices):
99
+ raise ValueError("Invalid device. Use cuda, mps or cpu")
100
+
101
+ output_dir = args.output_dir
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ device = args.device
105
+ if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
106
+ device = "cpu"
107
+
108
+ print("Device used: ", device)
109
+
110
+ model = SPAR3D.from_pretrained(
111
+ args.pretrained_model,
112
+ config_name="config.yaml",
113
+ weight_name="model.safetensors",
114
+ low_vram_mode=args.low_vram_mode,
115
+ )
116
+ model.to(device)
117
+ model.eval()
118
+
119
+ bg_remover = Remover(device=device)
120
+ images = []
121
+ idx = 0
122
+ for image_path in args.image:
123
+
124
+ def handle_image(image_path, idx):
125
+ image = remove_background(
126
+ Image.open(image_path).convert("RGBA"), bg_remover
127
+ )
128
+ image = foreground_crop(image, args.foreground_ratio)
129
+ os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
130
+ image.save(os.path.join(output_dir, str(idx), "input.png"))
131
+ images.append(image)
132
+
133
+ if os.path.isdir(image_path):
134
+ image_paths = [
135
+ os.path.join(image_path, f)
136
+ for f in os.listdir(image_path)
137
+ if f.endswith((".png", ".jpg", ".jpeg"))
138
+ ]
139
+ for image_path in image_paths:
140
+ handle_image(image_path, idx)
141
+ idx += 1
142
+ else:
143
+ handle_image(image_path, idx)
144
+ idx += 1
145
+
146
+ vertex_count = (
147
+ -1
148
+ if args.reduction_count_type == "keep"
149
+ else (
150
+ args.target_count
151
+ if args.reduction_count_type == "vertex"
152
+ else args.target_count // 2
153
+ )
154
+ )
155
+
156
+ for i in tqdm(range(0, len(images), args.batch_size)):
157
+ image = images[i : i + args.batch_size]
158
+ if torch.cuda.is_available():
159
+ torch.cuda.reset_peak_memory_stats()
160
+ with torch.no_grad():
161
+ with (
162
+ torch.autocast(device_type=device, dtype=torch.bfloat16)
163
+ if "cuda" in device
164
+ else nullcontext()
165
+ ):
166
+ mesh, glob_dict = model.run_image(
167
+ image,
168
+ bake_resolution=args.texture_resolution,
169
+ remesh=args.remesh_option,
170
+ vertex_count=vertex_count,
171
+ return_points=True,
172
+ )
173
+ if torch.cuda.is_available():
174
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
175
+ elif torch.backends.mps.is_available():
176
+ print(
177
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
178
+ )
179
+
180
+ if len(image) == 1:
181
+ out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
182
+ mesh.export(out_mesh_path, include_normals=True)
183
+ out_points_path = os.path.join(output_dir, str(i), "points.ply")
184
+ glob_dict["point_clouds"][0].export(out_points_path)
185
+ else:
186
+ for j in range(len(mesh)):
187
+ out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
188
+ mesh[j].export(out_mesh_path, include_normals=True)
189
+ out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
190
+ glob_dict["point_clouds"][j].export(out_points_path)
spar3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from spar3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
spar3d/models/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from: https://github.com/openai/point-e
3
+ # Licensed under the MIT License
4
+ # Copyright (c) 2022 OpenAI
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ # --------------------------------------------------------
25
+
26
+ import math
27
+ from typing import Any, Dict, Iterable, Optional, Sequence, Union
28
+
29
+ import numpy as np
30
+ import torch as th
31
+
32
+
33
+ def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
34
+ def sigmoid(x):
35
+ return 1 / (1 + np.exp(-x))
36
+
37
+ v_start = sigmoid(start / tau)
38
+ v_end = sigmoid(end / tau)
39
+ output = sigmoid((t * (end - start) + start) / tau)
40
+ output = (v_end - output) / (v_end - v_start)
41
+ return np.clip(output, clip_min, 1.0)
42
+
43
+
44
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
45
+ """
46
+ This is the deprecated API for creating beta schedules.
47
+
48
+ See get_named_beta_schedule() for the new library of schedules.
49
+ """
50
+ if beta_schedule == "linear":
51
+ betas = np.linspace(
52
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
53
+ )
54
+ else:
55
+ raise NotImplementedError(beta_schedule)
56
+ assert betas.shape == (num_diffusion_timesteps,)
57
+ return betas
58
+
59
+
60
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, exp_p=12):
61
+ """
62
+ Get a pre-defined beta schedule for the given name.
63
+
64
+ The beta schedule library consists of beta schedules which remain similar
65
+ in the limit of num_diffusion_timesteps.
66
+ Beta schedules may be added, but should not be removed or changed once
67
+ they are committed to maintain backwards compatibility.
68
+ """
69
+ if schedule_name == "linear":
70
+ # Linear schedule from Ho et al, extended to work for any number of
71
+ # diffusion steps.
72
+ scale = 1000 / num_diffusion_timesteps
73
+ return get_beta_schedule(
74
+ "linear",
75
+ beta_start=scale * 0.0001,
76
+ beta_end=scale * 0.02,
77
+ num_diffusion_timesteps=num_diffusion_timesteps,
78
+ )
79
+ elif schedule_name == "cosine":
80
+ return betas_for_alpha_bar(
81
+ num_diffusion_timesteps,
82
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
83
+ )
84
+ elif schedule_name == "sigmoid":
85
+ # Sigmoid schedule passed through betas_for_alpha_bar
86
+ return betas_for_alpha_bar(
87
+ num_diffusion_timesteps, lambda t: sigmoid_schedule(t)
88
+ )
89
+ else:
90
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
91
+
92
+
93
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
94
+ """
95
+ Create a beta schedule that discretizes the given alpha_t_bar function,
96
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
97
+
98
+ :param num_diffusion_timesteps: the number of betas to produce.
99
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
100
+ produces the cumulative product of (1-beta) up to that
101
+ part of the diffusion process.
102
+ :param max_beta: the maximum beta to use; use values lower than 1 to
103
+ prevent singularities.
104
+ """
105
+ betas = []
106
+ for i in range(num_diffusion_timesteps):
107
+ t1 = i / num_diffusion_timesteps
108
+ t2 = (i + 1) / num_diffusion_timesteps
109
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
110
+ return np.array(betas)
111
+
112
+
113
+ def space_timesteps(num_timesteps, section_counts):
114
+ """
115
+ Create a list of timesteps to use from an original diffusion process,
116
+ given the number of timesteps we want to take from equally-sized portions
117
+ of the original process.
118
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
119
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
120
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
121
+ :param num_timesteps: the number of diffusion steps in the original
122
+ process to divide up.
123
+ :param section_counts: either a list of numbers, or a string containing
124
+ comma-separated numbers, indicating the step count
125
+ per section. As a special case, use "ddimN" where N
126
+ is a number of steps to use the striding from the
127
+ DDIM paper.
128
+ :return: a set of diffusion steps from the original process to use.
129
+ """
130
+ if isinstance(section_counts, str):
131
+ if section_counts.startswith("ddim"):
132
+ desired_count = int(section_counts[len("ddim") :])
133
+ for i in range(1, num_timesteps):
134
+ if len(range(0, num_timesteps, i)) == desired_count:
135
+ return set(range(0, num_timesteps, i))
136
+ raise ValueError(
137
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
138
+ )
139
+ elif section_counts.startswith("exact"):
140
+ res = set(int(x) for x in section_counts[len("exact") :].split(","))
141
+ for x in res:
142
+ if x < 0 or x >= num_timesteps:
143
+ raise ValueError(f"timestep out of bounds: {x}")
144
+ return res
145
+ section_counts = [int(x) for x in section_counts.split(",")]
146
+ size_per = num_timesteps // len(section_counts)
147
+ extra = num_timesteps % len(section_counts)
148
+ start_idx = 0
149
+ all_steps = []
150
+ for i, section_count in enumerate(section_counts):
151
+ size = size_per + (1 if i < extra else 0)
152
+ if size < section_count:
153
+ raise ValueError(
154
+ f"cannot divide section of {size} steps into {section_count}"
155
+ )
156
+ if section_count <= 1:
157
+ frac_stride = 1
158
+ else:
159
+ frac_stride = (size - 1) / (section_count - 1)
160
+ cur_idx = 0.0
161
+ taken_steps = []
162
+ for _ in range(section_count):
163
+ taken_steps.append(start_idx + round(cur_idx))
164
+ cur_idx += frac_stride
165
+ all_steps += taken_steps
166
+ start_idx += size
167
+ return set(all_steps)
168
+
169
+
170
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
171
+ """Extract values from a 1-D numpy array for a batch of indices."""
172
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
173
+ while len(res.shape) < len(broadcast_shape):
174
+ res = res[..., None]
175
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
176
+
177
+
178
+ class GaussianDiffusion:
179
+ """
180
+ Utilities for sampling from Gaussian diffusion models.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ betas: Sequence[float],
187
+ model_mean_type: str,
188
+ model_var_type: str,
189
+ channel_scales: Optional[np.ndarray] = None,
190
+ channel_biases: Optional[np.ndarray] = None,
191
+ ):
192
+ self.model_mean_type = model_mean_type
193
+ self.model_var_type = model_var_type
194
+ self.channel_scales = channel_scales
195
+ self.channel_biases = channel_biases
196
+
197
+ # Use float64 for accuracy
198
+ betas = np.array(betas, dtype=np.float64)
199
+ self.betas = betas
200
+ assert len(betas.shape) == 1, "betas must be 1-D"
201
+ assert (betas > 0).all() and (betas <= 1).all()
202
+
203
+ self.num_timesteps = int(betas.shape[0])
204
+
205
+ alphas = 1.0 - betas
206
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
207
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
208
+
209
+ # calculations for diffusion q(x_t | x_{t-1}) and others
210
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
211
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
212
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
213
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
214
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
215
+ self.posterior_variance = (
216
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
217
+ )
218
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
219
+ self.posterior_log_variance_clipped = np.log(
220
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
221
+ )
222
+
223
+ self.posterior_mean_coef1 = (
224
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
225
+ )
226
+ self.posterior_mean_coef2 = (
227
+ (1.0 - self.alphas_cumprod_prev)
228
+ * np.sqrt(alphas)
229
+ / (1.0 - self.alphas_cumprod)
230
+ )
231
+
232
+ def scale_channels(self, x: th.Tensor) -> th.Tensor:
233
+ """Apply channel-wise scaling."""
234
+ if self.channel_scales is not None:
235
+ x = x * th.from_numpy(self.channel_scales).to(x).reshape(
236
+ [1, -1, *([1] * (len(x.shape) - 2))]
237
+ )
238
+ if self.channel_biases is not None:
239
+ x = x + th.from_numpy(self.channel_biases).to(x).reshape(
240
+ [1, -1, *([1] * (len(x.shape) - 2))]
241
+ )
242
+ return x
243
+
244
+ def unscale_channels(self, x: th.Tensor) -> th.Tensor:
245
+ """Remove channel-wise scaling."""
246
+ if self.channel_biases is not None:
247
+ x = x - th.from_numpy(self.channel_biases).to(x).reshape(
248
+ [1, -1, *([1] * (len(x.shape) - 2))]
249
+ )
250
+ if self.channel_scales is not None:
251
+ x = x / th.from_numpy(self.channel_scales).to(x).reshape(
252
+ [1, -1, *([1] * (len(x.shape) - 2))]
253
+ )
254
+ return x
255
+
256
+ def unscale_out_dict(
257
+ self, out: Dict[str, Union[th.Tensor, Any]]
258
+ ) -> Dict[str, Union[th.Tensor, Any]]:
259
+ return {
260
+ k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v)
261
+ for k, v in out.items()
262
+ }
263
+
264
+ def q_posterior_mean_variance(self, x_start, x_t, t):
265
+ """
266
+ Compute the mean and variance of the diffusion posterior:
267
+
268
+ q(x_{t-1} | x_t, x_0)
269
+
270
+ """
271
+ assert x_start.shape == x_t.shape
272
+ posterior_mean = (
273
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
274
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
275
+ )
276
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
277
+ posterior_log_variance_clipped = _extract_into_tensor(
278
+ self.posterior_log_variance_clipped, t, x_t.shape
279
+ )
280
+ assert (
281
+ posterior_mean.shape[0]
282
+ == posterior_variance.shape[0]
283
+ == posterior_log_variance_clipped.shape[0]
284
+ == x_start.shape[0]
285
+ )
286
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
287
+
288
+ def p_mean_variance(
289
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
290
+ ):
291
+ """
292
+ Apply the model to get p(x_{t-1} | x_t).
293
+ """
294
+ if model_kwargs is None:
295
+ model_kwargs = {}
296
+
297
+ B, C = x.shape[:2]
298
+ assert t.shape == (B,)
299
+
300
+ # Direct prediction of eps
301
+ model_output = model(x, t, **model_kwargs)
302
+ if isinstance(model_output, tuple):
303
+ model_output, prev_latent = model_output
304
+ model_kwargs["prev_latent"] = prev_latent
305
+
306
+ # Convert model output to mean and variance
307
+ model_variance, model_log_variance = {
308
+ # for fixedlarge, we set the initial (log-)variance like so
309
+ # to get a better decoder log likelihood.
310
+ "fixed_large": (
311
+ np.append(self.posterior_variance[1], self.betas[1:]),
312
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
313
+ ),
314
+ "fixed_small": (
315
+ self.posterior_variance,
316
+ self.posterior_log_variance_clipped,
317
+ ),
318
+ }[self.model_var_type]
319
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
320
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
321
+
322
+ def process_xstart(x):
323
+ if denoised_fn is not None:
324
+ x = denoised_fn(x)
325
+ if clip_denoised:
326
+ x = x.clamp(
327
+ -self.channel_scales[0] * 0.67, self.channel_scales[0] * 0.67
328
+ )
329
+ x[:, 3:] = x[:, 3:].clamp(
330
+ -self.channel_scales[3] * 0.5, self.channel_scales[3] * 0.5
331
+ )
332
+ return x
333
+ return x
334
+
335
+ if self.model_mean_type == "x_prev":
336
+ pred_xstart = process_xstart(
337
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
338
+ )
339
+ model_mean = model_output
340
+ elif self.model_mean_type in ["x_start", "epsilon"]:
341
+ if self.model_mean_type == "x_start":
342
+ pred_xstart = process_xstart(model_output)
343
+ else:
344
+ pred_xstart = process_xstart(
345
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
346
+ )
347
+ model_mean, _, _ = self.q_posterior_mean_variance(
348
+ x_start=pred_xstart, x_t=x, t=t
349
+ )
350
+ # print('p_mean_variance:', pred_xstart.min(), pred_xstart.max())
351
+ else:
352
+ raise NotImplementedError(self.model_mean_type)
353
+
354
+ assert (
355
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
356
+ )
357
+ return {
358
+ "mean": model_mean,
359
+ "variance": model_variance,
360
+ "log_variance": model_log_variance,
361
+ "pred_xstart": pred_xstart,
362
+ }
363
+
364
+ def _predict_xstart_from_eps(self, x_t, t, eps):
365
+ assert x_t.shape == eps.shape
366
+ return (
367
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
368
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
369
+ )
370
+
371
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
372
+ assert x_t.shape == xprev.shape
373
+ return ( # (xprev - coef2*x_t) / coef1
374
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
375
+ - _extract_into_tensor(
376
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
377
+ )
378
+ * x_t
379
+ )
380
+
381
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
382
+ return (
383
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
384
+ - pred_xstart
385
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
386
+
387
+ def ddim_sample_loop_progressive(
388
+ self,
389
+ model,
390
+ shape,
391
+ noise=None,
392
+ clip_denoised=True,
393
+ denoised_fn=None,
394
+ model_kwargs=None,
395
+ device=None,
396
+ progress=False,
397
+ eta=0.0,
398
+ ):
399
+ """
400
+ Use DDIM to sample from the model and yield intermediate samples.
401
+ """
402
+ if device is None:
403
+ device = next(model.parameters()).device
404
+ assert isinstance(shape, (tuple, list))
405
+ if noise is not None:
406
+ img = noise
407
+ else:
408
+ img = th.randn(*shape, device=device)
409
+
410
+ indices = list(range(self.num_timesteps))[::-1]
411
+
412
+ if progress:
413
+ from tqdm.auto import tqdm
414
+
415
+ indices = tqdm(indices)
416
+
417
+ for i in indices:
418
+ t = th.tensor([i] * shape[0], device=device)
419
+ with th.no_grad():
420
+ out = self.ddim_sample(
421
+ model,
422
+ img,
423
+ t,
424
+ clip_denoised=clip_denoised,
425
+ denoised_fn=denoised_fn,
426
+ model_kwargs=model_kwargs,
427
+ eta=eta,
428
+ )
429
+ yield self.unscale_out_dict(out)
430
+ img = out["sample"]
431
+
432
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
433
+ return (
434
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
435
+ - pred_xstart
436
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
437
+
438
+ def ddim_sample(
439
+ self,
440
+ model,
441
+ x,
442
+ t,
443
+ clip_denoised=True,
444
+ denoised_fn=None,
445
+ model_kwargs=None,
446
+ eta=0.0,
447
+ ):
448
+ """
449
+ Sample x_{t-1} from the model using DDIM.
450
+ """
451
+ out = self.p_mean_variance(
452
+ model,
453
+ x,
454
+ t,
455
+ clip_denoised=clip_denoised,
456
+ denoised_fn=denoised_fn,
457
+ model_kwargs=model_kwargs,
458
+ )
459
+
460
+ # Usually our model outputs epsilon, but we re-derive it
461
+ # in case we used x_start or x_prev prediction.
462
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
463
+
464
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
465
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
466
+ sigma = (
467
+ eta
468
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
469
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
470
+ )
471
+
472
+ # Equation 12.
473
+ noise = th.randn_like(x)
474
+ mean_pred = (
475
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
476
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
477
+ )
478
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
479
+ sample = mean_pred + nonzero_mask * sigma * noise
480
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
481
+
482
+
483
+ class SpacedDiffusion(GaussianDiffusion):
484
+ """
485
+ A diffusion process which can skip steps in a base diffusion process.
486
+ """
487
+
488
+ def __init__(self, use_timesteps: Iterable[int], **kwargs):
489
+ self.use_timesteps = set(use_timesteps)
490
+ self.timestep_map = []
491
+ self.original_num_steps = len(kwargs["betas"])
492
+
493
+ base_diffusion = GaussianDiffusion(**kwargs)
494
+ last_alpha_cumprod = 1.0
495
+ new_betas = []
496
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
497
+ if i in self.use_timesteps:
498
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
499
+ last_alpha_cumprod = alpha_cumprod
500
+ self.timestep_map.append(i)
501
+ kwargs["betas"] = np.array(new_betas)
502
+ super().__init__(**kwargs)
503
+
504
+ def p_mean_variance(self, model, *args, **kwargs):
505
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
506
+
507
+ def _wrap_model(self, model):
508
+ if isinstance(model, _WrappedModel):
509
+ return model
510
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
511
+
512
+
513
+ class _WrappedModel:
514
+ """Helper class to wrap models for SpacedDiffusion."""
515
+
516
+ def __init__(self, model, timestep_map, original_num_steps):
517
+ self.model = model
518
+ self.timestep_map = timestep_map
519
+ self.original_num_steps = original_num_steps
520
+
521
+ def __call__(self, x, ts, **kwargs):
522
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
523
+ new_ts = map_tensor[ts]
524
+ return self.model(x, new_ts, **kwargs)
spar3d/models/diffusion/sampler.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from: https://github.com/openai/point-e
3
+ # Licensed under the MIT License
4
+ # Copyright (c) 2022 OpenAI
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ # --------------------------------------------------------
25
+
26
+ from typing import Dict, Iterator
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from .gaussian_diffusion import GaussianDiffusion
32
+
33
+
34
+ class PointCloudSampler:
35
+ """
36
+ A wrapper around a model that produces conditional sample tensors.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: nn.Module,
42
+ diffusion: GaussianDiffusion,
43
+ num_points: int,
44
+ point_dim: int = 3,
45
+ guidance_scale: float = 3.0,
46
+ clip_denoised: bool = True,
47
+ sigma_min: float = 1e-3,
48
+ sigma_max: float = 120,
49
+ s_churn: float = 3,
50
+ ):
51
+ self.model = model
52
+ self.num_points = num_points
53
+ self.point_dim = point_dim
54
+ self.guidance_scale = guidance_scale
55
+ self.clip_denoised = clip_denoised
56
+ self.sigma_min = sigma_min
57
+ self.sigma_max = sigma_max
58
+ self.s_churn = s_churn
59
+
60
+ self.diffusion = diffusion
61
+
62
+ def sample_batch_progressive(
63
+ self,
64
+ batch_size: int,
65
+ condition: torch.Tensor,
66
+ noise=None,
67
+ device=None,
68
+ guidance_scale=None,
69
+ ) -> Iterator[Dict[str, torch.Tensor]]:
70
+ """
71
+ Generate samples progressively using classifier-free guidance.
72
+
73
+ Args:
74
+ batch_size: Number of samples to generate
75
+ condition: Conditioning tensor
76
+ noise: Optional initial noise tensor
77
+ device: Device to run on
78
+ guidance_scale: Optional override for guidance scale
79
+
80
+ Returns:
81
+ Iterator of dicts containing intermediate samples
82
+ """
83
+ if guidance_scale is None:
84
+ guidance_scale = self.guidance_scale
85
+
86
+ sample_shape = (batch_size, self.point_dim, self.num_points)
87
+
88
+ # Double the batch for classifier-free guidance
89
+ if guidance_scale != 1 and guidance_scale != 0:
90
+ condition = torch.cat([condition, torch.zeros_like(condition)], dim=0)
91
+ if noise is not None:
92
+ noise = torch.cat([noise, noise], dim=0)
93
+ model_kwargs = {"condition": condition}
94
+
95
+ internal_batch_size = batch_size
96
+ if guidance_scale != 1 and guidance_scale != 0:
97
+ model = self._uncond_guide_model(self.model, guidance_scale)
98
+ internal_batch_size *= 2
99
+ else:
100
+ model = self.model
101
+
102
+ samples_it = self.diffusion.ddim_sample_loop_progressive(
103
+ model,
104
+ shape=(internal_batch_size, *sample_shape[1:]),
105
+ model_kwargs=model_kwargs,
106
+ device=device,
107
+ clip_denoised=self.clip_denoised,
108
+ noise=noise,
109
+ )
110
+
111
+ for x in samples_it:
112
+ samples = {
113
+ "xstart": x["pred_xstart"][:batch_size],
114
+ "xprev": x["sample"][:batch_size] if "sample" in x else x["x"],
115
+ }
116
+ yield samples
117
+
118
+ def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module:
119
+ """
120
+ Wraps the model for classifier-free guidance.
121
+ """
122
+
123
+ def model_fn(x_t, ts, **kwargs):
124
+ half = x_t[: len(x_t) // 2]
125
+ combined = torch.cat([half, half], dim=0)
126
+ model_out = model(combined, ts, **kwargs)
127
+
128
+ eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :]
129
+ cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
130
+ half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
131
+ eps = torch.cat([half_eps, half_eps], dim=0)
132
+ return torch.cat([eps, rest], dim=1)
133
+
134
+ return model_fn
spar3d/models/global_estimator/reni_estimator.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.illumination.reni.env_map import RENIEnvMap
11
+ from spar3d.models.utils import BaseModule
12
+
13
+
14
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
15
+ assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"
16
+
17
+ def proj_u2a(u, a):
18
+ r"""
19
+ u: batch x 3
20
+ a: batch x 3
21
+ """
22
+ inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
23
+ norm2 = torch.sum(u**2, dim=-1, keepdim=True)
24
+ norm2 = torch.clamp(norm2, min=1e-8)
25
+ factor = inner_prod / (norm2 + 1e-10)
26
+ return factor * u
27
+
28
+ x_raw, y_raw = d6[..., :3], d6[..., 3:]
29
+
30
+ x = F.normalize(x_raw, dim=-1)
31
+ y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
32
+ z = torch.cross(x, y, dim=-1)
33
+
34
+ return torch.stack((x, y, z), dim=-1)
35
+
36
+
37
+ class ReniLatentCodeEstimator(BaseModule):
38
+ @dataclass
39
+ class Config(BaseModule.Config):
40
+ triplane_features: int = 40
41
+
42
+ n_layers: int = 5
43
+ hidden_features: int = 512
44
+ activation: str = "relu"
45
+
46
+ pool: str = "mean"
47
+
48
+ reni_env_config: dict = field(default_factory=dict)
49
+
50
+ cfg: Config
51
+
52
+ def configure(self):
53
+ layers = []
54
+ cur_features = self.cfg.triplane_features * 3
55
+ for _ in range(self.cfg.n_layers):
56
+ layers.append(
57
+ nn.Conv2d(
58
+ cur_features,
59
+ self.cfg.hidden_features,
60
+ kernel_size=3,
61
+ padding=0,
62
+ stride=2,
63
+ )
64
+ )
65
+ layers.append(self.make_activation(self.cfg.activation))
66
+
67
+ cur_features = self.cfg.hidden_features
68
+
69
+ self.layers = nn.Sequential(*layers)
70
+
71
+ self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
72
+ self.latent_dim = self.reni_env_map.field.latent_dim
73
+
74
+ self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
75
+ nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)
76
+
77
+ self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
78
+ nn.init.constant_(self.fc_rotations.bias, 0.0)
79
+ nn.init.normal_(
80
+ self.fc_rotations.weight, mean=0.0, std=0.01
81
+ ) # Small variance here
82
+
83
+ self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
84
+ nn.init.constant_(self.fc_scale.bias, 0.0)
85
+ nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01) # Small variance here
86
+
87
+ def make_activation(self, activation):
88
+ if activation == "relu":
89
+ return nn.ReLU(inplace=True)
90
+ elif activation == "silu":
91
+ return nn.SiLU(inplace=True)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ def forward(
96
+ self,
97
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
98
+ rotation: Optional[Float[Tensor, "B 3 3"]] = None,
99
+ ) -> dict[str, Any]:
100
+ x = self.layers(
101
+ triplane.reshape(
102
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
103
+ )
104
+ )
105
+ x = x.mean(dim=[-2, -1])
106
+
107
+ latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
108
+ rotations = rotation_6d_to_matrix(self.fc_rotations(x))
109
+ scale = self.fc_scale(x)
110
+
111
+ if rotation is not None:
112
+ rotations = rotations @ rotation.to(dtype=rotations.dtype)
113
+
114
+ env_map = self.reni_env_map(latents, rotations, scale)
115
+
116
+ return {"illumination": env_map["rgb"]}
spar3d/models/illumination/reni/components/film_siren.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/."""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ def kaiming_leaky_init(m):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Linear") != -1:
13
+ torch.nn.init.kaiming_normal_(
14
+ m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu"
15
+ )
16
+
17
+
18
+ def frequency_init(freq):
19
+ def init(m):
20
+ with torch.no_grad():
21
+ if isinstance(m, nn.Linear):
22
+ num_input = m.weight.size(-1)
23
+ m.weight.uniform_(
24
+ -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq
25
+ )
26
+
27
+ return init
28
+
29
+
30
+ def first_layer_film_sine_init(m):
31
+ with torch.no_grad():
32
+ if isinstance(m, nn.Linear):
33
+ num_input = m.weight.size(-1)
34
+ m.weight.uniform_(-1 / num_input, 1 / num_input)
35
+
36
+
37
+ class CustomMappingNetwork(nn.Module):
38
+ def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim):
39
+ super().__init__()
40
+
41
+ self.network = []
42
+
43
+ for _ in range(map_hidden_layers):
44
+ self.network.append(nn.Linear(in_features, map_hidden_dim))
45
+ self.network.append(nn.LeakyReLU(0.2, inplace=True))
46
+ in_features = map_hidden_dim
47
+
48
+ self.network.append(nn.Linear(map_hidden_dim, map_output_dim))
49
+
50
+ self.network = nn.Sequential(*self.network)
51
+
52
+ self.network.apply(kaiming_leaky_init)
53
+ with torch.no_grad():
54
+ self.network[-1].weight *= 0.25
55
+
56
+ def forward(self, z):
57
+ frequencies_offsets = self.network(z)
58
+ frequencies = frequencies_offsets[
59
+ ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor")
60
+ ]
61
+ phase_shifts = frequencies_offsets[
62
+ ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") :
63
+ ]
64
+
65
+ return frequencies, phase_shifts
66
+
67
+
68
+ class FiLMLayer(nn.Module):
69
+ def __init__(self, input_dim, hidden_dim):
70
+ super().__init__()
71
+ self.layer = nn.Linear(input_dim, hidden_dim)
72
+
73
+ def forward(self, x, freq, phase_shift):
74
+ x = self.layer(x)
75
+ freq = freq.expand_as(x)
76
+ phase_shift = phase_shift.expand_as(x)
77
+ return torch.sin(freq * x + phase_shift)
78
+
79
+
80
+ class FiLMSiren(nn.Module):
81
+ """FiLM Conditioned Siren network."""
82
+
83
+ def __init__(
84
+ self,
85
+ in_dim: int,
86
+ hidden_layers: int,
87
+ hidden_features: int,
88
+ mapping_network_in_dim: int,
89
+ mapping_network_layers: int,
90
+ mapping_network_features: int,
91
+ out_dim: int,
92
+ outermost_linear: bool = False,
93
+ out_activation: Optional[nn.Module] = None,
94
+ ) -> None:
95
+ super().__init__()
96
+ self.in_dim = in_dim
97
+ assert self.in_dim > 0
98
+ self.out_dim = out_dim if out_dim is not None else hidden_features
99
+ self.hidden_layers = hidden_layers
100
+ self.hidden_features = hidden_features
101
+ self.mapping_network_in_dim = mapping_network_in_dim
102
+ self.mapping_network_layers = mapping_network_layers
103
+ self.mapping_network_features = mapping_network_features
104
+ self.outermost_linear = outermost_linear
105
+ self.out_activation = out_activation
106
+
107
+ self.net = nn.ModuleList()
108
+
109
+ self.net.append(FiLMLayer(self.in_dim, self.hidden_features))
110
+
111
+ for _ in range(self.hidden_layers - 1):
112
+ self.net.append(FiLMLayer(self.hidden_features, self.hidden_features))
113
+
114
+ self.final_layer = None
115
+ if self.outermost_linear:
116
+ self.final_layer = nn.Linear(self.hidden_features, self.out_dim)
117
+ self.final_layer.apply(frequency_init(25))
118
+ else:
119
+ final_layer = FiLMLayer(self.hidden_features, self.out_dim)
120
+ self.net.append(final_layer)
121
+
122
+ self.mapping_network = CustomMappingNetwork(
123
+ in_features=self.mapping_network_in_dim,
124
+ map_hidden_layers=self.mapping_network_layers,
125
+ map_hidden_dim=self.mapping_network_features,
126
+ map_output_dim=(len(self.net)) * self.hidden_features * 2,
127
+ )
128
+
129
+ self.net.apply(frequency_init(25))
130
+ self.net[0].apply(first_layer_film_sine_init)
131
+
132
+ def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts):
133
+ """Get conditiional frequencies and phase shifts from mapping network."""
134
+ frequencies = frequencies * 15 + 30
135
+
136
+ for index, layer in enumerate(self.net):
137
+ start = index * self.hidden_features
138
+ end = (index + 1) * self.hidden_features
139
+ x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
140
+
141
+ x = self.final_layer(x) if self.final_layer is not None else x
142
+ output = self.out_activation(x) if self.out_activation is not None else x
143
+ return output
144
+
145
+ def forward(self, x, conditioning_input):
146
+ """Forward pass."""
147
+ frequencies, phase_shifts = self.mapping_network(conditioning_input)
148
+ return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)
spar3d/models/illumination/reni/components/siren.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Siren MLP https://www.vincentsitzmann.com/siren/"""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class SineLayer(nn.Module):
11
+ """
12
+ Sine layer for the SIREN network.
13
+ """
14
+
15
+ def __init__(
16
+ self, in_features, out_features, bias=True, is_first=False, omega_0=30.0
17
+ ):
18
+ super().__init__()
19
+ self.omega_0 = omega_0
20
+ self.is_first = is_first
21
+
22
+ self.in_features = in_features
23
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
24
+
25
+ self.init_weights()
26
+
27
+ def init_weights(self):
28
+ with torch.no_grad():
29
+ if self.is_first:
30
+ self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
31
+ else:
32
+ self.linear.weight.uniform_(
33
+ -np.sqrt(6 / self.in_features) / self.omega_0,
34
+ np.sqrt(6 / self.in_features) / self.omega_0,
35
+ )
36
+
37
+ def forward(self, x):
38
+ return torch.sin(self.omega_0 * self.linear(x))
39
+
40
+
41
+ class Siren(nn.Module):
42
+ """Siren network.
43
+
44
+ Args:
45
+ in_dim: Input layer dimension
46
+ num_layers: Number of network layers
47
+ layer_width: Width of each MLP layer
48
+ out_dim: Output layer dimension. Uses layer_width if None.
49
+ activation: intermediate layer activation function.
50
+ out_activation: output activation function.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_dim: int,
56
+ hidden_layers: int,
57
+ hidden_features: int,
58
+ out_dim: Optional[int] = None,
59
+ outermost_linear: bool = False,
60
+ first_omega_0: float = 30,
61
+ hidden_omega_0: float = 30,
62
+ out_activation: Optional[nn.Module] = None,
63
+ ) -> None:
64
+ super().__init__()
65
+ self.in_dim = in_dim
66
+ assert self.in_dim > 0
67
+ self.out_dim = out_dim if out_dim is not None else hidden_features
68
+ self.outermost_linear = outermost_linear
69
+ self.first_omega_0 = first_omega_0
70
+ self.hidden_omega_0 = hidden_omega_0
71
+ self.hidden_layers = hidden_layers
72
+ self.layer_width = hidden_features
73
+ self.out_activation = out_activation
74
+
75
+ self.net = []
76
+ self.net.append(
77
+ SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0)
78
+ )
79
+
80
+ for _ in range(hidden_layers):
81
+ self.net.append(
82
+ SineLayer(
83
+ hidden_features,
84
+ hidden_features,
85
+ is_first=False,
86
+ omega_0=hidden_omega_0,
87
+ )
88
+ )
89
+
90
+ if outermost_linear:
91
+ final_layer = nn.Linear(hidden_features, self.out_dim)
92
+
93
+ with torch.no_grad():
94
+ final_layer.weight.uniform_(
95
+ -np.sqrt(6 / hidden_features) / hidden_omega_0,
96
+ np.sqrt(6 / hidden_features) / hidden_omega_0,
97
+ )
98
+
99
+ self.net.append(final_layer)
100
+ else:
101
+ self.net.append(
102
+ SineLayer(
103
+ hidden_features,
104
+ self.out_dim,
105
+ is_first=False,
106
+ omega_0=hidden_omega_0,
107
+ )
108
+ )
109
+
110
+ if self.out_activation is not None:
111
+ self.net.append(self.out_activation)
112
+
113
+ self.net = nn.Sequential(*self.net)
114
+
115
+ def forward(self, model_input):
116
+ """Forward pass through the network"""
117
+ output = self.net(model_input)
118
+ return output
spar3d/models/illumination/reni/components/transformer_decoder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(
9
+ self,
10
+ direction_input_dim: int,
11
+ conditioning_input_dim: int,
12
+ latent_dim: int,
13
+ num_heads: int,
14
+ ):
15
+ """
16
+ Multi-Head Attention module.
17
+
18
+ Args:
19
+ direction_input_dim (int): The input dimension of the directional input.
20
+ conditioning_input_dim (int): The input dimension of the conditioning input.
21
+ latent_dim (int): The latent dimension of the module.
22
+ num_heads (int): The number of heads to use in the attention mechanism.
23
+ """
24
+ super().__init__()
25
+ assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
26
+ self.num_heads = num_heads
27
+ self.head_dim = latent_dim // num_heads
28
+ self.scale = self.head_dim**-0.5
29
+
30
+ self.query = nn.Linear(direction_input_dim, latent_dim)
31
+ self.key = nn.Linear(conditioning_input_dim, latent_dim)
32
+ self.value = nn.Linear(conditioning_input_dim, latent_dim)
33
+ self.fc_out = nn.Linear(latent_dim, latent_dim)
34
+
35
+ def forward(
36
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
37
+ ) -> torch.Tensor:
38
+ """
39
+ Forward pass of the Multi-Head Attention module.
40
+
41
+ Args:
42
+ query (torch.Tensor): The directional input tensor.
43
+ key (torch.Tensor): The conditioning input tensor for the keys.
44
+ value (torch.Tensor): The conditioning input tensor for the values.
45
+
46
+ Returns:
47
+ torch.Tensor: The output tensor of the Multi-Head Attention module.
48
+ """
49
+ batch_size = query.size(0)
50
+
51
+ Q = (
52
+ self.query(query)
53
+ .view(batch_size, -1, self.num_heads, self.head_dim)
54
+ .transpose(1, 2)
55
+ )
56
+ K = (
57
+ self.key(key)
58
+ .view(batch_size, -1, self.num_heads, self.head_dim)
59
+ .transpose(1, 2)
60
+ )
61
+ V = (
62
+ self.value(value)
63
+ .view(batch_size, -1, self.num_heads, self.head_dim)
64
+ .transpose(1, 2)
65
+ )
66
+
67
+ attention = (
68
+ torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
69
+ )
70
+ attention = torch.softmax(attention, dim=-1)
71
+
72
+ out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
73
+ out = (
74
+ out.transpose(1, 2)
75
+ .contiguous()
76
+ .view(batch_size, -1, self.num_heads * self.head_dim)
77
+ )
78
+
79
+ out = self.fc_out(out).squeeze(1)
80
+ return out
81
+
82
+
83
+ class AttentionLayer(nn.Module):
84
+ def __init__(
85
+ self,
86
+ direction_input_dim: int,
87
+ conditioning_input_dim: int,
88
+ latent_dim: int,
89
+ num_heads: int,
90
+ ):
91
+ """
92
+ Attention Layer module.
93
+
94
+ Args:
95
+ direction_input_dim (int): The input dimension of the directional input.
96
+ conditioning_input_dim (int): The input dimension of the conditioning input.
97
+ latent_dim (int): The latent dimension of the module.
98
+ num_heads (int): The number of heads to use in the attention mechanism.
99
+ """
100
+ super().__init__()
101
+ self.mha = MultiHeadAttention(
102
+ direction_input_dim, conditioning_input_dim, latent_dim, num_heads
103
+ )
104
+ self.norm1 = nn.LayerNorm(latent_dim)
105
+ self.norm2 = nn.LayerNorm(latent_dim)
106
+ self.fc = nn.Sequential(
107
+ nn.Linear(latent_dim, latent_dim),
108
+ nn.ReLU(),
109
+ nn.Linear(latent_dim, latent_dim),
110
+ )
111
+
112
+ def forward(
113
+ self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
114
+ ) -> torch.Tensor:
115
+ """
116
+ Forward pass of the Attention Layer module.
117
+
118
+ Args:
119
+ directional_input (torch.Tensor): The directional input tensor.
120
+ conditioning_input (torch.Tensor): The conditioning input tensor.
121
+
122
+ Returns:
123
+ torch.Tensor: The output tensor of the Attention Layer module.
124
+ """
125
+ attn_output = self.mha(
126
+ directional_input, conditioning_input, conditioning_input
127
+ )
128
+ out1 = self.norm1(attn_output + directional_input)
129
+ fc_output = self.fc(out1)
130
+ out2 = self.norm2(fc_output + out1)
131
+ return out2
132
+
133
+
134
+ class Decoder(nn.Module):
135
+ def __init__(
136
+ self,
137
+ in_dim: int,
138
+ conditioning_input_dim: int,
139
+ hidden_features: int,
140
+ num_heads: int,
141
+ num_layers: int,
142
+ out_activation: Optional[nn.Module],
143
+ ):
144
+ """
145
+ Decoder module.
146
+
147
+ Args:
148
+ in_dim (int): The input dimension of the module.
149
+ conditioning_input_dim (int): The input dimension of the conditioning input.
150
+ hidden_features (int): The number of hidden features in the module.
151
+ num_heads (int): The number of heads to use in the attention mechanism.
152
+ num_layers (int): The number of layers in the module.
153
+ out_activation (nn.Module): The activation function to use on the output tensor.
154
+ """
155
+ super().__init__()
156
+ self.residual_projection = nn.Linear(
157
+ in_dim, hidden_features
158
+ ) # projection for residual connection
159
+ self.layers = nn.ModuleList(
160
+ [
161
+ AttentionLayer(
162
+ hidden_features, conditioning_input_dim, hidden_features, num_heads
163
+ )
164
+ for i in range(num_layers)
165
+ ]
166
+ )
167
+ self.fc = nn.Linear(hidden_features, 3) # 3 for RGB
168
+ self.out_activation = out_activation
169
+
170
+ def forward(
171
+ self, x: torch.Tensor, conditioning_input: torch.Tensor
172
+ ) -> torch.Tensor:
173
+ """
174
+ Forward pass of the Decoder module.
175
+
176
+ Args:
177
+ x (torch.Tensor): The input tensor.
178
+ conditioning_input (torch.Tensor): The conditioning input tensor.
179
+
180
+ Returns:
181
+ torch.Tensor: The output tensor of the Decoder module.
182
+ """
183
+ x = self.residual_projection(x)
184
+ for layer in self.layers:
185
+ x = layer(x, conditioning_input)
186
+ x = self.fc(x)
187
+ if self.out_activation is not None:
188
+ x = self.out_activation(x)
189
+ return x
spar3d/models/illumination/reni/components/vn_layers.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Phil Wang
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ """All code taken from https://github.com/lucidrains/VN-transformer"""
24
+
25
+ from collections import namedtuple
26
+ from functools import wraps
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from einops import rearrange, reduce
31
+ from einops.layers.torch import Rearrange
32
+ from packaging import version
33
+ from torch import einsum, nn
34
+
35
+ # constants
36
+
37
+ FlashAttentionConfig = namedtuple(
38
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
39
+ )
40
+
41
+ # helpers
42
+
43
+
44
+ def exists(val):
45
+ return val is not None
46
+
47
+
48
+ def once(fn):
49
+ called = False
50
+
51
+ @wraps(fn)
52
+ def inner(x):
53
+ nonlocal called
54
+ if called:
55
+ return
56
+ called = True
57
+ return fn(x)
58
+
59
+ return inner
60
+
61
+
62
+ print_once = once(print)
63
+
64
+ # main class
65
+
66
+
67
+ class Attend(nn.Module):
68
+ def __init__(self, dropout=0.0, flash=False, l2_dist=False):
69
+ super().__init__()
70
+ assert not (
71
+ flash and l2_dist
72
+ ), "flash attention is not compatible with l2 distance"
73
+ self.l2_dist = l2_dist
74
+
75
+ self.dropout = dropout
76
+ self.attn_dropout = nn.Dropout(dropout)
77
+
78
+ self.flash = flash
79
+ assert not (
80
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
81
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
82
+
83
+ # determine efficient attention configs for cuda and cpu
84
+
85
+ self.cpu_config = FlashAttentionConfig(True, True, True)
86
+ self.cuda_config = None
87
+
88
+ if not torch.cuda.is_available() or not flash:
89
+ return
90
+
91
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
92
+
93
+ if device_properties.major == 8 and device_properties.minor == 0:
94
+ print_once(
95
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
96
+ )
97
+ self.cuda_config = FlashAttentionConfig(True, False, False)
98
+ else:
99
+ print_once(
100
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
101
+ )
102
+ self.cuda_config = FlashAttentionConfig(False, True, True)
103
+
104
+ def flash_attn(self, q, k, v, mask=None):
105
+ _, heads, q_len, _, _, is_cuda = (
106
+ *q.shape,
107
+ k.shape[-2],
108
+ q.is_cuda,
109
+ )
110
+
111
+ # Check if mask exists and expand to compatible shape
112
+ # The mask is B L, so it would have to be expanded to B H N L
113
+
114
+ if exists(mask):
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ )
131
+
132
+ return out
133
+
134
+ def forward(self, q, k, v, mask=None):
135
+ """
136
+ einstein notation
137
+ b - batch
138
+ h - heads
139
+ n, i, j - sequence length (base sequence length, source, target)
140
+ d - feature dimension
141
+ """
142
+ scale = q.shape[-1] ** -0.5
143
+
144
+ if exists(mask) and mask.ndim != 4:
145
+ mask = rearrange(mask, "b j -> b 1 1 j")
146
+
147
+ if self.flash:
148
+ return self.flash_attn(q, k, v, mask=mask)
149
+
150
+ # similarity
151
+
152
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
153
+
154
+ # l2 distance
155
+
156
+ if self.l2_dist:
157
+ # -cdist squared == (-q^2 + 2qk - k^2)
158
+ # so simply work off the qk above
159
+ q_squared = reduce(q**2, "b h i d -> b h i 1", "sum")
160
+ k_squared = reduce(k**2, "b h j d -> b h 1 j", "sum")
161
+ sim = sim * 2 - q_squared - k_squared
162
+
163
+ # key padding mask
164
+
165
+ if exists(mask):
166
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
167
+
168
+ # attention
169
+
170
+ attn = sim.softmax(dim=-1)
171
+ attn = self.attn_dropout(attn)
172
+
173
+ # aggregate values
174
+
175
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
176
+
177
+ return out
178
+
179
+
180
+ # helper
181
+
182
+
183
+ def exists(val): # noqa: F811
184
+ return val is not None
185
+
186
+
187
+ def default(val, d):
188
+ return val if exists(val) else d
189
+
190
+
191
+ def inner_dot_product(x, y, *, dim=-1, keepdim=True):
192
+ return (x * y).sum(dim=dim, keepdim=keepdim)
193
+
194
+
195
+ # layernorm
196
+
197
+
198
+ class LayerNorm(nn.Module):
199
+ def __init__(self, dim):
200
+ super().__init__()
201
+ self.gamma = nn.Parameter(torch.ones(dim))
202
+ self.register_buffer("beta", torch.zeros(dim))
203
+
204
+ def forward(self, x):
205
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
206
+
207
+
208
+ # equivariant modules
209
+
210
+
211
+ class VNLinear(nn.Module):
212
+ def __init__(self, dim_in, dim_out, bias_epsilon=0.0):
213
+ super().__init__()
214
+ self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
215
+
216
+ self.bias = None
217
+ self.bias_epsilon = bias_epsilon
218
+
219
+ # in this paper, they propose going for quasi-equivariance with a small bias, controllable with epsilon, which they claim lead to better stability and results
220
+
221
+ if bias_epsilon > 0.0:
222
+ self.bias = nn.Parameter(torch.randn(dim_out))
223
+
224
+ def forward(self, x):
225
+ out = einsum("... i c, o i -> ... o c", x, self.weight)
226
+
227
+ if exists(self.bias):
228
+ bias = F.normalize(self.bias, dim=-1) * self.bias_epsilon
229
+ out = out + rearrange(bias, "... -> ... 1")
230
+
231
+ return out
232
+
233
+
234
+ class VNReLU(nn.Module):
235
+ def __init__(self, dim, eps=1e-6):
236
+ super().__init__()
237
+ self.eps = eps
238
+ self.W = nn.Parameter(torch.randn(dim, dim))
239
+ self.U = nn.Parameter(torch.randn(dim, dim))
240
+
241
+ def forward(self, x):
242
+ q = einsum("... i c, o i -> ... o c", x, self.W)
243
+ k = einsum("... i c, o i -> ... o c", x, self.U)
244
+
245
+ qk = inner_dot_product(q, k)
246
+
247
+ k_norm = k.norm(dim=-1, keepdim=True).clamp(min=self.eps)
248
+ q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k
249
+
250
+ out = torch.where(qk >= 0.0, q, q_projected_on_k)
251
+
252
+ return out
253
+
254
+
255
+ class VNAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ dim,
259
+ dim_head=64,
260
+ heads=8,
261
+ dim_coor=3,
262
+ bias_epsilon=0.0,
263
+ l2_dist_attn=False,
264
+ flash=False,
265
+ num_latents=None, # setting this would enable perceiver-like cross attention from latents to sequence, with the latents derived from VNWeightedPool
266
+ ):
267
+ super().__init__()
268
+ assert not (
269
+ l2_dist_attn and flash
270
+ ), "l2 distance attention is not compatible with flash attention"
271
+
272
+ self.scale = (dim_coor * dim_head) ** -0.5
273
+ dim_inner = dim_head * heads
274
+ self.heads = heads
275
+
276
+ self.to_q_input = None
277
+ if exists(num_latents):
278
+ self.to_q_input = VNWeightedPool(
279
+ dim, num_pooled_tokens=num_latents, squeeze_out_pooled_dim=False
280
+ )
281
+
282
+ self.to_q = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
283
+ self.to_k = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
284
+ self.to_v = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
285
+ self.to_out = VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon)
286
+
287
+ if l2_dist_attn and not exists(num_latents):
288
+ # tied queries and keys for l2 distance attention, and not perceiver-like attention
289
+ self.to_k = self.to_q
290
+
291
+ self.attend = Attend(flash=flash, l2_dist=l2_dist_attn)
292
+
293
+ def forward(self, x, mask=None):
294
+ """
295
+ einstein notation
296
+ b - batch
297
+ n - sequence
298
+ h - heads
299
+ d - feature dimension (channels)
300
+ c - coordinate dimension (3 for 3d space)
301
+ i - source sequence dimension
302
+ j - target sequence dimension
303
+ """
304
+
305
+ c = x.shape[-1]
306
+
307
+ if exists(self.to_q_input):
308
+ q_input = self.to_q_input(x, mask=mask)
309
+ else:
310
+ q_input = x
311
+
312
+ q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
313
+ q, k, v = map(
314
+ lambda t: rearrange(t, "b n (h d) c -> b h n (d c)", h=self.heads),
315
+ (q, k, v),
316
+ )
317
+
318
+ out = self.attend(q, k, v, mask=mask)
319
+
320
+ out = rearrange(out, "b h n (d c) -> b n (h d) c", c=c)
321
+ return self.to_out(out)
322
+
323
+
324
+ def VNFeedForward(dim, mult=4, bias_epsilon=0.0):
325
+ dim_inner = int(dim * mult)
326
+ return nn.Sequential(
327
+ VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon),
328
+ VNReLU(dim_inner),
329
+ VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon),
330
+ )
331
+
332
+
333
+ class VNLayerNorm(nn.Module):
334
+ def __init__(self, dim, eps=1e-6):
335
+ super().__init__()
336
+ self.eps = eps
337
+ self.ln = LayerNorm(dim)
338
+
339
+ def forward(self, x):
340
+ norms = x.norm(dim=-1)
341
+ x = x / rearrange(norms.clamp(min=self.eps), "... -> ... 1")
342
+ ln_out = self.ln(norms)
343
+ return x * rearrange(ln_out, "... -> ... 1")
344
+
345
+
346
+ class VNWeightedPool(nn.Module):
347
+ def __init__(
348
+ self, dim, dim_out=None, num_pooled_tokens=1, squeeze_out_pooled_dim=True
349
+ ):
350
+ super().__init__()
351
+ dim_out = default(dim_out, dim)
352
+ self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))
353
+ self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim
354
+
355
+ def forward(self, x, mask=None):
356
+ if exists(mask):
357
+ mask = rearrange(mask, "b n -> b n 1 1")
358
+ x = x.masked_fill(~mask, 0.0)
359
+ numer = reduce(x, "b n d c -> b d c", "sum")
360
+ denom = mask.sum(dim=1)
361
+ mean_pooled = numer / denom.clamp(min=1e-6)
362
+ else:
363
+ mean_pooled = reduce(x, "b n d c -> b d c", "mean")
364
+
365
+ out = einsum("b d c, m d e -> b m e c", mean_pooled, self.weight)
366
+
367
+ if not self.squeeze_out_pooled_dim:
368
+ return out
369
+
370
+ out = rearrange(out, "b 1 d c -> b d c")
371
+ return out
372
+
373
+
374
+ # equivariant VN transformer encoder
375
+
376
+
377
+ class VNTransformerEncoder(nn.Module):
378
+ def __init__(
379
+ self,
380
+ dim,
381
+ *,
382
+ depth,
383
+ dim_head=64,
384
+ heads=8,
385
+ dim_coor=3,
386
+ ff_mult=4,
387
+ final_norm=False,
388
+ bias_epsilon=0.0,
389
+ l2_dist_attn=False,
390
+ flash_attn=False,
391
+ ):
392
+ super().__init__()
393
+ self.dim = dim
394
+ self.dim_coor = dim_coor
395
+
396
+ self.layers = nn.ModuleList([])
397
+
398
+ for _ in range(depth):
399
+ self.layers.append(
400
+ nn.ModuleList(
401
+ [
402
+ VNAttention(
403
+ dim=dim,
404
+ dim_head=dim_head,
405
+ heads=heads,
406
+ bias_epsilon=bias_epsilon,
407
+ l2_dist_attn=l2_dist_attn,
408
+ flash=flash_attn,
409
+ ),
410
+ VNLayerNorm(dim),
411
+ VNFeedForward(dim=dim, mult=ff_mult, bias_epsilon=bias_epsilon),
412
+ VNLayerNorm(dim),
413
+ ]
414
+ )
415
+ )
416
+
417
+ self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()
418
+
419
+ def forward(self, x, mask=None):
420
+ *_, d, c = x.shape
421
+
422
+ assert (
423
+ x.ndim == 4 and d == self.dim and c == self.dim_coor
424
+ ), "input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))"
425
+
426
+ for attn, attn_post_ln, ff, ff_post_ln in self.layers:
427
+ x = attn_post_ln(attn(x, mask=mask)) + x
428
+ x = ff_post_ln(ff(x)) + x
429
+
430
+ return self.norm(x)
431
+
432
+
433
+ # invariant layers
434
+
435
+
436
+ class VNInvariant(nn.Module):
437
+ def __init__(
438
+ self,
439
+ dim,
440
+ dim_coor=3,
441
+ ):
442
+ super().__init__()
443
+ self.mlp = nn.Sequential(
444
+ VNLinear(dim, dim_coor), VNReLU(dim_coor), Rearrange("... d e -> ... e d")
445
+ )
446
+
447
+ def forward(self, x):
448
+ return einsum("b n d i, b n i o -> b n o", x, self.mlp(x))
449
+
450
+
451
+ # main class
452
+
453
+
454
+ class VNTransformer(nn.Module):
455
+ def __init__(
456
+ self,
457
+ *,
458
+ dim,
459
+ depth,
460
+ num_tokens=None,
461
+ dim_feat=None,
462
+ dim_head=64,
463
+ heads=8,
464
+ dim_coor=3,
465
+ reduce_dim_out=True,
466
+ bias_epsilon=0.0,
467
+ l2_dist_attn=False,
468
+ flash_attn=False,
469
+ translation_equivariance=False,
470
+ translation_invariant=False,
471
+ ):
472
+ super().__init__()
473
+ self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
474
+
475
+ dim_feat = default(dim_feat, 0)
476
+ self.dim_feat = dim_feat
477
+ self.dim_coor_total = dim_coor + dim_feat
478
+
479
+ assert (int(translation_equivariance) + int(translation_invariant)) <= 1
480
+ self.translation_equivariance = translation_equivariance
481
+ self.translation_invariant = translation_invariant
482
+
483
+ self.vn_proj_in = nn.Sequential(
484
+ Rearrange("... c -> ... 1 c"), VNLinear(1, dim, bias_epsilon=bias_epsilon)
485
+ )
486
+
487
+ self.encoder = VNTransformerEncoder(
488
+ dim=dim,
489
+ depth=depth,
490
+ dim_head=dim_head,
491
+ heads=heads,
492
+ bias_epsilon=bias_epsilon,
493
+ dim_coor=self.dim_coor_total,
494
+ l2_dist_attn=l2_dist_attn,
495
+ flash_attn=flash_attn,
496
+ )
497
+
498
+ if reduce_dim_out:
499
+ self.vn_proj_out = nn.Sequential(
500
+ VNLayerNorm(dim),
501
+ VNLinear(dim, 1, bias_epsilon=bias_epsilon),
502
+ Rearrange("... 1 c -> ... c"),
503
+ )
504
+ else:
505
+ self.vn_proj_out = nn.Identity()
506
+
507
+ def forward(
508
+ self, coors, *, feats=None, mask=None, return_concatted_coors_and_feats=False
509
+ ):
510
+ if self.translation_equivariance or self.translation_invariant:
511
+ coors_mean = reduce(coors, "... c -> c", "mean")
512
+ coors = coors - coors_mean
513
+
514
+ x = coors # [batch, num_points, 3]
515
+
516
+ if exists(feats):
517
+ if feats.dtype == torch.long:
518
+ assert exists(
519
+ self.token_emb
520
+ ), "num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices"
521
+ feats = self.token_emb(feats)
522
+
523
+ assert (
524
+ feats.shape[-1] == self.dim_feat
525
+ ), f"dim_feat should be set to {feats.shape[-1]}"
526
+ x = torch.cat((x, feats), dim=-1) # [batch, num_points, 3 + dim_feat]
527
+
528
+ assert x.shape[-1] == self.dim_coor_total
529
+
530
+ x = self.vn_proj_in(x) # [batch, num_points, hidden_dim, 3 + dim_feat]
531
+ x = self.encoder(x, mask=mask) # [batch, num_points, hidden_dim, 3 + dim_feat]
532
+ x = self.vn_proj_out(x) # [batch, num_points, 3 + dim_feat]
533
+
534
+ coors_out, feats_out = (
535
+ x[..., :3],
536
+ x[..., 3:],
537
+ ) # [batch, num_points, 3], [batch, num_points, dim_feat]
538
+
539
+ if self.translation_equivariance:
540
+ coors_out = coors_out + coors_mean
541
+
542
+ if not exists(feats):
543
+ return coors_out
544
+
545
+ if return_concatted_coors_and_feats:
546
+ return torch.cat((coors_out, feats_out), dim=-1)
547
+
548
+ return coors_out, feats_out
spar3d/models/illumination/reni/env_map.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from spar3d.models.utils import BaseModule
9
+
10
+ from .field import RENIField
11
+
12
+
13
+ def _direction_from_coordinate(
14
+ coordinate: Float[Tensor, "*B 2"],
15
+ ) -> Float[Tensor, "*B 3"]:
16
+ # OpenGL Convention
17
+ # +X Right
18
+ # +Y Up
19
+ # +Z Backward
20
+
21
+ u, v = coordinate.unbind(-1)
22
+ theta = (2 * torch.pi * u) - torch.pi
23
+ phi = torch.pi * v
24
+
25
+ dir = torch.stack(
26
+ [
27
+ theta.sin() * phi.sin(),
28
+ phi.cos(),
29
+ -1 * theta.cos() * phi.sin(),
30
+ ],
31
+ -1,
32
+ )
33
+ return dir
34
+
35
+
36
+ def _get_sample_coordinates(
37
+ resolution: List[int], device: Optional[torch.device] = None
38
+ ) -> Float[Tensor, "H W 2"]:
39
+ return torch.stack(
40
+ torch.meshgrid(
41
+ (torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
42
+ (torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
43
+ indexing="xy",
44
+ ),
45
+ -1,
46
+ )
47
+
48
+
49
+ class RENIEnvMap(BaseModule):
50
+ @dataclass
51
+ class Config(BaseModule.Config):
52
+ reni_config: dict = field(default_factory=dict)
53
+ resolution: int = 128
54
+
55
+ cfg: Config
56
+
57
+ def configure(self):
58
+ self.field = RENIField(self.cfg.reni_config)
59
+ resolution = (self.cfg.resolution, self.cfg.resolution * 2)
60
+ sample_directions = _direction_from_coordinate(
61
+ _get_sample_coordinates(resolution)
62
+ )
63
+ self.img_shape = sample_directions.shape[:-1]
64
+
65
+ sample_directions_flat = sample_directions.view(-1, 3)
66
+ # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
67
+ sample_directions_flat = torch.stack(
68
+ [
69
+ sample_directions_flat[:, 0],
70
+ -sample_directions_flat[:, 2],
71
+ sample_directions_flat[:, 1],
72
+ ],
73
+ -1,
74
+ )
75
+ self.sample_directions = torch.nn.Parameter(
76
+ sample_directions_flat, requires_grad=False
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ latent_codes: Float[Tensor, "B latent_dim 3"],
82
+ rotation: Optional[Float[Tensor, "B 3 3"]] = None,
83
+ scale: Optional[Float[Tensor, "B"]] = None,
84
+ ) -> Dict[str, Tensor]:
85
+ return {
86
+ k: v.view(latent_codes.shape[0], *self.img_shape, -1)
87
+ for k, v in self.field(
88
+ self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
89
+ latent_codes,
90
+ rotation=rotation,
91
+ scale=scale,
92
+ ).items()
93
+ }
spar3d/models/illumination/reni/field.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The University of York. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Mark Boss
16
+
17
+ """RENI field"""
18
+
19
+ import contextlib
20
+ from dataclasses import dataclass
21
+ from typing import Dict, Literal, Optional
22
+
23
+ import torch
24
+ from einops.layers.torch import Rearrange
25
+ from jaxtyping import Float
26
+ from torch import Tensor, nn
27
+
28
+ from spar3d.models.network import get_activation_module, trunc_exp
29
+ from spar3d.models.utils import BaseModule
30
+
31
+ from .components.film_siren import FiLMSiren
32
+ from .components.siren import Siren
33
+ from .components.transformer_decoder import Decoder
34
+ from .components.vn_layers import VNInvariant, VNLinear
35
+
36
+ # from nerfstudio.cameras.rays import RaySamples
37
+
38
+
39
+ def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
40
+ """Computes the expected value of sin(y) where y ~ N(x_means, x_vars)
41
+
42
+ Args:
43
+ x_means: Mean values.
44
+ x_vars: Variance of values.
45
+
46
+ Returns:
47
+ torch.Tensor: The expected value of sin.
48
+ """
49
+
50
+ return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
51
+
52
+
53
+ class NeRFEncoding(torch.nn.Module):
54
+ """Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided.
55
+ Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.
56
+
57
+ Args:
58
+ in_dim: Input dimension of tensor
59
+ num_frequencies: Number of encoded frequencies per axis
60
+ min_freq_exp: Minimum frequency exponent
61
+ max_freq_exp: Maximum frequency exponent
62
+ include_input: Append the input coordinate to the encoding
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ in_dim: int,
68
+ num_frequencies: int,
69
+ min_freq_exp: float,
70
+ max_freq_exp: float,
71
+ include_input: bool = False,
72
+ off_axis: bool = False,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ self.in_dim = in_dim
77
+ self.num_frequencies = num_frequencies
78
+ self.min_freq = min_freq_exp
79
+ self.max_freq = max_freq_exp
80
+ self.include_input = include_input
81
+
82
+ self.off_axis = off_axis
83
+
84
+ self.P = torch.tensor(
85
+ [
86
+ [0.8506508, 0, 0.5257311],
87
+ [0.809017, 0.5, 0.309017],
88
+ [0.5257311, 0.8506508, 0],
89
+ [1, 0, 0],
90
+ [0.809017, 0.5, -0.309017],
91
+ [0.8506508, 0, -0.5257311],
92
+ [0.309017, 0.809017, -0.5],
93
+ [0, 0.5257311, -0.8506508],
94
+ [0.5, 0.309017, -0.809017],
95
+ [0, 1, 0],
96
+ [-0.5257311, 0.8506508, 0],
97
+ [-0.309017, 0.809017, -0.5],
98
+ [0, 0.5257311, 0.8506508],
99
+ [-0.309017, 0.809017, 0.5],
100
+ [0.309017, 0.809017, 0.5],
101
+ [0.5, 0.309017, 0.809017],
102
+ [0.5, -0.309017, 0.809017],
103
+ [0, 0, 1],
104
+ [-0.5, 0.309017, 0.809017],
105
+ [-0.809017, 0.5, 0.309017],
106
+ [-0.809017, 0.5, -0.309017],
107
+ ]
108
+ ).T
109
+
110
+ def get_out_dim(self) -> int:
111
+ if self.in_dim is None:
112
+ raise ValueError("Input dimension has not been set")
113
+ out_dim = self.in_dim * self.num_frequencies * 2
114
+
115
+ if self.off_axis:
116
+ out_dim = self.P.shape[1] * self.num_frequencies * 2
117
+
118
+ if self.include_input:
119
+ out_dim += self.in_dim
120
+ return out_dim
121
+
122
+ def forward(
123
+ self,
124
+ in_tensor: Float[Tensor, "*b input_dim"],
125
+ covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None,
126
+ ) -> Float[Tensor, "*b output_dim"]:
127
+ """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
128
+ in mip-NeRF.
129
+
130
+ Args:
131
+ in_tensor: For best performance, the input tensor should be between 0 and 1.
132
+ covs: Covariances of input points.
133
+ Returns:
134
+ Output values will be between -1 and 1
135
+ """
136
+ # TODO check scaling here but just comment it for now
137
+ # in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
138
+ freqs = 2 ** torch.linspace(
139
+ self.min_freq, self.max_freq, self.num_frequencies
140
+ ).to(in_tensor.device)
141
+ # freqs = 2 ** (
142
+ # torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq
143
+ # ).to(in_tensor.device)
144
+ # freqs = 2 ** (
145
+ # torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq
146
+ # )
147
+
148
+ if self.off_axis:
149
+ scaled_inputs = (
150
+ torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs
151
+ )
152
+ else:
153
+ scaled_inputs = (
154
+ in_tensor[..., None] * freqs
155
+ ) # [..., "input_dim", "num_scales"]
156
+ scaled_inputs = scaled_inputs.view(
157
+ *scaled_inputs.shape[:-2], -1
158
+ ) # [..., "input_dim" * "num_scales"]
159
+
160
+ if covs is None:
161
+ encoded_inputs = torch.sin(
162
+ torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)
163
+ )
164
+ else:
165
+ input_var = (
166
+ torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None]
167
+ * freqs[None, :] ** 2
168
+ )
169
+ input_var = input_var.reshape((*input_var.shape[:-2], -1))
170
+ encoded_inputs = expected_sin(
171
+ torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1),
172
+ torch.cat(2 * [input_var], dim=-1),
173
+ )
174
+
175
+ if self.include_input:
176
+ encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)
177
+ return encoded_inputs
178
+
179
+
180
+ class RENIField(BaseModule):
181
+ @dataclass
182
+ class Config(BaseModule.Config):
183
+ """Configuration for model instantiation"""
184
+
185
+ fixed_decoder: bool = False
186
+ """Whether to fix the decoder weights"""
187
+ equivariance: str = "SO2"
188
+ """Type of equivariance to use: None, SO2, SO3"""
189
+ axis_of_invariance: str = "y"
190
+ """Which axis should SO2 equivariance be invariant to: x, y, z"""
191
+ invariant_function: str = "GramMatrix"
192
+ """Type of invariant function to use: GramMatrix, VN"""
193
+ conditioning: str = "Concat"
194
+ """Type of conditioning to use: FiLM, Concat, Attention"""
195
+ positional_encoding: str = "NeRF"
196
+ """Type of positional encoding to use. Currently only NeRF is supported"""
197
+ encoded_input: str = "Directions"
198
+ """Type of input to encode: None, Directions, Conditioning, Both"""
199
+ latent_dim: int = 36
200
+ """Dimensionality of latent code, N for a latent code size of (N x 3)"""
201
+ hidden_layers: int = 3
202
+ """Number of hidden layers"""
203
+ hidden_features: int = 128
204
+ """Number of hidden features"""
205
+ mapping_layers: int = 3
206
+ """Number of mapping layers"""
207
+ mapping_features: int = 128
208
+ """Number of mapping features"""
209
+ num_attention_heads: int = 8
210
+ """Number of attention heads"""
211
+ num_attention_layers: int = 3
212
+ """Number of attention layers"""
213
+ out_features: int = 3 # RGB
214
+ """Number of output features"""
215
+ last_layer_linear: bool = False
216
+ """Whether to use a linear layer as the last layer"""
217
+ output_activation: str = "exp"
218
+ """Activation function for output layer: sigmoid, tanh, relu, exp, None"""
219
+ first_omega_0: float = 30.0
220
+ """Omega_0 for first layer"""
221
+ hidden_omega_0: float = 30.0
222
+ """Omega_0 for hidden layers"""
223
+ fixed_decoder: bool = False
224
+ """Whether to fix the decoder weights"""
225
+ old_implementation: bool = False
226
+ """Whether to match implementation of old RENI, when using old checkpoints"""
227
+
228
+ cfg: Config
229
+
230
+ def configure(self):
231
+ self.equivariance = self.cfg.equivariance
232
+ self.conditioning = self.cfg.conditioning
233
+ self.latent_dim = self.cfg.latent_dim
234
+ self.hidden_layers = self.cfg.hidden_layers
235
+ self.hidden_features = self.cfg.hidden_features
236
+ self.mapping_layers = self.cfg.mapping_layers
237
+ self.mapping_features = self.cfg.mapping_features
238
+ self.out_features = self.cfg.out_features
239
+ self.last_layer_linear = self.cfg.last_layer_linear
240
+ self.output_activation = self.cfg.output_activation
241
+ self.first_omega_0 = self.cfg.first_omega_0
242
+ self.hidden_omega_0 = self.cfg.hidden_omega_0
243
+ self.old_implementation = self.cfg.old_implementation
244
+ self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance)
245
+
246
+ self.fixed_decoder = self.cfg.fixed_decoder
247
+ if self.cfg.invariant_function == "GramMatrix":
248
+ self.invariant_function = self.gram_matrix_invariance
249
+ else:
250
+ self.vn_proj_in = nn.Sequential(
251
+ Rearrange("... c -> ... 1 c"),
252
+ VNLinear(dim_in=1, dim_out=1, bias_epsilon=0),
253
+ )
254
+ dim_coor = 2 if self.cfg.equivariance == "SO2" else 3
255
+ self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor)
256
+ self.invariant_function = self.vn_invariance
257
+
258
+ self.network = self.setup_network()
259
+
260
+ if self.fixed_decoder:
261
+ for param in self.network.parameters():
262
+ param.requires_grad = False
263
+
264
+ if self.cfg.invariant_function == "VN":
265
+ for param in self.vn_proj_in.parameters():
266
+ param.requires_grad = False
267
+ for param in self.vn_invar.parameters():
268
+ param.requires_grad = False
269
+
270
+ @contextlib.contextmanager
271
+ def hold_decoder_fixed(self):
272
+ """Context manager to fix the decoder weights
273
+
274
+ Example usage:
275
+ ```
276
+ with instance_of_RENIField.hold_decoder_fixed():
277
+ # do stuff
278
+ ```
279
+ """
280
+ prev_state_network = {
281
+ name: p.requires_grad for name, p in self.network.named_parameters()
282
+ }
283
+ for param in self.network.parameters():
284
+ param.requires_grad = False
285
+ if self.cfg.invariant_function == "VN":
286
+ prev_state_proj_in = {
287
+ k: p.requires_grad for k, p in self.vn_proj_in.named_parameters()
288
+ }
289
+ prev_state_invar = {
290
+ k: p.requires_grad for k, p in self.vn_invar.named_parameters()
291
+ }
292
+ for param in self.vn_proj_in.parameters():
293
+ param.requires_grad = False
294
+ for param in self.vn_invar.parameters():
295
+ param.requires_grad = False
296
+
297
+ prev_decoder_state = self.fixed_decoder
298
+ self.fixed_decoder = True
299
+ try:
300
+ yield
301
+ finally:
302
+ # Restore the previous requires_grad state
303
+ for name, param in self.network.named_parameters():
304
+ param.requires_grad = prev_state_network[name]
305
+ if self.cfg.invariant_function == "VN":
306
+ for name, param in self.vn_proj_in.named_parameters():
307
+ param.requires_grad_(prev_state_proj_in[name])
308
+ for name, param in self.vn_invar.named_parameters():
309
+ param.requires_grad_(prev_state_invar[name])
310
+ self.fixed_decoder = prev_decoder_state
311
+
312
+ def vn_invariance(
313
+ self,
314
+ Z: Float[Tensor, "B latent_dim 3"],
315
+ D: Float[Tensor, "B num_rays 3"],
316
+ equivariance: Literal["None", "SO2", "SO3"] = "SO2",
317
+ axis_of_invariance: int = 1,
318
+ ):
319
+ """Generates a batched invariant representation from latent code Z and direction coordinates D.
320
+
321
+ Args:
322
+ Z: [B, latent_dim, 3] - Latent code.
323
+ D: [B num_rays, 3] - Direction coordinates.
324
+ equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'.
325
+ axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
326
+
327
+ Returns:
328
+ Tuple[Tensor, Tensor]: directional_input, conditioning_input
329
+ """
330
+ assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
331
+ other_axes = [i for i in range(3) if i != axis_of_invariance]
332
+
333
+ B, latent_dim, _ = Z.shape
334
+ _, num_rays, _ = D.shape
335
+
336
+ if equivariance == "None":
337
+ # get inner product between latent code and direction coordinates
338
+ innerprod = torch.sum(
339
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
340
+ ) # [B, num_rays, latent_dim]
341
+ z_input = (
342
+ Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
343
+ ) # [B, num_rays, latent_dim * 3]
344
+ return innerprod, z_input
345
+
346
+ if equivariance == "SO2":
347
+ z_other = torch.stack(
348
+ (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
349
+ ) # [B, latent_dim, 2]
350
+ d_other = torch.stack(
351
+ (D[..., other_axes[0]], D[..., other_axes[1]]), -1
352
+ ).unsqueeze(2) # [B, num_rays, 1, 2]
353
+ d_other = d_other.expand(
354
+ B, num_rays, latent_dim, 2
355
+ ) # [B, num_rays, latent_dim, 2]
356
+
357
+ z_other_emb = self.vn_proj_in(z_other) # [B, latent_dim, 1, 2]
358
+ z_other_invar = self.vn_invar(z_other_emb) # [B, latent_dim, 2]
359
+
360
+ # Get invariant component of Z along the axis of invariance
361
+ z_invar = Z[..., axis_of_invariance].unsqueeze(-1) # [B, latent_dim, 1]
362
+
363
+ # Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance.
364
+ # This encodes the rotational information. This is rotation-equivariant to rotations of either Z
365
+ # or D and is invariant to rotations of both Z and D.
366
+ innerprod = (z_other.unsqueeze(1) * d_other).sum(
367
+ dim=-1
368
+ ) # [B, num_rays, latent_dim]
369
+
370
+ # Compute norm along the axes orthogonal to the axis of invariance
371
+ d_other_norm = torch.sqrt(
372
+ D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
373
+ ).unsqueeze(-1) # [B num_rays, 1]
374
+
375
+ # Get invariant component of D along the axis of invariance
376
+ d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
377
+
378
+ directional_input = torch.cat(
379
+ (innerprod, d_invar, d_other_norm), -1
380
+ ) # [B, num_rays, latent_dim + 2]
381
+ conditioning_input = (
382
+ torch.cat((z_other_invar, z_invar), dim=-1)
383
+ .flatten(1)
384
+ .unsqueeze(1)
385
+ .expand(B, num_rays, latent_dim * 3)
386
+ ) # [B, num_rays, latent_dim * 3]
387
+
388
+ return directional_input, conditioning_input
389
+
390
+ if equivariance == "SO3":
391
+ z = self.vn_proj_in(Z) # [B, latent_dim, 1, 3]
392
+ z_invar = self.vn_invar(z) # [B, latent_dim, 3]
393
+ conditioning_input = (
394
+ z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim)
395
+ ) # [B, num_rays, latent_dim * 3]
396
+ # D [B, num_rays, 3] -> [B, num_rays, 1, 3]
397
+ # Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3]
398
+ innerprod = torch.sum(
399
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
400
+ ) # [B, num_rays, latent_dim]
401
+ return innerprod, conditioning_input
402
+
403
+ def gram_matrix_invariance(
404
+ self,
405
+ Z: Float[Tensor, "B latent_dim 3"],
406
+ D: Float[Tensor, "B num_rays 3"],
407
+ equivariance: Literal["None", "SO2", "SO3"] = "SO2",
408
+ axis_of_invariance: int = 1,
409
+ ):
410
+ """Generates an invariant representation from latent code Z and direction coordinates D.
411
+
412
+ Args:
413
+ Z (torch.Tensor): Latent code (B x latent_dim x 3)
414
+ D (torch.Tensor): Direction coordinates (B x num_rays x 3)
415
+ equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3'
416
+ axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
417
+ Default is 1 (y-axis).
418
+ Returns:
419
+ torch.Tensor: Invariant representation
420
+ """
421
+ assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
422
+ other_axes = [i for i in range(3) if i != axis_of_invariance]
423
+
424
+ B, latent_dim, _ = Z.shape
425
+ _, num_rays, _ = D.shape
426
+
427
+ if equivariance == "None":
428
+ # get inner product between latent code and direction coordinates
429
+ innerprod = torch.sum(
430
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
431
+ ) # [B, num_rays, latent_dim]
432
+ z_input = (
433
+ Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
434
+ ) # [B, num_rays, latent_dim * 3]
435
+ return innerprod, z_input
436
+
437
+ if equivariance == "SO2":
438
+ # Select components along axes orthogonal to the axis of invariance
439
+ z_other = torch.stack(
440
+ (Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
441
+ ) # [B, latent_dim, 2]
442
+ d_other = torch.stack(
443
+ (D[..., other_axes[0]], D[..., other_axes[1]]), -1
444
+ ).unsqueeze(2) # [B, num_rays, 1, 2]
445
+ d_other = d_other.expand(
446
+ B, num_rays, latent_dim, 2
447
+ ) # size becomes [B, num_rays, latent_dim, 2]
448
+
449
+ # Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim
450
+ G = torch.bmm(z_other, torch.transpose(z_other, 1, 2))
451
+
452
+ # Flatten G to be size B x latent_dim^2
453
+ z_other_invar = G.flatten(start_dim=1)
454
+
455
+ # Get invariant component of Z along the axis of invariance
456
+ z_invar = Z[..., axis_of_invariance] # [B, latent_dim]
457
+
458
+ # Innerprod is size num_rays x latent_dim
459
+ innerprod = (z_other.unsqueeze(1) * d_other).sum(
460
+ dim=-1
461
+ ) # [B, num_rays, latent_dim]
462
+
463
+ # Compute norm along the axes orthogonal to the axis of invariance
464
+ d_other_norm = torch.sqrt(
465
+ D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
466
+ ).unsqueeze(-1) # [B, num_rays, 1]
467
+
468
+ # Get invariant component of D along the axis of invariance
469
+ d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
470
+
471
+ if not self.old_implementation:
472
+ directional_input = torch.cat(
473
+ (innerprod, d_invar, d_other_norm), -1
474
+ ) # [B, num_rays, latent_dim + 2]
475
+ conditioning_input = (
476
+ torch.cat((z_other_invar, z_invar), -1)
477
+ .unsqueeze(1)
478
+ .expand(B, num_rays, latent_dim * 3)
479
+ ) # [B, num_rays, latent_dim^2 + latent_dim]
480
+ else:
481
+ # this is matching the previous implementation of RENI, needed if using old checkpoints
482
+ z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1)
483
+ z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1)
484
+ return torch.cat(
485
+ (innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1
486
+ )
487
+
488
+ return directional_input, conditioning_input
489
+
490
+ if equivariance == "SO3":
491
+ G = Z @ torch.transpose(Z, 1, 2) # [B, latent_dim, latent_dim]
492
+ innerprod = torch.sum(
493
+ Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
494
+ ) # [B, num_rays, latent_dim]
495
+ z_invar = (
496
+ G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1)
497
+ ) # [B, num_rays, latent_dim^2]
498
+ return innerprod, z_invar
499
+
500
+ def setup_network(self):
501
+ """Sets up the network architecture"""
502
+ base_input_dims = {
503
+ "VN": {
504
+ "None": {
505
+ "direction": self.latent_dim,
506
+ "conditioning": self.latent_dim * 3,
507
+ },
508
+ "SO2": {
509
+ "direction": self.latent_dim + 2,
510
+ "conditioning": self.latent_dim * 3,
511
+ },
512
+ "SO3": {
513
+ "direction": self.latent_dim,
514
+ "conditioning": self.latent_dim * 3,
515
+ },
516
+ },
517
+ "GramMatrix": {
518
+ "None": {
519
+ "direction": self.latent_dim,
520
+ "conditioning": self.latent_dim * 3,
521
+ },
522
+ "SO2": {
523
+ "direction": self.latent_dim + 2,
524
+ "conditioning": self.latent_dim**2 + self.latent_dim,
525
+ },
526
+ "SO3": {
527
+ "direction": self.latent_dim,
528
+ "conditioning": self.latent_dim**2,
529
+ },
530
+ },
531
+ }
532
+
533
+ # Extract the necessary input dimensions
534
+ input_types = ["direction", "conditioning"]
535
+ input_dims = {
536
+ key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][
537
+ key
538
+ ]
539
+ for key in input_types
540
+ }
541
+
542
+ # Helper function to create NeRF encoding
543
+ def create_nerf_encoding(in_dim):
544
+ return NeRFEncoding(
545
+ in_dim=in_dim,
546
+ num_frequencies=2,
547
+ min_freq_exp=0.0,
548
+ max_freq_exp=2.0,
549
+ include_input=True,
550
+ )
551
+
552
+ # Dictionary-based encoding setup
553
+ encoding_setup = {
554
+ "None": [],
555
+ "Conditioning": ["conditioning"],
556
+ "Directions": ["direction"],
557
+ "Both": ["direction", "conditioning"],
558
+ }
559
+
560
+ # Setting up the required encodings
561
+ for input_type in encoding_setup.get(self.cfg.encoded_input, []):
562
+ # create self.{input_type}_encoding and update input_dims
563
+ setattr(
564
+ self,
565
+ f"{input_type}_encoding",
566
+ create_nerf_encoding(input_dims[input_type]),
567
+ )
568
+ input_dims[input_type] = getattr(
569
+ self, f"{input_type}_encoding"
570
+ ).get_out_dim()
571
+
572
+ output_activation = get_activation_module(self.cfg.output_activation)
573
+
574
+ network = None
575
+ if self.conditioning == "Concat":
576
+ network = Siren(
577
+ in_dim=input_dims["direction"] + input_dims["conditioning"],
578
+ hidden_layers=self.hidden_layers,
579
+ hidden_features=self.hidden_features,
580
+ out_dim=self.out_features,
581
+ outermost_linear=self.last_layer_linear,
582
+ first_omega_0=self.first_omega_0,
583
+ hidden_omega_0=self.hidden_omega_0,
584
+ out_activation=output_activation,
585
+ )
586
+ elif self.conditioning == "FiLM":
587
+ network = FiLMSiren(
588
+ in_dim=input_dims["direction"],
589
+ hidden_layers=self.hidden_layers,
590
+ hidden_features=self.hidden_features,
591
+ mapping_network_in_dim=input_dims["conditioning"],
592
+ mapping_network_layers=self.mapping_layers,
593
+ mapping_network_features=self.mapping_features,
594
+ out_dim=self.out_features,
595
+ outermost_linear=True,
596
+ out_activation=output_activation,
597
+ )
598
+ elif self.conditioning == "Attention":
599
+ # transformer where K, V is from conditioning input and Q is from pos encoded directional input
600
+ network = Decoder(
601
+ in_dim=input_dims["direction"],
602
+ conditioning_input_dim=input_dims["conditioning"],
603
+ hidden_features=self.cfg.hidden_features,
604
+ num_heads=self.cfg.num_attention_heads,
605
+ num_layers=self.cfg.num_attention_layers,
606
+ out_activation=output_activation,
607
+ )
608
+ assert network is not None, "unknown conditioning type"
609
+ return network
610
+
611
+ def apply_positional_encoding(self, directional_input, conditioning_input):
612
+ # conditioning on just invariant directional input
613
+ if self.cfg.encoded_input == "Conditioning":
614
+ conditioning_input = self.conditioning_encoding(
615
+ conditioning_input
616
+ ) # [num_rays, embedding_dim]
617
+ elif self.cfg.encoded_input == "Directions":
618
+ directional_input = self.direction_encoding(
619
+ directional_input
620
+ ) # [num_rays, embedding_dim]
621
+ elif self.cfg.encoded_input == "Both":
622
+ directional_input = self.direction_encoding(directional_input)
623
+ conditioning_input = self.conditioning_encoding(conditioning_input)
624
+
625
+ return directional_input, conditioning_input
626
+
627
+ def get_outputs(
628
+ self,
629
+ rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
630
+ latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
631
+ rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
632
+ scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
633
+ ) -> Dict[str, Tensor]:
634
+ """Returns the outputs of the field.
635
+
636
+ Args:
637
+ ray_samples: [batch_size num_rays 3]
638
+ latent_codes: [batch_size, latent_dim, 3]
639
+ rotation: [batch_size, 3, 3]
640
+ scale: [batch_size]
641
+ """
642
+ if rotation is not None:
643
+ if len(rotation.shape) == 3: # [batch_size, 3, 3]
644
+ # Expand latent_codes to match [batch_size, latent_dim, 3]
645
+ latent_codes = torch.einsum(
646
+ "bik,blk->bli",
647
+ rotation,
648
+ latent_codes,
649
+ )
650
+ else:
651
+ raise NotImplementedError(
652
+ "Unsupported rotation shape. Expected [batch_size, 3, 3]."
653
+ )
654
+
655
+ B, num_rays, _ = rays_d.shape
656
+ _, latent_dim, _ = latent_codes.shape
657
+
658
+ if not self.old_implementation:
659
+ directional_input, conditioning_input = self.invariant_function(
660
+ latent_codes,
661
+ rays_d,
662
+ equivariance=self.equivariance,
663
+ axis_of_invariance=self.axis_of_invariance,
664
+ ) # [B, num_rays, 3]
665
+
666
+ if self.cfg.positional_encoding == "NeRF":
667
+ directional_input, conditioning_input = self.apply_positional_encoding(
668
+ directional_input, conditioning_input
669
+ )
670
+
671
+ if self.conditioning == "Concat":
672
+ model_outputs = self.network(
673
+ torch.cat((directional_input, conditioning_input), dim=-1).reshape(
674
+ B * num_rays, -1
675
+ )
676
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
677
+ elif self.conditioning == "FiLM":
678
+ model_outputs = self.network(
679
+ directional_input.reshape(B * num_rays, -1),
680
+ conditioning_input.reshape(B * num_rays, -1),
681
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
682
+ elif self.conditioning == "Attention":
683
+ model_outputs = self.network(
684
+ directional_input.reshape(B * num_rays, -1),
685
+ conditioning_input.reshape(B * num_rays, -1),
686
+ ).view(B, num_rays, 3) # returns -> [B num_rays, 3]
687
+ else:
688
+ # in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions
689
+ directions = torch.stack(
690
+ (rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1
691
+ )
692
+ model_input = self.invariant_function(
693
+ latent_codes,
694
+ directions,
695
+ equivariance=self.equivariance,
696
+ axis_of_invariance=self.axis_of_invariance,
697
+ ) # [B, num_rays, 3]
698
+
699
+ model_outputs = self.network(model_input.view(B * num_rays, -1)).view(
700
+ B, num_rays, 3
701
+ )
702
+
703
+ outputs = {}
704
+
705
+ if scale is not None:
706
+ scale = trunc_exp(scale) # [num_rays] exp to ensure positive
707
+ model_outputs = model_outputs * scale.view(-1, 1, 1) # [num_rays, 3]
708
+
709
+ outputs["rgb"] = model_outputs
710
+
711
+ return outputs
712
+
713
+ def forward(
714
+ self,
715
+ rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
716
+ latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
717
+ rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
718
+ scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
719
+ ) -> Dict[str, Tensor]:
720
+ """Evaluates spherical field for a given ray bundle and rotation.
721
+
722
+ Args:
723
+ ray_samples: [B num_rays 3]
724
+ latent_codes: [B, num_rays, latent_dim, 3]
725
+ rotation: [batch_size, 3, 3]
726
+ scale: [batch_size]
727
+
728
+ Returns:
729
+ Dict[str, Tensor]: A dictionary containing the outputs of the field.
730
+ """
731
+ return self.get_outputs(
732
+ rays_d=rays_d,
733
+ latent_codes=latent_codes,
734
+ rotation=rotation,
735
+ scale=scale,
736
+ )
spar3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import alpha_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from spar3d.models.network import get_activation
12
+ from spar3d.models.utils import BaseModule
13
+
14
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
15
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
16
+
17
+
18
+ @dataclass
19
+ class HeadSpec:
20
+ name: str
21
+ out_channels: int
22
+ n_hidden_layers: int
23
+ output_activation: Optional[str] = None
24
+ output_bias: float = 0.0
25
+ add_to_decoder_features: bool = False
26
+ shape: Optional[list[int]] = None
27
+ distribution_eval: str = "sample"
28
+
29
+
30
+ class ClipBasedHeadEstimator(BaseModule):
31
+ @dataclass
32
+ class Config(BaseModule.Config):
33
+ model: str = "ViT-L/14@336px"
34
+
35
+ distribution: str = "beta"
36
+
37
+ # ["mean", "mode", "sample", "sample_mean"]
38
+ distribution_eval: str = "mode"
39
+
40
+ activation: str = "relu"
41
+ hidden_features: int = 512
42
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
43
+
44
+ cfg: Config
45
+
46
+ def configure(self):
47
+ self.model, _ = alpha_clip.load(
48
+ self.cfg.model,
49
+ ) # change to your own ckpt path
50
+ self.model.eval()
51
+
52
+ if not hasattr(self.model.visual, "input_resolution"):
53
+ self.img_size = 224
54
+ else:
55
+ self.img_size = self.model.visual.input_resolution
56
+ # Check if img_size is subscribable and pick the first element
57
+ if hasattr(self.img_size, "__getitem__"):
58
+ self.img_size = self.img_size[0]
59
+
60
+ # Do not add the weights in self.model to the optimizer
61
+ for param in self.model.parameters():
62
+ param.requires_grad = False
63
+
64
+ assert len(self.cfg.heads) > 0
65
+ heads = {}
66
+ for head in self.cfg.heads:
67
+ head_layers = []
68
+ in_feature = self.model.visual.output_dim
69
+
70
+ for i in range(head.n_hidden_layers):
71
+ head_layers += [
72
+ nn.Linear(
73
+ in_feature if i == 0 else self.cfg.hidden_features,
74
+ self.cfg.hidden_features,
75
+ ),
76
+ self.make_activation(self.cfg.activation),
77
+ ]
78
+
79
+ head_layers = [nn.Sequential(*head_layers)]
80
+ head_layers += [
81
+ nn.Sequential(
82
+ nn.Linear(
83
+ self.cfg.hidden_features,
84
+ self.cfg.hidden_features,
85
+ ),
86
+ self.make_activation(self.cfg.activation),
87
+ nn.Linear(self.cfg.hidden_features, 1),
88
+ )
89
+ for _ in range(2)
90
+ ]
91
+ heads[head.name] = nn.ModuleList(head_layers)
92
+ self.heads = nn.ModuleDict(heads)
93
+
94
+ def make_activation(self, activation):
95
+ if activation == "relu":
96
+ return nn.ReLU(inplace=True)
97
+ elif activation == "silu":
98
+ return nn.SiLU(inplace=True)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ def forward(
103
+ self,
104
+ cond_image: Float[Tensor, "B 1 H W 4"],
105
+ sample: bool = True,
106
+ ) -> dict[str, Any]:
107
+ # Run the model
108
+ # Resize cond_image to 224
109
+ cond_image = cond_image.flatten(0, 1)
110
+ cond_image = nn.functional.interpolate(
111
+ cond_image.permute(0, 3, 1, 2),
112
+ size=(self.img_size, self.img_size),
113
+ mode="bilinear",
114
+ align_corners=False,
115
+ )
116
+ mask = cond_image[:, 3:4]
117
+ cond_image = cond_image[:, :3] * mask
118
+ cond_image = Normalize(
119
+ mean=OPENAI_DATASET_MEAN,
120
+ std=OPENAI_DATASET_STD,
121
+ )(cond_image)
122
+ mask = Normalize(0.5, 0.26)(mask).half()
123
+ image_features = self.model.visual(cond_image.half(), mask).float()
124
+
125
+ # Run the heads
126
+ outputs = {}
127
+
128
+ for head_dict in self.cfg.heads:
129
+ head_name = head_dict.name
130
+ shared_head, d1_h, d2_h = self.heads[head_name]
131
+ shared_features = shared_head(image_features)
132
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
133
+ if self.cfg.distribution == "normal":
134
+ mean = d1
135
+ var = d2
136
+ if mean.shape[-1] == 1:
137
+ outputs[head_name] = torch.distributions.Normal(
138
+ mean + head_dict.output_bias,
139
+ torch.nn.functional.softplus(var),
140
+ )
141
+ else:
142
+ outputs[head_name] = torch.distributions.MultivariateNormal(
143
+ mean + head_dict.output_bias,
144
+ torch.nn.functional.softplus(var).diag_embed(),
145
+ )
146
+ elif self.cfg.distribution == "beta":
147
+ outputs[head_name] = torch.distributions.Beta(
148
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
149
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
150
+ )
151
+ else:
152
+ raise NotImplementedError
153
+
154
+ if sample:
155
+ for head_dict in self.cfg.heads:
156
+ head_name = head_dict.name
157
+ dist = outputs[head_name]
158
+
159
+ if head_dict.distribution_eval == "mean":
160
+ out = dist.mean
161
+ elif head_dict.distribution_eval == "mode":
162
+ out = dist.mode
163
+ elif head_dict.distribution_eval == "sample_mean":
164
+ out = dist.sample([10]).mean(-1)
165
+ else:
166
+ # use rsample if gradient is needed
167
+ out = dist.rsample() if self.training else dist.sample()
168
+
169
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
170
+ outputs[f"{head_name}_dist"] = dist
171
+
172
+ for head in self.cfg.heads:
173
+ if head.shape:
174
+ if not sample:
175
+ raise ValueError(
176
+ "Cannot reshape non-sampled probabilisitic outputs"
177
+ )
178
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
179
+
180
+ if head.add_to_decoder_features:
181
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
182
+ del outputs[head.name]
183
+
184
+ return outputs
spar3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
spar3d/models/mesh.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import trimesh
10
+ from jaxtyping import Float, Integer
11
+ from torch import Tensor
12
+
13
+ from spar3d.models.utils import dot
14
+
15
+ try:
16
+ from uv_unwrapper import Unwrapper
17
+ except ImportError:
18
+ import logging
19
+
20
+ logging.warning(
21
+ "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
22
+ )
23
+ # Exit early to avoid further errors
24
+ raise ImportError("uv_unwrapper not found")
25
+
26
+ try:
27
+ import gpytoolbox
28
+
29
+ TRIANGLE_REMESH_AVAILABLE = True
30
+ except ImportError:
31
+ TRIANGLE_REMESH_AVAILABLE = False
32
+ import logging
33
+
34
+ logging.warning(
35
+ "Could not import gpytoolbox. Triangle remeshing functionality will be disabled. "
36
+ "Install via `pip install gpytoolbox`"
37
+ )
38
+
39
+ try:
40
+ import pynim
41
+
42
+ QUAD_REMESH_AVAILABLE = True
43
+ except ImportError:
44
+ QUAD_REMESH_AVAILABLE = False
45
+ import logging
46
+
47
+ logging.warning(
48
+ "Could not import pynim. Quad remeshing functionality will be disabled. "
49
+ "Install via `pip install git+https://github.com/vork/PyNanoInstantMeshes.git@v0.0.3`"
50
+ )
51
+
52
+
53
+ class Mesh:
54
+ def __init__(
55
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
56
+ ) -> None:
57
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
58
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
59
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
60
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
61
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
62
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
63
+ self.extras: Dict[str, Any] = {}
64
+ for k, v in kwargs.items():
65
+ self.add_extra(k, v)
66
+
67
+ self.unwrapper = Unwrapper()
68
+
69
+ def add_extra(self, k, v) -> None:
70
+ self.extras[k] = v
71
+
72
+ @property
73
+ def requires_grad(self):
74
+ return self.v_pos.requires_grad
75
+
76
+ @property
77
+ def v_nrm(self):
78
+ if self._v_nrm is None:
79
+ self._v_nrm = self._compute_vertex_normal()
80
+ return self._v_nrm
81
+
82
+ @property
83
+ def v_tng(self):
84
+ if self._v_tng is None:
85
+ self._v_tng = self._compute_vertex_tangent()
86
+ return self._v_tng
87
+
88
+ @property
89
+ def v_tex(self):
90
+ if self._v_tex is None:
91
+ self.unwrap_uv()
92
+ return self._v_tex
93
+
94
+ @property
95
+ def edges(self):
96
+ if self._edges is None:
97
+ self._edges = self._compute_edges()
98
+ return self._edges
99
+
100
+ def _compute_vertex_normal(self):
101
+ i0 = self.t_pos_idx[:, 0]
102
+ i1 = self.t_pos_idx[:, 1]
103
+ i2 = self.t_pos_idx[:, 2]
104
+
105
+ v0 = self.v_pos[i0, :]
106
+ v1 = self.v_pos[i1, :]
107
+ v2 = self.v_pos[i2, :]
108
+
109
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
110
+
111
+ # Splat face normals to vertices
112
+ v_nrm = torch.zeros_like(self.v_pos)
113
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
114
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
115
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
116
+
117
+ # Normalize, replace zero (degenerated) normals with some default value
118
+ v_nrm = torch.where(
119
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
120
+ )
121
+ v_nrm = F.normalize(v_nrm, dim=1)
122
+
123
+ if torch.is_anomaly_enabled():
124
+ assert torch.all(torch.isfinite(v_nrm))
125
+
126
+ return v_nrm
127
+
128
+ def _compute_vertex_tangent(self):
129
+ vn_idx = [None] * 3
130
+ pos = [None] * 3
131
+ tex = [None] * 3
132
+ for i in range(0, 3):
133
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
134
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
135
+ # t_nrm_idx is always the same as t_pos_idx
136
+ vn_idx[i] = self.t_pos_idx[:, i]
137
+
138
+ tangents = torch.zeros_like(self.v_nrm)
139
+ tansum = torch.zeros_like(self.v_nrm)
140
+
141
+ # Compute tangent space for each triangle
142
+ duv1 = tex[1] - tex[0]
143
+ duv2 = tex[2] - tex[0]
144
+ dpos1 = pos[1] - pos[0]
145
+ dpos2 = pos[2] - pos[0]
146
+
147
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
148
+
149
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
150
+
151
+ # Avoid division by zero for degenerated texture coordinates
152
+ denom_safe = denom.clip(1e-6)
153
+ tang = tng_nom / denom_safe
154
+
155
+ # Update all 3 vertices
156
+ for i in range(0, 3):
157
+ idx = vn_idx[i][:, None].repeat(1, 3)
158
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
159
+ tansum.scatter_add_(
160
+ 0, idx, torch.ones_like(tang)
161
+ ) # tansum[n_i] = tansum[n_i] + 1
162
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
163
+ # triangles influence the tangent space more
164
+ tangents = tangents / tansum
165
+
166
+ # Normalize and make sure tangent is perpendicular to normal
167
+ tangents = F.normalize(tangents, dim=1)
168
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
169
+
170
+ if torch.is_anomaly_enabled():
171
+ assert torch.all(torch.isfinite(tangents))
172
+
173
+ return tangents
174
+
175
+ def quad_remesh(
176
+ self,
177
+ quad_vertex_count: int = -1,
178
+ quad_rosy: int = 4,
179
+ quad_crease_angle: float = -1.0,
180
+ quad_smooth_iter: int = 2,
181
+ quad_align_to_boundaries: bool = False,
182
+ ) -> Mesh:
183
+ if not QUAD_REMESH_AVAILABLE:
184
+ raise ImportError("Quad remeshing requires pynim to be installed")
185
+ if quad_vertex_count < 0:
186
+ quad_vertex_count = self.v_pos.shape[0]
187
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
188
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
189
+
190
+ new_vert, new_faces = pynim.remesh(
191
+ v_pos,
192
+ t_pos_idx,
193
+ quad_vertex_count // 4,
194
+ rosy=quad_rosy,
195
+ posy=4,
196
+ creaseAngle=quad_crease_angle,
197
+ align_to_boundaries=quad_align_to_boundaries,
198
+ smooth_iter=quad_smooth_iter,
199
+ deterministic=False,
200
+ )
201
+
202
+ # Briefly load in trimesh
203
+ mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
204
+
205
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
206
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
207
+
208
+ # Create new mesh
209
+ return Mesh(v_pos, t_pos_idx)
210
+
211
+ def triangle_remesh(
212
+ self,
213
+ triangle_average_edge_length_multiplier: Optional[float] = None,
214
+ triangle_remesh_steps: int = 10,
215
+ triangle_vertex_count=-1,
216
+ ):
217
+ if not TRIANGLE_REMESH_AVAILABLE:
218
+ raise ImportError("Triangle remeshing requires gpytoolbox to be installed")
219
+ if triangle_vertex_count > 0:
220
+ reduction = triangle_vertex_count / self.v_pos.shape[0]
221
+ print("Triangle reduction:", reduction)
222
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
223
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
224
+ if reduction > 1.0:
225
+ subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
226
+ print("Subdivide iters:", subdivide_iters)
227
+ v_pos, t_pos_idx = gpytoolbox.subdivide(
228
+ v_pos,
229
+ t_pos_idx,
230
+ iters=subdivide_iters,
231
+ )
232
+ reduction = triangle_vertex_count / v_pos.shape[0]
233
+
234
+ # Simplify
235
+ points_out, faces_out, _, _ = gpytoolbox.decimate(
236
+ v_pos,
237
+ t_pos_idx,
238
+ face_ratio=reduction,
239
+ )
240
+
241
+ # Convert back to torch
242
+ self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
243
+ self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
244
+ self._edges = None
245
+ triangle_average_edge_length_multiplier = None
246
+
247
+ edges = self.edges
248
+ if triangle_average_edge_length_multiplier is None:
249
+ h = None
250
+ else:
251
+ h = float(
252
+ torch.linalg.norm(
253
+ self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
254
+ )
255
+ .mean()
256
+ .item()
257
+ * triangle_average_edge_length_multiplier
258
+ )
259
+
260
+ # Convert to numpy
261
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
262
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
263
+
264
+ # Remesh
265
+ v_remesh, f_remesh = gpytoolbox.remesh_botsch(
266
+ v_pos,
267
+ t_pos_idx,
268
+ triangle_remesh_steps,
269
+ h,
270
+ )
271
+
272
+ # Convert back to torch
273
+ v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
274
+ t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
275
+
276
+ # Create new mesh
277
+ return Mesh(v_pos, t_pos_idx)
278
+
279
+ @torch.no_grad()
280
+ def unwrap_uv(
281
+ self,
282
+ island_padding: float = 0.02,
283
+ ) -> Mesh:
284
+ uv, indices = self.unwrapper(
285
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
286
+ )
287
+
288
+ # Do store per vertex UVs.
289
+ # This means we need to duplicate some vertices at the seams
290
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
291
+ individual_faces = torch.arange(
292
+ individual_vertices.shape[0],
293
+ device=individual_vertices.device,
294
+ dtype=self.t_pos_idx.dtype,
295
+ ).reshape(-1, 3)
296
+ uv_flat = uv[indices].reshape((-1, 2))
297
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
298
+
299
+ self.v_pos = individual_vertices
300
+ self.t_pos_idx = individual_faces
301
+ self._v_tex = uv_flat
302
+ self._v_nrm = self._compute_vertex_normal()
303
+ self._v_tng = self._compute_vertex_tangent()
304
+
305
+ def _compute_edges(self):
306
+ # Compute edges
307
+ edges = torch.cat(
308
+ [
309
+ self.t_pos_idx[:, [0, 1]],
310
+ self.t_pos_idx[:, [1, 2]],
311
+ self.t_pos_idx[:, [2, 0]],
312
+ ],
313
+ dim=0,
314
+ )
315
+ edges = edges.sort()[0]
316
+ edges = torch.unique(edges, dim=0)
317
+ return edges
spar3d/models/network.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.amp import custom_bwd, custom_fwd
11
+ from torch.autograd import Function
12
+
13
+ from spar3d.models.utils import BaseModule, normalize
14
+ from spar3d.utils import get_device
15
+
16
+
17
+ def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
18
+ def wrapper(fn):
19
+ if condition:
20
+ if len(kwargs) == 0:
21
+ return decorator_with_args
22
+ return decorator_with_args(*args, **kwargs)(fn)
23
+ else:
24
+ return fn
25
+
26
+ return wrapper
27
+
28
+
29
+ class PixelShuffleUpsampleNetwork(BaseModule):
30
+ @dataclass
31
+ class Config(BaseModule.Config):
32
+ in_channels: int = 1024
33
+ out_channels: int = 40
34
+ scale_factor: int = 4
35
+
36
+ conv_layers: int = 4
37
+ conv_kernel_size: int = 3
38
+
39
+ cfg: Config
40
+
41
+ def configure(self) -> None:
42
+ layers = []
43
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
44
+
45
+ in_channels = self.cfg.in_channels
46
+ for i in range(self.cfg.conv_layers):
47
+ cur_out_channels = (
48
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
49
+ )
50
+ layers.append(
51
+ nn.Conv2d(
52
+ in_channels,
53
+ cur_out_channels,
54
+ self.cfg.conv_kernel_size,
55
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
56
+ )
57
+ )
58
+ if i != self.cfg.conv_layers - 1:
59
+ layers.append(nn.ReLU(inplace=True))
60
+
61
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
62
+
63
+ self.upsample = nn.Sequential(*layers)
64
+
65
+ def forward(
66
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
67
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
68
+ return rearrange(
69
+ self.upsample(
70
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
71
+ ),
72
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
73
+ Np=3,
74
+ )
75
+
76
+
77
+ class _TruncExp(Function): # pylint: disable=abstract-method
78
+ # Implementation from torch-ngp:
79
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
+ @staticmethod
81
+ @conditional_decorator(
82
+ custom_fwd,
83
+ "cuda" in get_device(),
84
+ cast_inputs=torch.float32,
85
+ device_type="cuda",
86
+ )
87
+ def forward(ctx, x): # pylint: disable=arguments-differ
88
+ ctx.save_for_backward(x)
89
+ return torch.exp(x)
90
+
91
+ @staticmethod
92
+ @conditional_decorator(custom_bwd, "cuda" in get_device())
93
+ def backward(ctx, g): # pylint: disable=arguments-differ
94
+ x = ctx.saved_tensors[0]
95
+ return g * torch.exp(torch.clamp(x, max=15))
96
+
97
+
98
+ trunc_exp = _TruncExp.apply
99
+
100
+
101
+ def get_activation(name) -> Callable:
102
+ if name is None:
103
+ return lambda x: x
104
+ name = name.lower()
105
+ if name == "none" or name == "linear" or name == "identity":
106
+ return lambda x: x
107
+ elif name == "lin2srgb":
108
+ return lambda x: torch.where(
109
+ x > 0.0031308,
110
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
111
+ 12.92 * x,
112
+ ).clamp(0.0, 1.0)
113
+ elif name == "exp":
114
+ return lambda x: torch.exp(x)
115
+ elif name == "shifted_exp":
116
+ return lambda x: torch.exp(x - 1.0)
117
+ elif name == "trunc_exp":
118
+ return trunc_exp
119
+ elif name == "shifted_trunc_exp":
120
+ return lambda x: trunc_exp(x - 1.0)
121
+ elif name == "sigmoid":
122
+ return lambda x: torch.sigmoid(x)
123
+ elif name == "tanh":
124
+ return lambda x: torch.tanh(x)
125
+ elif name == "shifted_softplus":
126
+ return lambda x: F.softplus(x - 1.0)
127
+ elif name == "scale_-11_01":
128
+ return lambda x: x * 0.5 + 0.5
129
+ elif name == "negative":
130
+ return lambda x: -x
131
+ elif name == "normalize_channel_last":
132
+ return lambda x: normalize(x)
133
+ elif name == "normalize_channel_first":
134
+ return lambda x: normalize(x, dim=1)
135
+ else:
136
+ try:
137
+ return getattr(F, name)
138
+ except AttributeError:
139
+ raise ValueError(f"Unknown activation function: {name}")
140
+
141
+
142
+ class LambdaModule(torch.nn.Module):
143
+ def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]):
144
+ super().__init__()
145
+ self.lambd = lambd
146
+
147
+ def forward(self, x):
148
+ return self.lambd(x)
149
+
150
+
151
+ def get_activation_module(name) -> torch.nn.Module:
152
+ return LambdaModule(get_activation(name))
153
+
154
+
155
+ @dataclass
156
+ class HeadSpec:
157
+ name: str
158
+ out_channels: int
159
+ n_hidden_layers: int
160
+ output_activation: Optional[str] = None
161
+ out_bias: float = 0.0
162
+
163
+
164
+ class MaterialMLP(BaseModule):
165
+ @dataclass
166
+ class Config(BaseModule.Config):
167
+ in_channels: int = 120
168
+ n_neurons: int = 64
169
+ activation: str = "silu"
170
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
171
+
172
+ cfg: Config
173
+
174
+ def configure(self) -> None:
175
+ assert len(self.cfg.heads) > 0
176
+ heads = {}
177
+ for head in self.cfg.heads:
178
+ head_layers = []
179
+ for i in range(head.n_hidden_layers):
180
+ head_layers += [
181
+ nn.Linear(
182
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
183
+ self.cfg.n_neurons,
184
+ ),
185
+ self.make_activation(self.cfg.activation),
186
+ ]
187
+ head_layers += [
188
+ nn.Linear(
189
+ self.cfg.n_neurons,
190
+ head.out_channels,
191
+ ),
192
+ ]
193
+ heads[head.name] = nn.Sequential(*head_layers)
194
+ self.heads = nn.ModuleDict(heads)
195
+
196
+ def make_activation(self, activation):
197
+ if activation == "relu":
198
+ return nn.ReLU(inplace=True)
199
+ elif activation == "silu":
200
+ return nn.SiLU(inplace=True)
201
+ else:
202
+ raise NotImplementedError
203
+
204
+ def keys(self):
205
+ return self.heads.keys()
206
+
207
+ def forward(
208
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
209
+ ):
210
+ if include is not None and exclude is not None:
211
+ raise ValueError("Cannot specify both include and exclude.")
212
+ if include is not None:
213
+ heads = [h for h in self.cfg.heads if h.name in include]
214
+ elif exclude is not None:
215
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
216
+ else:
217
+ heads = self.cfg.heads
218
+
219
+ out = {
220
+ head.name: get_activation(head.output_activation)(
221
+ self.heads[head.name](x) + head.out_bias
222
+ )
223
+ for head in heads
224
+ }
225
+
226
+ return out
spar3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
spar3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from spar3d.models.transformers.attention import Modulation
12
+ from spar3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
+ if modulation_cond is not None
86
+ else None,
87
+ )
88
+ local_features = out.last_hidden_state
89
+ local_features = local_features.permute(0, 2, 1)
90
+ local_features = rearrange(
91
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
+ )
93
+ if packed:
94
+ local_features = local_features.squeeze(1)
95
+
96
+ return local_features
97
+
98
+ def detokenize(self, *args, **kwargs):
99
+ raise NotImplementedError
spar3d/models/tokenizers/point.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from spar3d.models.transformers.transformer_1d import Transformer1D
9
+ from spar3d.models.utils import BaseModule
10
+
11
+
12
+ class TransformerPointTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ num_attention_heads: int = 16
16
+ attention_head_dim: int = 64
17
+ in_channels: Optional[int] = 6
18
+ out_channels: Optional[int] = 1024
19
+ num_layers: int = 16
20
+ norm_num_groups: int = 32
21
+ attention_bias: bool = False
22
+ activation_fn: str = "geglu"
23
+ norm_elementwise_affine: bool = True
24
+
25
+ cfg: Config
26
+
27
+ def configure(self) -> None:
28
+ transformer_cfg = dict(self.cfg.copy())
29
+ # remove the non-transformer configs
30
+ transformer_cfg["in_channels"] = (
31
+ self.cfg.num_attention_heads * self.cfg.attention_head_dim
32
+ )
33
+ self.model = Transformer1D(transformer_cfg)
34
+ self.linear_in = torch.nn.Linear(
35
+ self.cfg.in_channels, transformer_cfg["in_channels"]
36
+ )
37
+ self.linear_out = torch.nn.Linear(
38
+ transformer_cfg["in_channels"], self.cfg.out_channels
39
+ )
40
+
41
+ def forward(
42
+ self, points: Float[Tensor, "B N Ci"], **kwargs
43
+ ) -> Float[Tensor, "B N Cp"]:
44
+ assert points.ndim == 3
45
+ inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
46
+ out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
47
+ out = self.linear_out(out) # B N Ci -> B N Co
48
+ return out
49
+
50
+ def detokenize(self, *args, **kwargs):
51
+ raise NotImplementedError
spar3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from spar3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )