KevinQu7 commited on
Commit
09c3706
1 Parent(s): 641fe65

initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +7 -0
  2. LICENSE.txt +177 -0
  3. app.py +639 -0
  4. marigold_iid_appearance.py +544 -0
  5. marigold_iid_residual.py +552 -0
  6. requirements.txt +126 -0
  7. requirements_min.txt +16 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .idea
2
+ .DS_Store
3
+ __pycache__
4
+ gradio_cached_examples
5
+ Marigold
6
+ *.sh
7
+ script/
LICENSE.txt ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
app.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
+ import functools
22
+ import os
23
+ import tempfile
24
+ import warnings
25
+
26
+ import spaces
27
+ import gradio as gr
28
+ import numpy as np
29
+ import torch as torch
30
+ from PIL import Image
31
+ from diffusers import UNet2DConditionModel
32
+
33
+ from gradio_imageslider import ImageSlider
34
+ from huggingface_hub import login
35
+
36
+ from gradio_patches.examples import Examples
37
+ from gradio_patches.flagging import HuggingFaceDatasetSaver, FlagMethod
38
+ from marigold_iid_appearance import MarigoldIIDAppearancePipeline
39
+ from marigold_iid_residual import MarigoldIIDResidualPipeline
40
+
41
+ warnings.filterwarnings(
42
+ "ignore", message=".*LoginButton created outside of a Blocks context.*"
43
+ )
44
+
45
+ default_seed = 2024
46
+
47
+ default_image_denoise_steps = 4
48
+ default_image_ensemble_size = 1
49
+ default_image_processing_res = 768
50
+ default_image_reproducuble = True
51
+ default_model_type="appearance"
52
+
53
+ default_share_always_show_hf_logout_btn = True
54
+ default_share_always_show_accordion = False
55
+
56
+ loaded_pipelines = {} # Cache to store loaded pipelines
57
+ def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, processing_res, model_type):
58
+
59
+ # Load and cache the pipeline based on the model type.
60
+ if model_type not in loaded_pipelines:
61
+ auth_token = os.environ.get("KEV_TOKEN")
62
+ if model_type == "appearance":
63
+ loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
64
+ "prs-eth/marigold-iid-appearance-v1-1", token=auth_token
65
+ )
66
+ elif model_type == "residual":
67
+ loaded_pipelines[model_type] = MarigoldIIDResidualPipeline.from_pretrained(
68
+ "prs-eth/marigold-iid-residual-v1-1", token=auth_token
69
+ )
70
+
71
+ # Move the pipeline to GPU if available
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
74
+
75
+ pipe = loaded_pipelines[model_type]
76
+
77
+ # Process the image using the preloaded pipeline.
78
+ return process_image(
79
+ pipe=pipe,
80
+ path_input=image_path,
81
+ denoise_steps=denoise_steps,
82
+ ensemble_size=ensemble_size,
83
+ processing_res=processing_res,
84
+ model_type=model_type,
85
+ )
86
+
87
+ def process_image_check(path_input):
88
+ if path_input is None:
89
+ raise gr.Error(
90
+ "Missing image in the first pane: upload a file or use one from the gallery below."
91
+ )
92
+
93
+ def process_image(
94
+ pipe,
95
+ path_input,
96
+ denoise_steps=default_image_denoise_steps,
97
+ ensemble_size=default_image_ensemble_size,
98
+ processing_res=default_image_processing_res,
99
+ model_type=default_model_type,
100
+ ):
101
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
102
+ print(f"Processing image {name_base}{name_ext}")
103
+
104
+ path_output_dir = tempfile.mkdtemp()
105
+
106
+ input_image = Image.open(path_input)
107
+
108
+
109
+ pipe_out = pipe(
110
+ input_image,
111
+ denoising_steps=denoise_steps,
112
+ ensemble_size=ensemble_size,
113
+ processing_res=processing_res,
114
+ batch_size=1 if processing_res == 0 else 0, # TODO: do we abuse "batch size" notation here?
115
+ seed=default_seed,
116
+ show_progress_bar=True,
117
+ )
118
+
119
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
120
+ os.makedirs(path_output_dir, exist_ok=True)
121
+
122
+ path_albedo_out = os.path.join(path_output_dir, f"{name_base}_albedo_fp32.npy")
123
+ path_albedo_out_vis = os.path.join(path_output_dir, f"{name_base}_albedo.png")
124
+
125
+ albedo = pipe_out.albedo
126
+ albedo_colored = pipe_out.albedo_colored
127
+
128
+ np.save(path_albedo_out, albedo)
129
+ albedo_colored.save(path_albedo_out_vis)
130
+
131
+
132
+ if model_type == "appearance":
133
+ path_material_out = os.path.join(path_output_dir, f"{name_base}_material_fp32.npy")
134
+ path_material_out_vis = os.path.join(path_output_dir, f"{name_base}_material.png")
135
+
136
+ material = pipe_out.material
137
+ material_colored = pipe_out.material_colored
138
+
139
+ np.save(path_material_out, material)
140
+ material_colored.save(path_material_out_vis)
141
+
142
+ return (
143
+ [path_input, path_albedo_out_vis],
144
+ [path_input, path_material_out_vis],
145
+ None,
146
+ [path_albedo_out_vis, path_material_out_vis, path_albedo_out, path_material_out],
147
+ )
148
+
149
+ elif model_type == "residual":
150
+ path_shading_out = os.path.join(path_output_dir, f"{name_base}_shading_fp32.npy")
151
+ path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png")
152
+ path_residual_out = os.path.join(path_output_dir, f"{name_base}_residual_fp32.npy")
153
+ path_residual_out_vis = os.path.join(path_output_dir, f"{name_base}_residual.png")
154
+
155
+ shading = pipe_out.shading
156
+ shading_colored = pipe_out.shading_colored
157
+ residual = pipe_out.residual
158
+ residual_colored = pipe_out.residual_colored
159
+
160
+ np.save(path_shading_out, shading)
161
+ shading_colored.save(path_shading_out_vis)
162
+ np.save(path_residual_out, residual)
163
+ residual_colored.save(path_residual_out_vis)
164
+
165
+ return (
166
+ [path_input, path_albedo_out_vis],
167
+ [path_input, path_shading_out_vis],
168
+ [path_input, path_residual_out_vis],
169
+ [path_albedo_out_vis, path_shading_out_vis, path_residual_out_vis, path_albedo_out, path_shading_out, path_residual_out],
170
+ )
171
+
172
+
173
+ def run_demo_server(hf_writer=None):
174
+ process_pipe_image = spaces.GPU(functools.partial(process_with_loaded_pipeline), duration=120)
175
+ gradio_theme = gr.themes.Default()
176
+
177
+ with gr.Blocks(
178
+ theme=gradio_theme,
179
+ title="Marigold Intrinsic Image Decomposition (Marigold-IID)",
180
+ css="""
181
+ #download {
182
+ height: 118px;
183
+ }
184
+ .slider .inner {
185
+ width: 5px;
186
+ background: #FFF;
187
+ }
188
+ .viewport {
189
+ aspect-ratio: 4/3;
190
+ }
191
+ .tabs button.selected {
192
+ font-size: 20px !important;
193
+ color: crimson !important;
194
+ }
195
+ h1 {
196
+ text-align: center;
197
+ display: block;
198
+ }
199
+ h2 {
200
+ text-align: center;
201
+ display: block;
202
+ }
203
+ h3 {
204
+ text-align: center;
205
+ display: block;
206
+ }
207
+ .md_feedback li {
208
+ margin-bottom: 0px !important;
209
+ }
210
+ """,
211
+ head="""
212
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
213
+ <script>
214
+ window.dataLayer = window.dataLayer || [];
215
+ function gtag() {dataLayer.push(arguments);}
216
+ gtag('js', new Date());
217
+ gtag('config', 'G-1FWSVCGZTG');
218
+ </script>
219
+ """,
220
+ ) as demo:
221
+ if hf_writer is not None:
222
+ print("Creating login button")
223
+ share_login_btn = gr.LoginButton(size="sm", scale=1, render=False)
224
+ print("Created login button")
225
+ share_login_btn.activate()
226
+ print("Activated login button")
227
+
228
+ gr.Markdown(
229
+ """
230
+ # Marigold Normals Estimation
231
+
232
+ <p align="center">
233
+ <a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
234
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
235
+ </a>
236
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
237
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
238
+ </a>
239
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
240
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
241
+ </a>
242
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
243
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
244
+ </a>
245
+ </p>
246
+ """
247
+ )
248
+
249
+ def get_share_instructions(is_full):
250
+ out = (
251
+ "### Help us improve Marigold! If the output is not what you expected, "
252
+ "you can help us by sharing it with us privately.\n"
253
+ )
254
+ if is_full:
255
+ out += (
256
+ "1. Sign into your Hugging Face account using the button below.\n"
257
+ "1. Signing in may reset the demo and results; in that case, process the image again.\n"
258
+ )
259
+ out += "1. Review and agree to the terms of usage and enter an optional message to us.\n"
260
+ out += "1. Click the 'Share' button to submit the image to us privately.\n"
261
+ return out
262
+
263
+ def get_share_conditioned_on_login(profile: gr.OAuthProfile | None):
264
+ state_logged_out = profile is None
265
+ return get_share_instructions(is_full=state_logged_out), gr.Button(
266
+ visible=(state_logged_out or default_share_always_show_hf_logout_btn)
267
+ )
268
+
269
+ with gr.Row():
270
+ with gr.Column():
271
+ image_input = gr.Image(
272
+ label="Input Image",
273
+ type="filepath",
274
+ )
275
+ model_type = gr.Radio(
276
+ [
277
+ ("Appearance (Albedo & Material)", "appearance"),
278
+ ("Residual (Albedo, Shading & Residual)", "residual"),
279
+ ],
280
+ label="Model Type",
281
+ value=default_model_type,
282
+ )
283
+
284
+ with gr.Accordion("Advanced options", open=True):
285
+ image_ensemble_size = gr.Slider(
286
+ label="Ensemble size",
287
+ minimum=1,
288
+ maximum=10,
289
+ step=1,
290
+ value=default_image_ensemble_size,
291
+ )
292
+ image_denoise_steps = gr.Slider(
293
+ label="Number of denoising steps",
294
+ minimum=1,
295
+ maximum=20,
296
+ step=1,
297
+ value=default_image_denoise_steps,
298
+ )
299
+ image_processing_res = gr.Radio(
300
+ [
301
+ ("Native", 0),
302
+ ("Recommended", 768),
303
+ ],
304
+ label="Processing resolution",
305
+ value=default_image_processing_res,
306
+ )
307
+ with gr.Row():
308
+ image_submit_btn = gr.Button(value="Compute Normals", variant="primary")
309
+ image_reset_btn = gr.Button(value="Reset")
310
+ with gr.Column():
311
+ image_output_slider1 = ImageSlider(
312
+ label="Predicted Albedo",
313
+ type="filepath",
314
+ show_download_button=True,
315
+ show_share_button=True,
316
+ interactive=False,
317
+ elem_classes="slider",
318
+ position=0.25,
319
+ visible=True
320
+ )
321
+ image_output_slider2 = ImageSlider(
322
+ label="Predicted Material",
323
+ type="filepath",
324
+ show_download_button=True,
325
+ show_share_button=True,
326
+ interactive=False,
327
+ elem_classes="slider",
328
+ position=0.25,
329
+ visible=True
330
+ )
331
+ image_output_slider3 = ImageSlider(
332
+ label="Predicted Residual",
333
+ type="filepath",
334
+ show_download_button=True,
335
+ show_share_button=True,
336
+ interactive=False,
337
+ elem_classes="slider",
338
+ position=0.25,
339
+ visible=False
340
+ )
341
+ image_output_files = gr.Files(
342
+ label="Output files",
343
+ elem_id="download",
344
+ interactive=False,
345
+ )
346
+
347
+ if hf_writer is not None:
348
+ with gr.Accordion(
349
+ "Feedback",
350
+ open=False,
351
+ visible=default_share_always_show_accordion,
352
+ ) as share_box:
353
+ share_instructions = gr.Markdown(
354
+ get_share_instructions(is_full=True),
355
+ elem_classes="md_feedback",
356
+ )
357
+ share_transfer_of_rights = gr.Checkbox(
358
+ label="(Optional) I own or hold necessary rights to the submitted image. By "
359
+ "checking this box, I grant an irrevocable, non-exclusive, transferable, "
360
+ "royalty-free, worldwide license to use the uploaded image, including for "
361
+ "publishing, reproducing, and model training. [transfer_of_rights]",
362
+ scale=1,
363
+ )
364
+ share_content_is_legal = gr.Checkbox(
365
+ label="By checking this box, I acknowledge that my uploaded content is legal and "
366
+ "safe, and that I am solely responsible for ensuring it complies with all "
367
+ "applicable laws and regulations. Additionally, I am aware that my Hugging Face "
368
+ "username is collected. [content_is_legal]",
369
+ scale=1,
370
+ )
371
+ share_reason = gr.Textbox(
372
+ label="(Optional) Reason for feedback",
373
+ max_lines=1,
374
+ interactive=True,
375
+ )
376
+ with gr.Row():
377
+ share_login_btn.render()
378
+ share_share_btn = gr.Button(
379
+ "Share", variant="stop", scale=1
380
+ )
381
+
382
+ # Function to toggle visibility and set dynamic labels
383
+ def toggle_sliders_and_labels(model_type):
384
+ if model_type == "appearance":
385
+ return (
386
+ gr.update(visible=True, label="Predicted Albedo"),
387
+ gr.update(visible=True, label="Predicted Material"),
388
+ gr.update(visible=False), # Hide third slider
389
+ )
390
+ elif model_type == "residual":
391
+ return (
392
+ gr.update(visible=True, label="Predicted Albedo"),
393
+ gr.update(visible=True, label="Predicted Shading"),
394
+ gr.update(visible=True, label="Predicted Residual"),
395
+ )
396
+
397
+ # Attach the change event to update sliders
398
+ model_type.change(
399
+ fn=toggle_sliders_and_labels,
400
+ inputs=[model_type],
401
+ outputs=[image_output_slider1, image_output_slider2, image_output_slider3],
402
+ show_progress=False,
403
+ )
404
+
405
+ Examples(
406
+ fn=process_pipe_image,
407
+ examples=[
408
+ os.path.join("files", "image", name)
409
+ for name in [
410
+ "berries.jpeg",
411
+ "costumes.png",
412
+ "cat.jpg",
413
+ "einstein.jpg",
414
+ "food.jpeg",
415
+ "food_counter.png",
416
+ "puzzle.jpeg",
417
+ "rocket.png",
418
+ "scientists.jpg",
419
+ "cat2.png",
420
+ "screw.png",
421
+ "statues.png",
422
+ "swings.jpg"
423
+ ]
424
+ ],
425
+ inputs=[image_input],
426
+ outputs= [
427
+ image_output_slider1,
428
+ image_output_slider2,
429
+ image_output_slider3,
430
+ image_output_files
431
+ ],
432
+ cache_examples=False, # TODO: toggle later
433
+ directory_name="examples_image",
434
+ )
435
+
436
+ ### Image tab
437
+
438
+ if hf_writer is not None:
439
+ image_submit_btn.click(
440
+ fn=process_image_check,
441
+ inputs=image_input,
442
+ outputs=None,
443
+ preprocess=False,
444
+ queue=False,
445
+ ).success(
446
+ get_share_conditioned_on_login,
447
+ None,
448
+ [share_instructions, share_login_btn],
449
+ queue=False,
450
+ ).then(
451
+ lambda: (
452
+ gr.Button(value="Share", interactive=True),
453
+ gr.Accordion(visible=True),
454
+ False,
455
+ False,
456
+ "",
457
+ ),
458
+ None,
459
+ [
460
+ share_share_btn,
461
+ share_box,
462
+ share_transfer_of_rights,
463
+ share_content_is_legal,
464
+ share_reason,
465
+ ],
466
+ queue=False,
467
+ ).then(
468
+ fn=process_pipe_image,
469
+ inputs=[
470
+ image_input,
471
+ image_denoise_steps,
472
+ image_ensemble_size,
473
+ image_processing_res,
474
+ model_type
475
+ ],
476
+ outputs= [
477
+ image_output_slider1,
478
+ image_output_slider2,
479
+ image_output_slider3,
480
+ image_output_files
481
+ ],
482
+ concurrency_limit=1,
483
+ )
484
+ else:
485
+ image_submit_btn.click(
486
+ fn=process_image_check,
487
+ inputs=image_input,
488
+ outputs=None,
489
+ preprocess=False,
490
+ queue=False,
491
+ ).success(
492
+ fn=process_pipe_image,
493
+ inputs=[
494
+ image_input,
495
+ image_denoise_steps,
496
+ image_ensemble_size,
497
+ image_processing_res,
498
+ model_type
499
+ ],
500
+ outputs= [
501
+ image_output_slider1,
502
+ image_output_slider2,
503
+ image_output_slider3,
504
+ image_output_files
505
+ ],
506
+ concurrency_limit=1,
507
+ )
508
+
509
+ image_reset_btn.click(
510
+ fn=lambda: (
511
+ None,
512
+ None,
513
+ None,
514
+ default_image_ensemble_size,
515
+ default_image_denoise_steps,
516
+ default_image_processing_res,
517
+ ),
518
+ inputs=[],
519
+ outputs=[
520
+ image_input,
521
+ image_output_slider1,
522
+ image_output_slider2,
523
+ image_output_slider3,
524
+ image_output_files,
525
+ image_ensemble_size,
526
+ image_denoise_steps,
527
+ image_processing_res,
528
+ ],
529
+ queue=False,
530
+ )
531
+
532
+ if hf_writer is not None:
533
+ image_reset_btn.click(
534
+ fn=lambda: (
535
+ gr.Button(value="Share", interactive=True),
536
+ gr.Accordion(visible=default_share_always_show_accordion),
537
+ ),
538
+ inputs=[],
539
+ outputs=[
540
+ share_share_btn,
541
+ share_box,
542
+ ],
543
+ queue=False,
544
+ )
545
+
546
+ ### Share functionality
547
+
548
+ if hf_writer is not None:
549
+ share_components = [
550
+ image_input,
551
+ image_denoise_steps,
552
+ image_ensemble_size,
553
+ image_processing_res,
554
+ image_output_slider1,
555
+ image_output_slider2,
556
+ image_output_slider3,
557
+ share_content_is_legal,
558
+ share_transfer_of_rights,
559
+ share_reason,
560
+ ]
561
+
562
+ hf_writer.setup(share_components, "shared_data")
563
+ share_callback = FlagMethod(hf_writer, "Share", "", visual_feedback=True)
564
+
565
+ def share_precheck(
566
+ hf_content_is_legal,
567
+ image_output_slider,
568
+ profile: gr.OAuthProfile | None,
569
+ ):
570
+ if profile is None:
571
+ raise gr.Error(
572
+ "Log into the Space with your Hugging Face account first."
573
+ )
574
+ if image_output_slider is None or image_output_slider[0] is None:
575
+ raise gr.Error("No output detected; process the image first.")
576
+ if not hf_content_is_legal:
577
+ raise gr.Error(
578
+ "You must consent that the uploaded content is legal."
579
+ )
580
+ return gr.Button(value="Sharing in progress", interactive=False)
581
+
582
+ share_share_btn.click(
583
+ share_precheck,
584
+ [share_content_is_legal, image_output_slider1],
585
+ share_share_btn,
586
+ preprocess=False,
587
+ queue=False,
588
+ ).success(
589
+ share_callback,
590
+ inputs=share_components,
591
+ outputs=share_share_btn,
592
+ preprocess=False,
593
+ queue=False,
594
+ )
595
+
596
+ demo.queue(
597
+ api_open=False,
598
+ ).launch(
599
+ server_name="0.0.0.0",
600
+ server_port=7860,
601
+ )
602
+
603
+
604
+ def main():
605
+ CHECKPOINT = "prs-eth/marigold-iid-appearance-v1-1"
606
+ CROWD_DATA = "crowddata-marigold-iid-appearance-v1-1-space-v1-1"
607
+
608
+ os.system("pip freeze")
609
+
610
+ if "HF_TOKEN_LOGIN" in os.environ:
611
+ login(token=os.environ["HF_TOKEN_LOGIN"])
612
+
613
+ auth_token = os.environ.get("KEV_TOKEN")
614
+ pipe = MarigoldIIDAppearancePipeline.from_pretrained(CHECKPOINT,token=auth_token)
615
+ try:
616
+ import xformers
617
+
618
+ pipe.enable_xformers_memory_efficient_attention()
619
+ except:
620
+ pass # run without xformers
621
+
622
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
623
+ pipe = pipe.to(device)
624
+
625
+ hf_writer = None
626
+ if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ:
627
+ hf_writer = HuggingFaceDatasetSaver(
628
+ os.getenv("HF_TOKEN_LOGIN_WRITE_CROWD"),
629
+ CROWD_DATA,
630
+ private=True,
631
+ info_filename="dataset_info.json",
632
+ separate_dirs=True,
633
+ )
634
+
635
+ run_demo_server(hf_writer)
636
+
637
+
638
+ if __name__ == "__main__":
639
+ main()
marigold_iid_appearance.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, Bingxin Ke, Bo Li & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldcomputervision.github.io
18
+ # --------------------------------------------------------------------------
19
+ import logging
20
+ import math
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from diffusers import (
26
+ AutoencoderKL,
27
+ DDIMScheduler,
28
+ DiffusionPipeline,
29
+ UNet2DConditionModel,
30
+ )
31
+ from diffusers.utils import BaseOutput, check_min_version
32
+ from PIL import Image
33
+ from PIL.Image import Resampling
34
+ from torch.utils.data import DataLoader, TensorDataset
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPTextModel, CLIPTokenizer
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.27.0.dev0")
40
+
41
+ class MarigoldIIDAppearanceOutput(BaseOutput):
42
+ """
43
+ Output class for Marigold IID Appearance pipeline.
44
+
45
+ Args:
46
+ albedo (`np.ndarray`):
47
+ Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
48
+ albedo_colored (`PIL.Image.Image`):
49
+ Colorized albedo map with the shape of [H, W, 3].
50
+ material (`np.ndarray`):
51
+ Predicted material map with the shape of [3, H, W] and values in [0, 1].
52
+ 1st channel (Red) is roughness
53
+ 2nd channel (Green) is metallicity
54
+ 3rd channel (Blue) is empty (zero)
55
+ material_colored (`PIL.Image.Image`):
56
+ Colorized material map with the shape of [H, W, 3].
57
+ 1st channel (Red) is roughness
58
+ 2nd channel (Green) is metallicity
59
+ 3rd channel (Blue) is empty (zero)
60
+ """
61
+
62
+ albedo: np.ndarray
63
+ albedo_colored: Image.Image
64
+ material: np.ndarray
65
+ material_colored: Image.Image
66
+
67
+ class MarigoldIIDAppearancePipeline(DiffusionPipeline):
68
+ """
69
+ Pipeline for Intrinsic Image Decomposition (Albedo and Material) using Marigold: https://marigoldcomputervision.github.io.
70
+
71
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
72
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
73
+
74
+ Args:
75
+ unet (`UNet2DConditionModel`):
76
+ Conditional U-Net to denoise the normals latent, conditioned on image latent.
77
+ vae (`AutoencoderKL`):
78
+ Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
79
+ to and from latent representations.
80
+ scheduler (`DDIMScheduler`):
81
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
82
+ text_encoder (`CLIPTextModel`):
83
+ Text-encoder, for empty text embedding.
84
+ tokenizer (`CLIPTokenizer`):
85
+ CLIP tokenizer.
86
+ """
87
+
88
+ latent_scale_factor = 0.18215
89
+
90
+ def __init__(
91
+ self,
92
+ unet: UNet2DConditionModel,
93
+ vae: AutoencoderKL,
94
+ scheduler: DDIMScheduler,
95
+ text_encoder: CLIPTextModel,
96
+ tokenizer: CLIPTokenizer,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.register_modules(
101
+ unet=unet,
102
+ vae=vae,
103
+ scheduler=scheduler,
104
+ text_encoder=text_encoder,
105
+ tokenizer=tokenizer,
106
+ )
107
+
108
+ self.empty_text_embed = None
109
+
110
+ self.n_targets = 2 # Albedo and material
111
+
112
+ @torch.no_grad()
113
+ def __call__(
114
+ self,
115
+ input_image: Image,
116
+ denoising_steps: int = 4,
117
+ ensemble_size: int = 10,
118
+ processing_res: int = 768,
119
+ match_input_res: bool = True,
120
+ resample_method: str = "bilinear",
121
+ batch_size: int = 0,
122
+ save_memory: bool = False,
123
+ seed: Union[int, None] = None,
124
+ color_map: str = "Spectral", # TODO change colorization api based on modality
125
+ show_progress_bar: bool = True,
126
+ **kwargs,
127
+ ) -> MarigoldIIDAppearanceOutput:
128
+ """
129
+ Function invoked when calling the pipeline.
130
+
131
+ Args:
132
+ input_image (`Image`):
133
+ Input RGB (or gray-scale) image.
134
+ denoising_steps (`int`, *optional*, defaults to `10`):
135
+ Number of diffusion denoising steps (DDIM) during inference.
136
+ ensemble_size (`int`, *optional*, defaults to `10`):
137
+ Number of predictions to be ensembled.
138
+ processing_res (`int`, *optional*, defaults to `768`):
139
+ Maximum resolution of processing.
140
+ If set to 0: will not resize at all.
141
+ match_input_res (`bool`, *optional*, defaults to `True`):
142
+ Resize normals prediction to match input resolution.
143
+ Only valid if `limit_input_res` is not None.
144
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
145
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
146
+ batch_size (`int`, *optional*, defaults to `0`):
147
+ Inference batch size, no bigger than `num_ensemble`.
148
+ If set to 0, the script will automatically decide the proper batch size.
149
+ save_memory (`bool`, defaults to `False`):
150
+ Extra steps to save memory at the cost of perforance.
151
+ seed (`int`, *optional*, defaults to `None`)
152
+ Reproducibility seed.
153
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
154
+ Colormap used to colorize the normals map.
155
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
156
+ Display a progress bar of diffusion denoising.
157
+ Returns:
158
+ `MarigoldIIDAppearanceOutput`: Output class for Marigold monocular intrinsic image decomposition (appearance) prediction pipeline, including:
159
+ - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
160
+ - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
161
+ - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
162
+ - **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
163
+ """
164
+
165
+ if not match_input_res:
166
+ assert processing_res is not None
167
+ assert processing_res >= 0
168
+ assert denoising_steps >= 1
169
+ assert ensemble_size >= 1
170
+
171
+ # Check if denoising step is reasonable
172
+ self.check_inference_step(denoising_steps)
173
+
174
+ resample_method: Resampling = self.get_pil_resample_method(resample_method)
175
+
176
+ W, H = input_image.size
177
+
178
+ if processing_res > 0:
179
+ input_image = self.resize_max_res(
180
+ input_image, max_edge_resolution=processing_res, resample_method=resample_method,
181
+ )
182
+ input_image = input_image.convert("RGB")
183
+ image = np.asarray(input_image)
184
+
185
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
186
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
187
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
188
+ rgb_norm = rgb_norm.to(self.device)
189
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
190
+
191
+ def ensemble(
192
+ targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median",
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
194
+ uncertainty = None
195
+ if reduction == "mean":
196
+ prediction = torch.mean(targets, dim=0, keepdim=True)
197
+ if return_uncertainty:
198
+ uncertainty = torch.std(targets, dim=0, keepdim=True)
199
+ elif reduction == "median":
200
+ prediction = torch.median(targets, dim=0, keepdim=True).values
201
+ if return_uncertainty:
202
+ uncertainty = torch.median(
203
+ torch.abs(targets - prediction), dim=0, keepdim=True
204
+ ).values
205
+ else:
206
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
207
+ return prediction, uncertainty
208
+
209
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
210
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
211
+
212
+ if batch_size <= 0:
213
+ batch_size = self.find_batch_size(
214
+ ensemble_size=ensemble_size,
215
+ input_res=max(rgb_norm.shape[1:]),
216
+ dtype=self.dtype,
217
+ )
218
+
219
+ single_rgb_loader = DataLoader(
220
+ single_rgb_dataset, batch_size=batch_size, shuffle=False
221
+ )
222
+
223
+ target_pred_ls = []
224
+ iterable = single_rgb_loader
225
+ if show_progress_bar:
226
+ iterable = tqdm(
227
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
228
+ )
229
+
230
+ for batch in iterable:
231
+ (batched_img,) = batch
232
+ target_pred = self.single_infer(
233
+ rgb_in=batched_img,
234
+ num_inference_steps=denoising_steps,
235
+ seed=seed,
236
+ show_pbar=show_progress_bar,
237
+ )
238
+ target_pred = target_pred.detach()
239
+ if save_memory:
240
+ target_pred = target_pred.cpu()
241
+ target_pred_ls.append(target_pred.detach())
242
+
243
+ target_preds = torch.concat(target_pred_ls, dim=0)
244
+ pred_uncert = None
245
+
246
+ if save_memory:
247
+ torch.cuda.empty_cache()
248
+
249
+ if ensemble_size > 1:
250
+ final_pred, pred_uncert = ensemble(
251
+ target_preds,
252
+ reduction = "median",
253
+ return_uncertainty=False
254
+ )
255
+ else:
256
+ final_pred = target_preds
257
+ pred_uncert = None
258
+
259
+ if match_input_res:
260
+ final_pred = torch.nn.functional.interpolate(
261
+ final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
262
+ ) # [1,3,H,W]
263
+
264
+ if pred_uncert is not None:
265
+ pred_uncert = torch.nn.functional.interpolate(
266
+ pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
267
+ ).squeeze(
268
+ 1
269
+ ) # [1,H,W]
270
+
271
+ # Convert to numpy
272
+ final_pred = final_pred.squeeze()
273
+ final_pred = final_pred.cpu().numpy()
274
+
275
+ albedo = final_pred[0:3, :, :]
276
+ material = np.stack(
277
+ (final_pred[3, :, :], final_pred[4, :, :], final_pred[5, :, :]), axis=0
278
+ )
279
+
280
+ albedo_colored = (albedo + 1.0) * 0.5
281
+ albedo_colored = (albedo_colored * 255).to(np.uint8)
282
+ albedo_colored = self.chw2hwc(albedo_colored)
283
+ albedo_colored_img = Image.fromarray(albedo_colored)
284
+
285
+ material_colored = (material + 1.0) * 0.5
286
+ material_colored = (material_colored * 255).to(np.uint8)
287
+ material_colored = self.chw2hwc(material_colored)
288
+ material_colored_img = Image.fromarray(material_colored)
289
+
290
+ out = MarigoldIIDAppearanceOutput(
291
+ albedo=albedo,
292
+ albedo_colored=albedo_colored_img,
293
+ material=material,
294
+ material_colored=material_colored_img
295
+ )
296
+
297
+ return out
298
+
299
+ def check_inference_step(self, n_step: int):
300
+ """
301
+ Check if denoising step is reasonable
302
+ Args:
303
+ n_step (`int`): denoising steps
304
+ """
305
+ assert n_step >= 1
306
+
307
+ if isinstance(self.scheduler, DDIMScheduler):
308
+ pass
309
+ else:
310
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
311
+
312
+ def encode_empty_text(self):
313
+ """
314
+ Encode text embedding for empty prompt.
315
+ """
316
+ prompt = ""
317
+ text_inputs = self.tokenizer(
318
+ prompt,
319
+ padding="do_not_pad",
320
+ max_length=self.tokenizer.model_max_length,
321
+ truncation=True,
322
+ return_tensors="pt",
323
+ )
324
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
325
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
326
+
327
+ @torch.no_grad()
328
+ def single_infer(
329
+ self,
330
+ rgb_in: torch.Tensor,
331
+ num_inference_steps: int,
332
+ seed: Union[int, None],
333
+ show_pbar: bool,
334
+ ) -> torch.Tensor:
335
+ """
336
+ Perform an individual iid prediction without ensembling.
337
+ """
338
+ device = rgb_in.device
339
+
340
+ # Set timesteps
341
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
342
+ timesteps = self.scheduler.timesteps # [T]
343
+
344
+ # Encode image
345
+ rgb_latent = self.encode_rgb(rgb_in)
346
+
347
+ target_latent_shape = list(rgb_latent.shape)
348
+ target_latent_shape[1] *= (
349
+ 2 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
350
+ )
351
+
352
+ # Initialize prediction latent with noise
353
+ if seed is None:
354
+ rand_num_generator = None
355
+ else:
356
+ rand_num_generator = torch.Generator(device=device)
357
+ rand_num_generator.manual_seed(seed)
358
+ target_latents = torch.randn(
359
+ target_latent_shape,
360
+ device=device,
361
+ dtype=self.dtype,
362
+ generator=rand_num_generator,
363
+ ) # [B, 4, h, w]
364
+
365
+ # Batched empty text embedding
366
+ if self.empty_text_embed is None:
367
+ self.encode_empty_text()
368
+ batch_empty_text_embed = self.empty_text_embed.repeat(
369
+ (rgb_latent.shape[0], 1, 1)
370
+ ) # [B, 2, 1024]
371
+
372
+ # Denoising loop
373
+ if show_pbar:
374
+ iterable = tqdm(
375
+ enumerate(timesteps),
376
+ total=len(timesteps),
377
+ leave=False,
378
+ desc=" " * 4 + "Diffusion denoising",
379
+ )
380
+ else:
381
+ iterable = enumerate(timesteps)
382
+
383
+ for i, t in iterable:
384
+ unet_input = torch.cat(
385
+ [rgb_latent, target_latents], dim=1
386
+ ) # this order is important
387
+
388
+ # predict the noise residual
389
+ noise_pred = self.unet(
390
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
391
+ ).sample # [B, 4, h, w]
392
+
393
+ # compute the previous noisy sample x_t -> x_t-1
394
+ target_latents = self.scheduler.step(
395
+ noise_pred, t, target_latents, generator=rand_num_generator
396
+ ).prev_sample
397
+
398
+ # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
399
+
400
+ targets = self.decode_targets(target_latents) # [B, 3, H, W]
401
+ targets = torch.clip(targets, -1.0, 1.0)
402
+
403
+ return targets
404
+
405
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
406
+ """
407
+ Encode RGB image into latent.
408
+
409
+ Args:
410
+ rgb_in (`torch.Tensor`):
411
+ Input RGB image to be encoded.
412
+
413
+ Returns:
414
+ `torch.Tensor`: Image latent.
415
+ """
416
+ # encode
417
+ h = self.vae.encoder(rgb_in)
418
+ moments = self.vae.quant_conv(h)
419
+ mean, logvar = torch.chunk(moments, 2, dim=1)
420
+ # scale latent
421
+ rgb_latent = mean * self.latent_scale_factor
422
+ return rgb_latent
423
+
424
+ def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
425
+ """
426
+ Decode target latent into target map.
427
+
428
+ Args:
429
+ target_latents (`torch.Tensor`):
430
+ Target latent to be decoded.
431
+
432
+ Returns:
433
+ `torch.Tensor`: Decoded target map.
434
+ """
435
+
436
+ assert target_latents.shape[1] == 8 # self.n_targets * 4
437
+
438
+ # scale latent
439
+ target_latents = target_latents / self.rgb_latent_scale_factor
440
+ # decode
441
+ targets = []
442
+ for i in range(self.n_targets):
443
+ latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
444
+ z = self.vae.post_quant_conv(latent)
445
+ stacked = self.vae.decoder(z)
446
+
447
+ targets.append(stacked)
448
+
449
+ return torch.cat(targets, dim=1)
450
+
451
+ @staticmethod
452
+ def get_pil_resample_method(method_str: str) -> Resampling:
453
+ resample_method_dic = {
454
+ "bilinear": Resampling.BILINEAR,
455
+ "bicubic": Resampling.BICUBIC,
456
+ "nearest": Resampling.NEAREST,
457
+ }
458
+ resample_method = resample_method_dic.get(method_str, None)
459
+ if resample_method is None:
460
+ raise ValueError(f"Unknown resampling method: {resample_method}")
461
+ else:
462
+ return resample_method
463
+
464
+ @staticmethod
465
+ def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
466
+ """
467
+ Resize image to limit maximum edge length while keeping aspect ratio.
468
+ """
469
+ original_width, original_height = img.size
470
+ downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
471
+
472
+ new_width = int(original_width * downscale_factor)
473
+ new_height = int(original_height * downscale_factor)
474
+
475
+ resized_img = img.resize((new_width, new_height), resample=resample_method)
476
+ return resized_img
477
+
478
+ @staticmethod
479
+ def chw2hwc(chw):
480
+ assert 3 == len(chw.shape)
481
+ if isinstance(chw, torch.Tensor):
482
+ hwc = torch.permute(chw, (1, 2, 0))
483
+ elif isinstance(chw, np.ndarray):
484
+ hwc = np.moveaxis(chw, 0, -1)
485
+ return hwc
486
+
487
+ @staticmethod
488
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
489
+ """
490
+ Automatically search for suitable operating batch size.
491
+
492
+ Args:
493
+ ensemble_size (`int`):
494
+ Number of predictions to be ensembled.
495
+ input_res (`int`):
496
+ Operating resolution of the input image.
497
+
498
+ Returns:
499
+ `int`: Operating batch size.
500
+ """
501
+ # Search table for suggested max. inference batch size
502
+ bs_search_table = [
503
+ # tested on A100-PCIE-80GB
504
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
505
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
506
+ # tested on A100-PCIE-40GB
507
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
508
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
509
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
510
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
511
+ # tested on RTX3090, RTX4090
512
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
513
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
514
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
515
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
516
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
517
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
518
+ # tested on GTX1080Ti
519
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
520
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
521
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
522
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
523
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
524
+ ]
525
+
526
+ if not torch.cuda.is_available():
527
+ return 1
528
+
529
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
530
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
531
+ for settings in sorted(
532
+ filtered_bs_search_table,
533
+ key=lambda k: (k["res"], -k["total_vram"]),
534
+ ):
535
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
536
+ bs = settings["bs"]
537
+ if bs > ensemble_size:
538
+ bs = ensemble_size
539
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
540
+ bs = math.ceil(ensemble_size / 2)
541
+ return bs
542
+
543
+ return 1
544
+
marigold_iid_residual.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldcomputervision.github.io
18
+ # --------------------------------------------------------------------------
19
+ import logging
20
+ import math
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from diffusers import (
26
+ AutoencoderKL,
27
+ DDIMScheduler,
28
+ DiffusionPipeline,
29
+ UNet2DConditionModel,
30
+ )
31
+ from diffusers.utils import BaseOutput, check_min_version
32
+ from PIL import Image
33
+ from PIL.Image import Resampling
34
+ from torch.utils.data import DataLoader, TensorDataset
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPTextModel, CLIPTokenizer
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.27.0.dev0")
40
+
41
+ class MarigoldIIDResidualOutput(BaseOutput):
42
+ """
43
+ Output class for Marigold IID Residual pipeline.
44
+
45
+ Args:
46
+ albedo (`np.ndarray`):
47
+ Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
48
+ albedo_colored (`PIL.Image.Image`):
49
+ Colorized albedo map with the shape of [H, W, 3].
50
+ shading (`np.ndarray`):
51
+ Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1].
52
+ shading_colored (`PIL.Image.Image`):
53
+ Colorized diffuse shading map with the shape of [H, W, 3].
54
+ residual (`np.ndarray`):
55
+ Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1].
56
+ residual_colored (`PIL.Image.Image`):
57
+ Colorized non-diffuse residual map with the shape of [H, W, 3].
58
+
59
+ """
60
+
61
+ albedo: np.ndarray
62
+ albedo_colored: Image.Image
63
+ shading: np.ndarray
64
+ shading_colored: Image.Image
65
+ residual: np.ndarray
66
+ residual_colored: Image.Image
67
+
68
+ class MarigoldIIDResidualPipeline(DiffusionPipeline):
69
+ """
70
+ Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
71
+
72
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
73
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
74
+
75
+ Args:
76
+ unet (`UNet2DConditionModel`):
77
+ Conditional U-Net to denoise the normals latent, conditioned on image latent.
78
+ vae (`AutoencoderKL`):
79
+ Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
80
+ to and from latent representations.
81
+ scheduler (`DDIMScheduler`):
82
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
83
+ text_encoder (`CLIPTextModel`):
84
+ Text-encoder, for empty text embedding.
85
+ tokenizer (`CLIPTokenizer`):
86
+ CLIP tokenizer.
87
+ """
88
+
89
+ latent_scale_factor = 0.18215
90
+
91
+ def __init__(
92
+ self,
93
+ unet: UNet2DConditionModel,
94
+ vae: AutoencoderKL,
95
+ scheduler: DDIMScheduler,
96
+ text_encoder: CLIPTextModel,
97
+ tokenizer: CLIPTokenizer,
98
+ ):
99
+ super().__init__()
100
+
101
+ self.register_modules(
102
+ unet=unet,
103
+ vae=vae,
104
+ scheduler=scheduler,
105
+ text_encoder=text_encoder,
106
+ tokenizer=tokenizer,
107
+ )
108
+
109
+ self.empty_text_embed = None
110
+ self.n_targets = 3 # Albedo, shading, residual
111
+
112
+ @torch.no_grad()
113
+ def __call__(
114
+ self,
115
+ input_image: Image,
116
+ denoising_steps: int = 4,
117
+ ensemble_size: int = 10,
118
+ processing_res: int = 768,
119
+ match_input_res: bool = True,
120
+ resample_method: str = "bilinear",
121
+ batch_size: int = 0,
122
+ save_memory: bool = False,
123
+ seed: Union[int, None] = None,
124
+ color_map: str = "Spectral", # TODO change colorization api based on modality
125
+ show_progress_bar: bool = True,
126
+ **kwargs,
127
+ ) -> MarigoldIIDResidualOutput:
128
+ """
129
+ Function invoked when calling the pipeline.
130
+
131
+ Args:
132
+ input_image (`Image`):
133
+ Input RGB (or gray-scale) image.
134
+ denoising_steps (`int`, *optional*, defaults to `10`):
135
+ Number of diffusion denoising steps (DDIM) during inference.
136
+ ensemble_size (`int`, *optional*, defaults to `10`):
137
+ Number of predictions to be ensembled.
138
+ processing_res (`int`, *optional*, defaults to `768`):
139
+ Maximum resolution of processing.
140
+ If set to 0: will not resize at all.
141
+ match_input_res (`bool`, *optional*, defaults to `True`):
142
+ Resize normals prediction to match input resolution.
143
+ Only valid if `limit_input_res` is not None.
144
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
145
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
146
+ batch_size (`int`, *optional*, defaults to `0`):
147
+ Inference batch size, no bigger than `num_ensemble`.
148
+ If set to 0, the script will automatically decide the proper batch size.
149
+ save_memory (`bool`, defaults to `False`):
150
+ Extra steps to save memory at the cost of perforance.
151
+ seed (`int`, *optional*, defaults to `None`)
152
+ Reproducibility seed.
153
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
154
+ Colormap used to colorize the normals map.
155
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
156
+ Display a progress bar of diffusion denoising.
157
+ Returns:
158
+ `MarigoldIIDResidualOutput`: Output class for Marigold monocular intrinsic image decomposition (Residual) prediction pipeline, including:
159
+ - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
160
+ - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
161
+ - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
162
+ - **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
163
+ """
164
+
165
+ if not match_input_res:
166
+ assert processing_res is not None
167
+ assert processing_res >= 0
168
+ assert denoising_steps >= 1
169
+ assert ensemble_size >= 1
170
+
171
+ # Check if denoising step is reasonable
172
+ self.check_inference_step(denoising_steps)
173
+
174
+ resample_method: Resampling = self.get_pil_resample_method(resample_method)
175
+
176
+ W, H = input_image.size
177
+
178
+ if processing_res > 0:
179
+ input_image = self.resize_max_res(
180
+ input_image, max_edge_resolution=processing_res, resample_method=resample_method,
181
+ )
182
+ input_image = input_image.convert("RGB")
183
+ image = np.asarray(input_image)
184
+
185
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
186
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
187
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
188
+ rgb_norm = rgb_norm.to(self.device)
189
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
190
+
191
+ def ensemble(
192
+ targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median",
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
194
+ uncertainty = None
195
+ if reduction == "mean":
196
+ prediction = torch.mean(targets, dim=0, keepdim=True)
197
+ if return_uncertainty:
198
+ uncertainty = torch.std(targets, dim=0, keepdim=True)
199
+ elif reduction == "median":
200
+ prediction = torch.median(targets, dim=0, keepdim=True).values
201
+ if return_uncertainty:
202
+ uncertainty = torch.median(
203
+ torch.abs(targets - prediction), dim=0, keepdim=True
204
+ ).values
205
+ else:
206
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
207
+ return prediction, uncertainty
208
+
209
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
210
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
211
+
212
+ if batch_size <= 0:
213
+ batch_size = self.find_batch_size(
214
+ ensemble_size=ensemble_size,
215
+ input_res=max(rgb_norm.shape[1:]),
216
+ dtype=self.dtype,
217
+ )
218
+
219
+ single_rgb_loader = DataLoader(
220
+ single_rgb_dataset, batch_size=batch_size, shuffle=False
221
+ )
222
+
223
+ target_pred_ls = []
224
+ iterable = single_rgb_loader
225
+ if show_progress_bar:
226
+ iterable = tqdm(
227
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
228
+ )
229
+
230
+ for batch in iterable:
231
+ (batched_img,) = batch
232
+ target_pred = self.single_infer(
233
+ rgb_in=batched_img,
234
+ num_inference_steps=denoising_steps,
235
+ seed=seed,
236
+ show_pbar=show_progress_bar,
237
+ )
238
+ target_pred = target_pred.detach()
239
+ if save_memory:
240
+ target_pred = target_pred.cpu()
241
+ target_pred_ls.append(target_pred.detach())
242
+
243
+ target_preds = torch.concat(target_pred_ls, dim=0)
244
+ pred_uncert = None
245
+
246
+ if save_memory:
247
+ torch.cuda.empty_cache()
248
+
249
+ if ensemble_size > 1:
250
+ final_pred, pred_uncert = ensemble(
251
+ target_preds,
252
+ reduction = "median",
253
+ return_uncertainty=False
254
+ )
255
+ else:
256
+ final_pred = target_preds
257
+ pred_uncert = None
258
+
259
+ if match_input_res:
260
+ final_pred = torch.nn.functional.interpolate(
261
+ final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
262
+ ) # [1,3,H,W]
263
+
264
+ if pred_uncert is not None:
265
+ pred_uncert = torch.nn.functional.interpolate(
266
+ pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
267
+ ).squeeze(
268
+ 1
269
+ ) # [1,H,W]
270
+
271
+ # Convert to numpy
272
+ final_pred = final_pred.squeeze()
273
+ final_pred = final_pred.cpu().numpy()
274
+
275
+ albedo = final_pred[0:3, :, :]
276
+ shading = final_pred[3:6, :, :]
277
+ residual = final_pred[6:, :, :]
278
+
279
+ albedo_colored = (albedo + 1.0) * 0.5
280
+ albedo_colored = (albedo_colored * 255).to(np.uint8)
281
+ albedo_colored = self.chw2hwc(albedo_colored)
282
+ albedo_colored_img = Image.fromarray(albedo_colored)
283
+
284
+ shading_colored = (shading + 1.0) * 0.5
285
+ shading_colored = shading_colored / shading_colored.max() # rescale for better visualization
286
+ shading_colored = (shading_colored * 255).to(np.uint8)
287
+ shading_colored = self.chw2hwc(shading_colored)
288
+ shading_colored_img = Image.fromarray(shading_colored)
289
+
290
+ residual_colored = (residual + 1.0) * 0.5
291
+ residual_colored = residual_colored / residual_colored.max() # rescale for better visualization
292
+ residual_colored = (residual_colored * 255).to(np.uint8)
293
+ residual_colored = self.chw2hwc(residual_colored)
294
+ residual_colored_img = Image.fromarray(residual_colored)
295
+
296
+ out = MarigoldIIDResidualOutput(
297
+ albedo=albedo,
298
+ albedo_colored=albedo_colored_img,
299
+ shading=shading,
300
+ shading_colored=shading_colored_img,
301
+ residual=residual,
302
+ residual_colored=residual_colored_img
303
+ )
304
+
305
+ return out
306
+
307
+ def check_inference_step(self, n_step: int):
308
+ """
309
+ Check if denoising step is reasonable
310
+ Args:
311
+ n_step (`int`): denoising steps
312
+ """
313
+ assert n_step >= 1
314
+
315
+ if isinstance(self.scheduler, DDIMScheduler):
316
+ pass
317
+ else:
318
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
319
+
320
+ def encode_empty_text(self):
321
+ """
322
+ Encode text embedding for empty prompt.
323
+ """
324
+ prompt = ""
325
+ text_inputs = self.tokenizer(
326
+ prompt,
327
+ padding="do_not_pad",
328
+ max_length=self.tokenizer.model_max_length,
329
+ truncation=True,
330
+ return_tensors="pt",
331
+ )
332
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
333
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
334
+
335
+ @torch.no_grad()
336
+ def single_infer(
337
+ self,
338
+ rgb_in: torch.Tensor,
339
+ num_inference_steps: int,
340
+ seed: Union[int, None],
341
+ show_pbar: bool,
342
+ ) -> torch.Tensor:
343
+ """
344
+ Perform an individual iid prediction without ensembling.
345
+ """
346
+ device = rgb_in.device
347
+
348
+ # Set timesteps
349
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
350
+ timesteps = self.scheduler.timesteps # [T]
351
+
352
+ # Encode image
353
+ rgb_latent = self.encode_rgb(rgb_in)
354
+
355
+ target_latent_shape = list(rgb_latent.shape)
356
+ target_latent_shape[1] *= (
357
+ 3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
358
+ )
359
+
360
+ # Initialize prediction latent with noise
361
+ if seed is None:
362
+ rand_num_generator = None
363
+ else:
364
+ rand_num_generator = torch.Generator(device=device)
365
+ rand_num_generator.manual_seed(seed)
366
+ target_latents = torch.randn(
367
+ target_latent_shape,
368
+ device=device,
369
+ dtype=self.dtype,
370
+ generator=rand_num_generator,
371
+ ) # [B, 4, h, w]
372
+
373
+ # Batched empty text embedding
374
+ if self.empty_text_embed is None:
375
+ self.encode_empty_text()
376
+ batch_empty_text_embed = self.empty_text_embed.repeat(
377
+ (rgb_latent.shape[0], 1, 1)
378
+ ) # [B, 2, 1024]
379
+
380
+ # Denoising loop
381
+ if show_pbar:
382
+ iterable = tqdm(
383
+ enumerate(timesteps),
384
+ total=len(timesteps),
385
+ leave=False,
386
+ desc=" " * 4 + "Diffusion denoising",
387
+ )
388
+ else:
389
+ iterable = enumerate(timesteps)
390
+
391
+ for i, t in iterable:
392
+ unet_input = torch.cat(
393
+ [rgb_latent, target_latents], dim=1
394
+ ) # this order is important
395
+
396
+ # predict the noise residual
397
+ noise_pred = self.unet(
398
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
399
+ ).sample # [B, 4, h, w]
400
+
401
+ # compute the previous noisy sample x_t -> x_t-1
402
+ target_latents = self.scheduler.step(
403
+ noise_pred, t, target_latents, generator=rand_num_generator
404
+ ).prev_sample
405
+
406
+ # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
407
+
408
+ targets = self.decode_targets(target_latents) # [B, 3, H, W]
409
+ targets = torch.clip(targets, -1.0, 1.0)
410
+
411
+ return targets
412
+
413
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
414
+ """
415
+ Encode RGB image into latent.
416
+
417
+ Args:
418
+ rgb_in (`torch.Tensor`):
419
+ Input RGB image to be encoded.
420
+
421
+ Returns:
422
+ `torch.Tensor`: Image latent.
423
+ """
424
+ # encode
425
+ h = self.vae.encoder(rgb_in)
426
+ moments = self.vae.quant_conv(h)
427
+ mean, logvar = torch.chunk(moments, 2, dim=1)
428
+ # scale latent
429
+ rgb_latent = mean * self.latent_scale_factor
430
+ return rgb_latent
431
+
432
+ def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
433
+ """
434
+ Decode target latent into target map.
435
+
436
+ Args:
437
+ target_latents (`torch.Tensor`):
438
+ Target latent to be decoded.
439
+
440
+ Returns:
441
+ `torch.Tensor`: Decoded target map.
442
+ """
443
+
444
+ assert target_latents.shape[1] == 12 # self.n_targets * 4
445
+
446
+ # scale latent
447
+ target_latents = target_latents / self.rgb_latent_scale_factor
448
+ # decode
449
+ targets = []
450
+ for i in range(self.n_targets):
451
+ latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
452
+ z = self.vae.post_quant_conv(latent)
453
+ stacked = self.vae.decoder(z)
454
+
455
+ targets.append(stacked)
456
+
457
+ return torch.cat(targets, dim=1)
458
+
459
+ @staticmethod
460
+ def get_pil_resample_method(method_str: str) -> Resampling:
461
+ resample_method_dic = {
462
+ "bilinear": Resampling.BILINEAR,
463
+ "bicubic": Resampling.BICUBIC,
464
+ "nearest": Resampling.NEAREST,
465
+ }
466
+ resample_method = resample_method_dic.get(method_str, None)
467
+ if resample_method is None:
468
+ raise ValueError(f"Unknown resampling method: {resample_method}")
469
+ else:
470
+ return resample_method
471
+
472
+ @staticmethod
473
+ def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
474
+ """
475
+ Resize image to limit maximum edge length while keeping aspect ratio.
476
+ """
477
+ original_width, original_height = img.size
478
+ downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
479
+
480
+ new_width = int(original_width * downscale_factor)
481
+ new_height = int(original_height * downscale_factor)
482
+
483
+ resized_img = img.resize((new_width, new_height), resample=resample_method)
484
+ return resized_img
485
+
486
+ @staticmethod
487
+ def chw2hwc(chw):
488
+ assert 3 == len(chw.shape)
489
+ if isinstance(chw, torch.Tensor):
490
+ hwc = torch.permute(chw, (1, 2, 0))
491
+ elif isinstance(chw, np.ndarray):
492
+ hwc = np.moveaxis(chw, 0, -1)
493
+ return hwc
494
+
495
+ @staticmethod
496
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
497
+ """
498
+ Automatically search for suitable operating batch size.
499
+
500
+ Args:
501
+ ensemble_size (`int`):
502
+ Number of predictions to be ensembled.
503
+ input_res (`int`):
504
+ Operating resolution of the input image.
505
+
506
+ Returns:
507
+ `int`: Operating batch size.
508
+ """
509
+ # Search table for suggested max. inference batch size
510
+ bs_search_table = [
511
+ # tested on A100-PCIE-80GB
512
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
513
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
514
+ # tested on A100-PCIE-40GB
515
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
516
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
517
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
518
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
519
+ # tested on RTX3090, RTX4090
520
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
521
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
522
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
523
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
524
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
525
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
526
+ # tested on GTX1080Ti
527
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
528
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
529
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
530
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
531
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
532
+ ]
533
+
534
+ if not torch.cuda.is_available():
535
+ return 1
536
+
537
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
538
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
539
+ for settings in sorted(
540
+ filtered_bs_search_table,
541
+ key=lambda k: (k["res"], -k["total_vram"]),
542
+ ):
543
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
544
+ bs = settings["bs"]
545
+ if bs > ensemble_size:
546
+ bs = ensemble_size
547
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
548
+ bs = math.ceil(ensemble_size / 2)
549
+ return bs
550
+
551
+ return 1
552
+
requirements.txt ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ altair==5.3.0
6
+ annotated-types==0.6.0
7
+ anyio==4.3.0
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ Authlib==1.3.0
11
+ certifi==2024.2.2
12
+ cffi==1.16.0
13
+ charset-normalizer==3.3.2
14
+ click==8.0.4
15
+ cmake==3.29.0.1
16
+ contourpy==1.2.0
17
+ cryptography==42.0.5
18
+ cycler==0.12.1
19
+ dataclasses-json==0.6.4
20
+ datasets==2.18.0
21
+ Deprecated==1.2.14
22
+ diffusers==0.27.2
23
+ dill==0.3.8
24
+ exceptiongroup==1.2.0
25
+ fastapi==0.110.0
26
+ ffmpy==0.3.2
27
+ filelock==3.13.3
28
+ fonttools==4.50.0
29
+ frozenlist==1.4.1
30
+ fsspec==2024.2.0
31
+ gradio==4.21.0
32
+ gradio_client==0.12.0
33
+ gradio_imageslider==0.0.18
34
+ h11==0.14.0
35
+ httpcore==1.0.5
36
+ httpx==0.27.0
37
+ huggingface-hub==0.22.1
38
+ idna==3.6
39
+ imageio==2.34.0
40
+ imageio-ffmpeg==0.4.9
41
+ importlib_metadata==7.1.0
42
+ importlib_resources==6.4.0
43
+ itsdangerous==2.1.2
44
+ Jinja2==3.1.3
45
+ jsonschema==4.21.1
46
+ jsonschema-specifications==2023.12.1
47
+ kiwisolver==1.4.5
48
+ lit==18.1.2
49
+ markdown-it-py==3.0.0
50
+ MarkupSafe==2.1.5
51
+ marshmallow==3.21.1
52
+ matplotlib==3.8.2
53
+ mdurl==0.1.2
54
+ mpmath==1.3.0
55
+ multidict==6.0.5
56
+ multiprocess==0.70.16
57
+ mypy-extensions==1.0.0
58
+ networkx==3.2.1
59
+ numpy==1.26.4
60
+ nvidia-cublas-cu11==11.10.3.66
61
+ nvidia-cuda-cupti-cu11==11.7.101
62
+ nvidia-cuda-nvrtc-cu11==11.7.99
63
+ nvidia-cuda-runtime-cu11==11.7.99
64
+ nvidia-cudnn-cu11==8.5.0.96
65
+ nvidia-cufft-cu11==10.9.0.58
66
+ nvidia-curand-cu11==10.2.10.91
67
+ nvidia-cusolver-cu11==11.4.0.1
68
+ nvidia-cusparse-cu11==11.7.4.91
69
+ nvidia-nccl-cu11==2.14.3
70
+ nvidia-nvtx-cu11==11.7.91
71
+ orjson==3.10.0
72
+ packaging==24.0
73
+ pandas==2.2.1
74
+ pillow==10.2.0
75
+ protobuf==3.20.3
76
+ psutil==5.9.8
77
+ pyarrow==15.0.2
78
+ pyarrow-hotfix==0.6
79
+ pycparser==2.22
80
+ pydantic==2.6.4
81
+ pydantic_core==2.16.3
82
+ pydub==0.25.1
83
+ pygltflib==1.16.1
84
+ Pygments==2.17.2
85
+ pyparsing==3.1.2
86
+ python-dateutil==2.9.0.post0
87
+ python-multipart==0.0.9
88
+ pytz==2024.1
89
+ PyYAML==6.0.1
90
+ referencing==0.34.0
91
+ regex==2023.12.25
92
+ requests==2.31.0
93
+ rich==13.7.1
94
+ rpds-py==0.18.0
95
+ ruff==0.3.4
96
+ safetensors==0.4.2
97
+ scipy==1.11.4
98
+ semantic-version==2.10.0
99
+ shellingham==1.5.4
100
+ six==1.16.0
101
+ sniffio==1.3.1
102
+ spaces==0.25.0
103
+ starlette==0.36.3
104
+ sympy==1.12
105
+ tokenizers==0.15.2
106
+ tomlkit==0.12.0
107
+ toolz==0.12.1
108
+ torch==2.0.1
109
+ tqdm==4.66.2
110
+ transformers==4.36.1
111
+ trimesh==4.0.5
112
+ triton==2.0.0
113
+ typer==0.12.0
114
+ typer-cli==0.12.0
115
+ typer-slim==0.12.0
116
+ typing-inspect==0.9.0
117
+ typing_extensions==4.10.0
118
+ tzdata==2024.1
119
+ urllib3==2.2.1
120
+ uvicorn==0.29.0
121
+ websockets==11.0.3
122
+ wrapt==1.16.0
123
+ xformers==0.0.21
124
+ xxhash==3.4.1
125
+ yarl==1.9.4
126
+ zipp==3.18.1
requirements_min.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.21.0
2
+ gradio-imageslider==0.0.18
3
+ pygltflib==1.16.1
4
+ trimesh==4.0.5
5
+ imageio
6
+ imageio-ffmpeg
7
+ Pillow
8
+
9
+ spaces==0.25.0
10
+ accelerate==0.25.0
11
+ diffusers==0.27.2
12
+ matplotlib==3.8.2
13
+ scipy==1.11.4
14
+ torch==2.0.1
15
+ transformers==4.36.1
16
+ xformers==0.0.21