Spaces:
Sleeping
Sleeping
KevinQu7
commited on
Commit
•
09c3706
1
Parent(s):
641fe65
initial commit
Browse files- .gitignore +7 -0
- LICENSE.txt +177 -0
- app.py +639 -0
- marigold_iid_appearance.py +544 -0
- marigold_iid_residual.py +552 -0
- requirements.txt +126 -0
- requirements_min.txt +16 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
.DS_Store
|
3 |
+
__pycache__
|
4 |
+
gradio_cached_examples
|
5 |
+
Marigold
|
6 |
+
*.sh
|
7 |
+
script/
|
LICENSE.txt
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
app.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
from __future__ import annotations
|
20 |
+
|
21 |
+
import functools
|
22 |
+
import os
|
23 |
+
import tempfile
|
24 |
+
import warnings
|
25 |
+
|
26 |
+
import spaces
|
27 |
+
import gradio as gr
|
28 |
+
import numpy as np
|
29 |
+
import torch as torch
|
30 |
+
from PIL import Image
|
31 |
+
from diffusers import UNet2DConditionModel
|
32 |
+
|
33 |
+
from gradio_imageslider import ImageSlider
|
34 |
+
from huggingface_hub import login
|
35 |
+
|
36 |
+
from gradio_patches.examples import Examples
|
37 |
+
from gradio_patches.flagging import HuggingFaceDatasetSaver, FlagMethod
|
38 |
+
from marigold_iid_appearance import MarigoldIIDAppearancePipeline
|
39 |
+
from marigold_iid_residual import MarigoldIIDResidualPipeline
|
40 |
+
|
41 |
+
warnings.filterwarnings(
|
42 |
+
"ignore", message=".*LoginButton created outside of a Blocks context.*"
|
43 |
+
)
|
44 |
+
|
45 |
+
default_seed = 2024
|
46 |
+
|
47 |
+
default_image_denoise_steps = 4
|
48 |
+
default_image_ensemble_size = 1
|
49 |
+
default_image_processing_res = 768
|
50 |
+
default_image_reproducuble = True
|
51 |
+
default_model_type="appearance"
|
52 |
+
|
53 |
+
default_share_always_show_hf_logout_btn = True
|
54 |
+
default_share_always_show_accordion = False
|
55 |
+
|
56 |
+
loaded_pipelines = {} # Cache to store loaded pipelines
|
57 |
+
def process_with_loaded_pipeline(image_path, denoise_steps, ensemble_size, processing_res, model_type):
|
58 |
+
|
59 |
+
# Load and cache the pipeline based on the model type.
|
60 |
+
if model_type not in loaded_pipelines:
|
61 |
+
auth_token = os.environ.get("KEV_TOKEN")
|
62 |
+
if model_type == "appearance":
|
63 |
+
loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
|
64 |
+
"prs-eth/marigold-iid-appearance-v1-1", token=auth_token
|
65 |
+
)
|
66 |
+
elif model_type == "residual":
|
67 |
+
loaded_pipelines[model_type] = MarigoldIIDResidualPipeline.from_pretrained(
|
68 |
+
"prs-eth/marigold-iid-residual-v1-1", token=auth_token
|
69 |
+
)
|
70 |
+
|
71 |
+
# Move the pipeline to GPU if available
|
72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
73 |
+
loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
|
74 |
+
|
75 |
+
pipe = loaded_pipelines[model_type]
|
76 |
+
|
77 |
+
# Process the image using the preloaded pipeline.
|
78 |
+
return process_image(
|
79 |
+
pipe=pipe,
|
80 |
+
path_input=image_path,
|
81 |
+
denoise_steps=denoise_steps,
|
82 |
+
ensemble_size=ensemble_size,
|
83 |
+
processing_res=processing_res,
|
84 |
+
model_type=model_type,
|
85 |
+
)
|
86 |
+
|
87 |
+
def process_image_check(path_input):
|
88 |
+
if path_input is None:
|
89 |
+
raise gr.Error(
|
90 |
+
"Missing image in the first pane: upload a file or use one from the gallery below."
|
91 |
+
)
|
92 |
+
|
93 |
+
def process_image(
|
94 |
+
pipe,
|
95 |
+
path_input,
|
96 |
+
denoise_steps=default_image_denoise_steps,
|
97 |
+
ensemble_size=default_image_ensemble_size,
|
98 |
+
processing_res=default_image_processing_res,
|
99 |
+
model_type=default_model_type,
|
100 |
+
):
|
101 |
+
name_base, name_ext = os.path.splitext(os.path.basename(path_input))
|
102 |
+
print(f"Processing image {name_base}{name_ext}")
|
103 |
+
|
104 |
+
path_output_dir = tempfile.mkdtemp()
|
105 |
+
|
106 |
+
input_image = Image.open(path_input)
|
107 |
+
|
108 |
+
|
109 |
+
pipe_out = pipe(
|
110 |
+
input_image,
|
111 |
+
denoising_steps=denoise_steps,
|
112 |
+
ensemble_size=ensemble_size,
|
113 |
+
processing_res=processing_res,
|
114 |
+
batch_size=1 if processing_res == 0 else 0, # TODO: do we abuse "batch size" notation here?
|
115 |
+
seed=default_seed,
|
116 |
+
show_progress_bar=True,
|
117 |
+
)
|
118 |
+
|
119 |
+
path_output_dir = os.path.splitext(path_input)[0] + "_output"
|
120 |
+
os.makedirs(path_output_dir, exist_ok=True)
|
121 |
+
|
122 |
+
path_albedo_out = os.path.join(path_output_dir, f"{name_base}_albedo_fp32.npy")
|
123 |
+
path_albedo_out_vis = os.path.join(path_output_dir, f"{name_base}_albedo.png")
|
124 |
+
|
125 |
+
albedo = pipe_out.albedo
|
126 |
+
albedo_colored = pipe_out.albedo_colored
|
127 |
+
|
128 |
+
np.save(path_albedo_out, albedo)
|
129 |
+
albedo_colored.save(path_albedo_out_vis)
|
130 |
+
|
131 |
+
|
132 |
+
if model_type == "appearance":
|
133 |
+
path_material_out = os.path.join(path_output_dir, f"{name_base}_material_fp32.npy")
|
134 |
+
path_material_out_vis = os.path.join(path_output_dir, f"{name_base}_material.png")
|
135 |
+
|
136 |
+
material = pipe_out.material
|
137 |
+
material_colored = pipe_out.material_colored
|
138 |
+
|
139 |
+
np.save(path_material_out, material)
|
140 |
+
material_colored.save(path_material_out_vis)
|
141 |
+
|
142 |
+
return (
|
143 |
+
[path_input, path_albedo_out_vis],
|
144 |
+
[path_input, path_material_out_vis],
|
145 |
+
None,
|
146 |
+
[path_albedo_out_vis, path_material_out_vis, path_albedo_out, path_material_out],
|
147 |
+
)
|
148 |
+
|
149 |
+
elif model_type == "residual":
|
150 |
+
path_shading_out = os.path.join(path_output_dir, f"{name_base}_shading_fp32.npy")
|
151 |
+
path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png")
|
152 |
+
path_residual_out = os.path.join(path_output_dir, f"{name_base}_residual_fp32.npy")
|
153 |
+
path_residual_out_vis = os.path.join(path_output_dir, f"{name_base}_residual.png")
|
154 |
+
|
155 |
+
shading = pipe_out.shading
|
156 |
+
shading_colored = pipe_out.shading_colored
|
157 |
+
residual = pipe_out.residual
|
158 |
+
residual_colored = pipe_out.residual_colored
|
159 |
+
|
160 |
+
np.save(path_shading_out, shading)
|
161 |
+
shading_colored.save(path_shading_out_vis)
|
162 |
+
np.save(path_residual_out, residual)
|
163 |
+
residual_colored.save(path_residual_out_vis)
|
164 |
+
|
165 |
+
return (
|
166 |
+
[path_input, path_albedo_out_vis],
|
167 |
+
[path_input, path_shading_out_vis],
|
168 |
+
[path_input, path_residual_out_vis],
|
169 |
+
[path_albedo_out_vis, path_shading_out_vis, path_residual_out_vis, path_albedo_out, path_shading_out, path_residual_out],
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
def run_demo_server(hf_writer=None):
|
174 |
+
process_pipe_image = spaces.GPU(functools.partial(process_with_loaded_pipeline), duration=120)
|
175 |
+
gradio_theme = gr.themes.Default()
|
176 |
+
|
177 |
+
with gr.Blocks(
|
178 |
+
theme=gradio_theme,
|
179 |
+
title="Marigold Intrinsic Image Decomposition (Marigold-IID)",
|
180 |
+
css="""
|
181 |
+
#download {
|
182 |
+
height: 118px;
|
183 |
+
}
|
184 |
+
.slider .inner {
|
185 |
+
width: 5px;
|
186 |
+
background: #FFF;
|
187 |
+
}
|
188 |
+
.viewport {
|
189 |
+
aspect-ratio: 4/3;
|
190 |
+
}
|
191 |
+
.tabs button.selected {
|
192 |
+
font-size: 20px !important;
|
193 |
+
color: crimson !important;
|
194 |
+
}
|
195 |
+
h1 {
|
196 |
+
text-align: center;
|
197 |
+
display: block;
|
198 |
+
}
|
199 |
+
h2 {
|
200 |
+
text-align: center;
|
201 |
+
display: block;
|
202 |
+
}
|
203 |
+
h3 {
|
204 |
+
text-align: center;
|
205 |
+
display: block;
|
206 |
+
}
|
207 |
+
.md_feedback li {
|
208 |
+
margin-bottom: 0px !important;
|
209 |
+
}
|
210 |
+
""",
|
211 |
+
head="""
|
212 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
|
213 |
+
<script>
|
214 |
+
window.dataLayer = window.dataLayer || [];
|
215 |
+
function gtag() {dataLayer.push(arguments);}
|
216 |
+
gtag('js', new Date());
|
217 |
+
gtag('config', 'G-1FWSVCGZTG');
|
218 |
+
</script>
|
219 |
+
""",
|
220 |
+
) as demo:
|
221 |
+
if hf_writer is not None:
|
222 |
+
print("Creating login button")
|
223 |
+
share_login_btn = gr.LoginButton(size="sm", scale=1, render=False)
|
224 |
+
print("Created login button")
|
225 |
+
share_login_btn.activate()
|
226 |
+
print("Activated login button")
|
227 |
+
|
228 |
+
gr.Markdown(
|
229 |
+
"""
|
230 |
+
# Marigold Normals Estimation
|
231 |
+
|
232 |
+
<p align="center">
|
233 |
+
<a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
234 |
+
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
235 |
+
</a>
|
236 |
+
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
237 |
+
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
238 |
+
</a>
|
239 |
+
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
240 |
+
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
|
241 |
+
</a>
|
242 |
+
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
243 |
+
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
|
244 |
+
</a>
|
245 |
+
</p>
|
246 |
+
"""
|
247 |
+
)
|
248 |
+
|
249 |
+
def get_share_instructions(is_full):
|
250 |
+
out = (
|
251 |
+
"### Help us improve Marigold! If the output is not what you expected, "
|
252 |
+
"you can help us by sharing it with us privately.\n"
|
253 |
+
)
|
254 |
+
if is_full:
|
255 |
+
out += (
|
256 |
+
"1. Sign into your Hugging Face account using the button below.\n"
|
257 |
+
"1. Signing in may reset the demo and results; in that case, process the image again.\n"
|
258 |
+
)
|
259 |
+
out += "1. Review and agree to the terms of usage and enter an optional message to us.\n"
|
260 |
+
out += "1. Click the 'Share' button to submit the image to us privately.\n"
|
261 |
+
return out
|
262 |
+
|
263 |
+
def get_share_conditioned_on_login(profile: gr.OAuthProfile | None):
|
264 |
+
state_logged_out = profile is None
|
265 |
+
return get_share_instructions(is_full=state_logged_out), gr.Button(
|
266 |
+
visible=(state_logged_out or default_share_always_show_hf_logout_btn)
|
267 |
+
)
|
268 |
+
|
269 |
+
with gr.Row():
|
270 |
+
with gr.Column():
|
271 |
+
image_input = gr.Image(
|
272 |
+
label="Input Image",
|
273 |
+
type="filepath",
|
274 |
+
)
|
275 |
+
model_type = gr.Radio(
|
276 |
+
[
|
277 |
+
("Appearance (Albedo & Material)", "appearance"),
|
278 |
+
("Residual (Albedo, Shading & Residual)", "residual"),
|
279 |
+
],
|
280 |
+
label="Model Type",
|
281 |
+
value=default_model_type,
|
282 |
+
)
|
283 |
+
|
284 |
+
with gr.Accordion("Advanced options", open=True):
|
285 |
+
image_ensemble_size = gr.Slider(
|
286 |
+
label="Ensemble size",
|
287 |
+
minimum=1,
|
288 |
+
maximum=10,
|
289 |
+
step=1,
|
290 |
+
value=default_image_ensemble_size,
|
291 |
+
)
|
292 |
+
image_denoise_steps = gr.Slider(
|
293 |
+
label="Number of denoising steps",
|
294 |
+
minimum=1,
|
295 |
+
maximum=20,
|
296 |
+
step=1,
|
297 |
+
value=default_image_denoise_steps,
|
298 |
+
)
|
299 |
+
image_processing_res = gr.Radio(
|
300 |
+
[
|
301 |
+
("Native", 0),
|
302 |
+
("Recommended", 768),
|
303 |
+
],
|
304 |
+
label="Processing resolution",
|
305 |
+
value=default_image_processing_res,
|
306 |
+
)
|
307 |
+
with gr.Row():
|
308 |
+
image_submit_btn = gr.Button(value="Compute Normals", variant="primary")
|
309 |
+
image_reset_btn = gr.Button(value="Reset")
|
310 |
+
with gr.Column():
|
311 |
+
image_output_slider1 = ImageSlider(
|
312 |
+
label="Predicted Albedo",
|
313 |
+
type="filepath",
|
314 |
+
show_download_button=True,
|
315 |
+
show_share_button=True,
|
316 |
+
interactive=False,
|
317 |
+
elem_classes="slider",
|
318 |
+
position=0.25,
|
319 |
+
visible=True
|
320 |
+
)
|
321 |
+
image_output_slider2 = ImageSlider(
|
322 |
+
label="Predicted Material",
|
323 |
+
type="filepath",
|
324 |
+
show_download_button=True,
|
325 |
+
show_share_button=True,
|
326 |
+
interactive=False,
|
327 |
+
elem_classes="slider",
|
328 |
+
position=0.25,
|
329 |
+
visible=True
|
330 |
+
)
|
331 |
+
image_output_slider3 = ImageSlider(
|
332 |
+
label="Predicted Residual",
|
333 |
+
type="filepath",
|
334 |
+
show_download_button=True,
|
335 |
+
show_share_button=True,
|
336 |
+
interactive=False,
|
337 |
+
elem_classes="slider",
|
338 |
+
position=0.25,
|
339 |
+
visible=False
|
340 |
+
)
|
341 |
+
image_output_files = gr.Files(
|
342 |
+
label="Output files",
|
343 |
+
elem_id="download",
|
344 |
+
interactive=False,
|
345 |
+
)
|
346 |
+
|
347 |
+
if hf_writer is not None:
|
348 |
+
with gr.Accordion(
|
349 |
+
"Feedback",
|
350 |
+
open=False,
|
351 |
+
visible=default_share_always_show_accordion,
|
352 |
+
) as share_box:
|
353 |
+
share_instructions = gr.Markdown(
|
354 |
+
get_share_instructions(is_full=True),
|
355 |
+
elem_classes="md_feedback",
|
356 |
+
)
|
357 |
+
share_transfer_of_rights = gr.Checkbox(
|
358 |
+
label="(Optional) I own or hold necessary rights to the submitted image. By "
|
359 |
+
"checking this box, I grant an irrevocable, non-exclusive, transferable, "
|
360 |
+
"royalty-free, worldwide license to use the uploaded image, including for "
|
361 |
+
"publishing, reproducing, and model training. [transfer_of_rights]",
|
362 |
+
scale=1,
|
363 |
+
)
|
364 |
+
share_content_is_legal = gr.Checkbox(
|
365 |
+
label="By checking this box, I acknowledge that my uploaded content is legal and "
|
366 |
+
"safe, and that I am solely responsible for ensuring it complies with all "
|
367 |
+
"applicable laws and regulations. Additionally, I am aware that my Hugging Face "
|
368 |
+
"username is collected. [content_is_legal]",
|
369 |
+
scale=1,
|
370 |
+
)
|
371 |
+
share_reason = gr.Textbox(
|
372 |
+
label="(Optional) Reason for feedback",
|
373 |
+
max_lines=1,
|
374 |
+
interactive=True,
|
375 |
+
)
|
376 |
+
with gr.Row():
|
377 |
+
share_login_btn.render()
|
378 |
+
share_share_btn = gr.Button(
|
379 |
+
"Share", variant="stop", scale=1
|
380 |
+
)
|
381 |
+
|
382 |
+
# Function to toggle visibility and set dynamic labels
|
383 |
+
def toggle_sliders_and_labels(model_type):
|
384 |
+
if model_type == "appearance":
|
385 |
+
return (
|
386 |
+
gr.update(visible=True, label="Predicted Albedo"),
|
387 |
+
gr.update(visible=True, label="Predicted Material"),
|
388 |
+
gr.update(visible=False), # Hide third slider
|
389 |
+
)
|
390 |
+
elif model_type == "residual":
|
391 |
+
return (
|
392 |
+
gr.update(visible=True, label="Predicted Albedo"),
|
393 |
+
gr.update(visible=True, label="Predicted Shading"),
|
394 |
+
gr.update(visible=True, label="Predicted Residual"),
|
395 |
+
)
|
396 |
+
|
397 |
+
# Attach the change event to update sliders
|
398 |
+
model_type.change(
|
399 |
+
fn=toggle_sliders_and_labels,
|
400 |
+
inputs=[model_type],
|
401 |
+
outputs=[image_output_slider1, image_output_slider2, image_output_slider3],
|
402 |
+
show_progress=False,
|
403 |
+
)
|
404 |
+
|
405 |
+
Examples(
|
406 |
+
fn=process_pipe_image,
|
407 |
+
examples=[
|
408 |
+
os.path.join("files", "image", name)
|
409 |
+
for name in [
|
410 |
+
"berries.jpeg",
|
411 |
+
"costumes.png",
|
412 |
+
"cat.jpg",
|
413 |
+
"einstein.jpg",
|
414 |
+
"food.jpeg",
|
415 |
+
"food_counter.png",
|
416 |
+
"puzzle.jpeg",
|
417 |
+
"rocket.png",
|
418 |
+
"scientists.jpg",
|
419 |
+
"cat2.png",
|
420 |
+
"screw.png",
|
421 |
+
"statues.png",
|
422 |
+
"swings.jpg"
|
423 |
+
]
|
424 |
+
],
|
425 |
+
inputs=[image_input],
|
426 |
+
outputs= [
|
427 |
+
image_output_slider1,
|
428 |
+
image_output_slider2,
|
429 |
+
image_output_slider3,
|
430 |
+
image_output_files
|
431 |
+
],
|
432 |
+
cache_examples=False, # TODO: toggle later
|
433 |
+
directory_name="examples_image",
|
434 |
+
)
|
435 |
+
|
436 |
+
### Image tab
|
437 |
+
|
438 |
+
if hf_writer is not None:
|
439 |
+
image_submit_btn.click(
|
440 |
+
fn=process_image_check,
|
441 |
+
inputs=image_input,
|
442 |
+
outputs=None,
|
443 |
+
preprocess=False,
|
444 |
+
queue=False,
|
445 |
+
).success(
|
446 |
+
get_share_conditioned_on_login,
|
447 |
+
None,
|
448 |
+
[share_instructions, share_login_btn],
|
449 |
+
queue=False,
|
450 |
+
).then(
|
451 |
+
lambda: (
|
452 |
+
gr.Button(value="Share", interactive=True),
|
453 |
+
gr.Accordion(visible=True),
|
454 |
+
False,
|
455 |
+
False,
|
456 |
+
"",
|
457 |
+
),
|
458 |
+
None,
|
459 |
+
[
|
460 |
+
share_share_btn,
|
461 |
+
share_box,
|
462 |
+
share_transfer_of_rights,
|
463 |
+
share_content_is_legal,
|
464 |
+
share_reason,
|
465 |
+
],
|
466 |
+
queue=False,
|
467 |
+
).then(
|
468 |
+
fn=process_pipe_image,
|
469 |
+
inputs=[
|
470 |
+
image_input,
|
471 |
+
image_denoise_steps,
|
472 |
+
image_ensemble_size,
|
473 |
+
image_processing_res,
|
474 |
+
model_type
|
475 |
+
],
|
476 |
+
outputs= [
|
477 |
+
image_output_slider1,
|
478 |
+
image_output_slider2,
|
479 |
+
image_output_slider3,
|
480 |
+
image_output_files
|
481 |
+
],
|
482 |
+
concurrency_limit=1,
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
image_submit_btn.click(
|
486 |
+
fn=process_image_check,
|
487 |
+
inputs=image_input,
|
488 |
+
outputs=None,
|
489 |
+
preprocess=False,
|
490 |
+
queue=False,
|
491 |
+
).success(
|
492 |
+
fn=process_pipe_image,
|
493 |
+
inputs=[
|
494 |
+
image_input,
|
495 |
+
image_denoise_steps,
|
496 |
+
image_ensemble_size,
|
497 |
+
image_processing_res,
|
498 |
+
model_type
|
499 |
+
],
|
500 |
+
outputs= [
|
501 |
+
image_output_slider1,
|
502 |
+
image_output_slider2,
|
503 |
+
image_output_slider3,
|
504 |
+
image_output_files
|
505 |
+
],
|
506 |
+
concurrency_limit=1,
|
507 |
+
)
|
508 |
+
|
509 |
+
image_reset_btn.click(
|
510 |
+
fn=lambda: (
|
511 |
+
None,
|
512 |
+
None,
|
513 |
+
None,
|
514 |
+
default_image_ensemble_size,
|
515 |
+
default_image_denoise_steps,
|
516 |
+
default_image_processing_res,
|
517 |
+
),
|
518 |
+
inputs=[],
|
519 |
+
outputs=[
|
520 |
+
image_input,
|
521 |
+
image_output_slider1,
|
522 |
+
image_output_slider2,
|
523 |
+
image_output_slider3,
|
524 |
+
image_output_files,
|
525 |
+
image_ensemble_size,
|
526 |
+
image_denoise_steps,
|
527 |
+
image_processing_res,
|
528 |
+
],
|
529 |
+
queue=False,
|
530 |
+
)
|
531 |
+
|
532 |
+
if hf_writer is not None:
|
533 |
+
image_reset_btn.click(
|
534 |
+
fn=lambda: (
|
535 |
+
gr.Button(value="Share", interactive=True),
|
536 |
+
gr.Accordion(visible=default_share_always_show_accordion),
|
537 |
+
),
|
538 |
+
inputs=[],
|
539 |
+
outputs=[
|
540 |
+
share_share_btn,
|
541 |
+
share_box,
|
542 |
+
],
|
543 |
+
queue=False,
|
544 |
+
)
|
545 |
+
|
546 |
+
### Share functionality
|
547 |
+
|
548 |
+
if hf_writer is not None:
|
549 |
+
share_components = [
|
550 |
+
image_input,
|
551 |
+
image_denoise_steps,
|
552 |
+
image_ensemble_size,
|
553 |
+
image_processing_res,
|
554 |
+
image_output_slider1,
|
555 |
+
image_output_slider2,
|
556 |
+
image_output_slider3,
|
557 |
+
share_content_is_legal,
|
558 |
+
share_transfer_of_rights,
|
559 |
+
share_reason,
|
560 |
+
]
|
561 |
+
|
562 |
+
hf_writer.setup(share_components, "shared_data")
|
563 |
+
share_callback = FlagMethod(hf_writer, "Share", "", visual_feedback=True)
|
564 |
+
|
565 |
+
def share_precheck(
|
566 |
+
hf_content_is_legal,
|
567 |
+
image_output_slider,
|
568 |
+
profile: gr.OAuthProfile | None,
|
569 |
+
):
|
570 |
+
if profile is None:
|
571 |
+
raise gr.Error(
|
572 |
+
"Log into the Space with your Hugging Face account first."
|
573 |
+
)
|
574 |
+
if image_output_slider is None or image_output_slider[0] is None:
|
575 |
+
raise gr.Error("No output detected; process the image first.")
|
576 |
+
if not hf_content_is_legal:
|
577 |
+
raise gr.Error(
|
578 |
+
"You must consent that the uploaded content is legal."
|
579 |
+
)
|
580 |
+
return gr.Button(value="Sharing in progress", interactive=False)
|
581 |
+
|
582 |
+
share_share_btn.click(
|
583 |
+
share_precheck,
|
584 |
+
[share_content_is_legal, image_output_slider1],
|
585 |
+
share_share_btn,
|
586 |
+
preprocess=False,
|
587 |
+
queue=False,
|
588 |
+
).success(
|
589 |
+
share_callback,
|
590 |
+
inputs=share_components,
|
591 |
+
outputs=share_share_btn,
|
592 |
+
preprocess=False,
|
593 |
+
queue=False,
|
594 |
+
)
|
595 |
+
|
596 |
+
demo.queue(
|
597 |
+
api_open=False,
|
598 |
+
).launch(
|
599 |
+
server_name="0.0.0.0",
|
600 |
+
server_port=7860,
|
601 |
+
)
|
602 |
+
|
603 |
+
|
604 |
+
def main():
|
605 |
+
CHECKPOINT = "prs-eth/marigold-iid-appearance-v1-1"
|
606 |
+
CROWD_DATA = "crowddata-marigold-iid-appearance-v1-1-space-v1-1"
|
607 |
+
|
608 |
+
os.system("pip freeze")
|
609 |
+
|
610 |
+
if "HF_TOKEN_LOGIN" in os.environ:
|
611 |
+
login(token=os.environ["HF_TOKEN_LOGIN"])
|
612 |
+
|
613 |
+
auth_token = os.environ.get("KEV_TOKEN")
|
614 |
+
pipe = MarigoldIIDAppearancePipeline.from_pretrained(CHECKPOINT,token=auth_token)
|
615 |
+
try:
|
616 |
+
import xformers
|
617 |
+
|
618 |
+
pipe.enable_xformers_memory_efficient_attention()
|
619 |
+
except:
|
620 |
+
pass # run without xformers
|
621 |
+
|
622 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
623 |
+
pipe = pipe.to(device)
|
624 |
+
|
625 |
+
hf_writer = None
|
626 |
+
if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ:
|
627 |
+
hf_writer = HuggingFaceDatasetSaver(
|
628 |
+
os.getenv("HF_TOKEN_LOGIN_WRITE_CROWD"),
|
629 |
+
CROWD_DATA,
|
630 |
+
private=True,
|
631 |
+
info_filename="dataset_info.json",
|
632 |
+
separate_dirs=True,
|
633 |
+
)
|
634 |
+
|
635 |
+
run_demo_server(hf_writer)
|
636 |
+
|
637 |
+
|
638 |
+
if __name__ == "__main__":
|
639 |
+
main()
|
marigold_iid_appearance.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Anton Obukhov, Bingxin Ke, Bo Li & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldcomputervision.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
import logging
|
20 |
+
import math
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from diffusers import (
|
26 |
+
AutoencoderKL,
|
27 |
+
DDIMScheduler,
|
28 |
+
DiffusionPipeline,
|
29 |
+
UNet2DConditionModel,
|
30 |
+
)
|
31 |
+
from diffusers.utils import BaseOutput, check_min_version
|
32 |
+
from PIL import Image
|
33 |
+
from PIL.Image import Resampling
|
34 |
+
from torch.utils.data import DataLoader, TensorDataset
|
35 |
+
from tqdm.auto import tqdm
|
36 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
37 |
+
|
38 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
39 |
+
check_min_version("0.27.0.dev0")
|
40 |
+
|
41 |
+
class MarigoldIIDAppearanceOutput(BaseOutput):
|
42 |
+
"""
|
43 |
+
Output class for Marigold IID Appearance pipeline.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
albedo (`np.ndarray`):
|
47 |
+
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
|
48 |
+
albedo_colored (`PIL.Image.Image`):
|
49 |
+
Colorized albedo map with the shape of [H, W, 3].
|
50 |
+
material (`np.ndarray`):
|
51 |
+
Predicted material map with the shape of [3, H, W] and values in [0, 1].
|
52 |
+
1st channel (Red) is roughness
|
53 |
+
2nd channel (Green) is metallicity
|
54 |
+
3rd channel (Blue) is empty (zero)
|
55 |
+
material_colored (`PIL.Image.Image`):
|
56 |
+
Colorized material map with the shape of [H, W, 3].
|
57 |
+
1st channel (Red) is roughness
|
58 |
+
2nd channel (Green) is metallicity
|
59 |
+
3rd channel (Blue) is empty (zero)
|
60 |
+
"""
|
61 |
+
|
62 |
+
albedo: np.ndarray
|
63 |
+
albedo_colored: Image.Image
|
64 |
+
material: np.ndarray
|
65 |
+
material_colored: Image.Image
|
66 |
+
|
67 |
+
class MarigoldIIDAppearancePipeline(DiffusionPipeline):
|
68 |
+
"""
|
69 |
+
Pipeline for Intrinsic Image Decomposition (Albedo and Material) using Marigold: https://marigoldcomputervision.github.io.
|
70 |
+
|
71 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
72 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
73 |
+
|
74 |
+
Args:
|
75 |
+
unet (`UNet2DConditionModel`):
|
76 |
+
Conditional U-Net to denoise the normals latent, conditioned on image latent.
|
77 |
+
vae (`AutoencoderKL`):
|
78 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
|
79 |
+
to and from latent representations.
|
80 |
+
scheduler (`DDIMScheduler`):
|
81 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
82 |
+
text_encoder (`CLIPTextModel`):
|
83 |
+
Text-encoder, for empty text embedding.
|
84 |
+
tokenizer (`CLIPTokenizer`):
|
85 |
+
CLIP tokenizer.
|
86 |
+
"""
|
87 |
+
|
88 |
+
latent_scale_factor = 0.18215
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
unet: UNet2DConditionModel,
|
93 |
+
vae: AutoencoderKL,
|
94 |
+
scheduler: DDIMScheduler,
|
95 |
+
text_encoder: CLIPTextModel,
|
96 |
+
tokenizer: CLIPTokenizer,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.register_modules(
|
101 |
+
unet=unet,
|
102 |
+
vae=vae,
|
103 |
+
scheduler=scheduler,
|
104 |
+
text_encoder=text_encoder,
|
105 |
+
tokenizer=tokenizer,
|
106 |
+
)
|
107 |
+
|
108 |
+
self.empty_text_embed = None
|
109 |
+
|
110 |
+
self.n_targets = 2 # Albedo and material
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def __call__(
|
114 |
+
self,
|
115 |
+
input_image: Image,
|
116 |
+
denoising_steps: int = 4,
|
117 |
+
ensemble_size: int = 10,
|
118 |
+
processing_res: int = 768,
|
119 |
+
match_input_res: bool = True,
|
120 |
+
resample_method: str = "bilinear",
|
121 |
+
batch_size: int = 0,
|
122 |
+
save_memory: bool = False,
|
123 |
+
seed: Union[int, None] = None,
|
124 |
+
color_map: str = "Spectral", # TODO change colorization api based on modality
|
125 |
+
show_progress_bar: bool = True,
|
126 |
+
**kwargs,
|
127 |
+
) -> MarigoldIIDAppearanceOutput:
|
128 |
+
"""
|
129 |
+
Function invoked when calling the pipeline.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
input_image (`Image`):
|
133 |
+
Input RGB (or gray-scale) image.
|
134 |
+
denoising_steps (`int`, *optional*, defaults to `10`):
|
135 |
+
Number of diffusion denoising steps (DDIM) during inference.
|
136 |
+
ensemble_size (`int`, *optional*, defaults to `10`):
|
137 |
+
Number of predictions to be ensembled.
|
138 |
+
processing_res (`int`, *optional*, defaults to `768`):
|
139 |
+
Maximum resolution of processing.
|
140 |
+
If set to 0: will not resize at all.
|
141 |
+
match_input_res (`bool`, *optional*, defaults to `True`):
|
142 |
+
Resize normals prediction to match input resolution.
|
143 |
+
Only valid if `limit_input_res` is not None.
|
144 |
+
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
145 |
+
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
146 |
+
batch_size (`int`, *optional*, defaults to `0`):
|
147 |
+
Inference batch size, no bigger than `num_ensemble`.
|
148 |
+
If set to 0, the script will automatically decide the proper batch size.
|
149 |
+
save_memory (`bool`, defaults to `False`):
|
150 |
+
Extra steps to save memory at the cost of perforance.
|
151 |
+
seed (`int`, *optional*, defaults to `None`)
|
152 |
+
Reproducibility seed.
|
153 |
+
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
|
154 |
+
Colormap used to colorize the normals map.
|
155 |
+
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
156 |
+
Display a progress bar of diffusion denoising.
|
157 |
+
Returns:
|
158 |
+
`MarigoldIIDAppearanceOutput`: Output class for Marigold monocular intrinsic image decomposition (appearance) prediction pipeline, including:
|
159 |
+
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
160 |
+
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
161 |
+
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
|
162 |
+
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
|
163 |
+
"""
|
164 |
+
|
165 |
+
if not match_input_res:
|
166 |
+
assert processing_res is not None
|
167 |
+
assert processing_res >= 0
|
168 |
+
assert denoising_steps >= 1
|
169 |
+
assert ensemble_size >= 1
|
170 |
+
|
171 |
+
# Check if denoising step is reasonable
|
172 |
+
self.check_inference_step(denoising_steps)
|
173 |
+
|
174 |
+
resample_method: Resampling = self.get_pil_resample_method(resample_method)
|
175 |
+
|
176 |
+
W, H = input_image.size
|
177 |
+
|
178 |
+
if processing_res > 0:
|
179 |
+
input_image = self.resize_max_res(
|
180 |
+
input_image, max_edge_resolution=processing_res, resample_method=resample_method,
|
181 |
+
)
|
182 |
+
input_image = input_image.convert("RGB")
|
183 |
+
image = np.asarray(input_image)
|
184 |
+
|
185 |
+
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
186 |
+
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
187 |
+
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
188 |
+
rgb_norm = rgb_norm.to(self.device)
|
189 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
|
190 |
+
|
191 |
+
def ensemble(
|
192 |
+
targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median",
|
193 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
194 |
+
uncertainty = None
|
195 |
+
if reduction == "mean":
|
196 |
+
prediction = torch.mean(targets, dim=0, keepdim=True)
|
197 |
+
if return_uncertainty:
|
198 |
+
uncertainty = torch.std(targets, dim=0, keepdim=True)
|
199 |
+
elif reduction == "median":
|
200 |
+
prediction = torch.median(targets, dim=0, keepdim=True).values
|
201 |
+
if return_uncertainty:
|
202 |
+
uncertainty = torch.median(
|
203 |
+
torch.abs(targets - prediction), dim=0, keepdim=True
|
204 |
+
).values
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
207 |
+
return prediction, uncertainty
|
208 |
+
|
209 |
+
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
210 |
+
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
211 |
+
|
212 |
+
if batch_size <= 0:
|
213 |
+
batch_size = self.find_batch_size(
|
214 |
+
ensemble_size=ensemble_size,
|
215 |
+
input_res=max(rgb_norm.shape[1:]),
|
216 |
+
dtype=self.dtype,
|
217 |
+
)
|
218 |
+
|
219 |
+
single_rgb_loader = DataLoader(
|
220 |
+
single_rgb_dataset, batch_size=batch_size, shuffle=False
|
221 |
+
)
|
222 |
+
|
223 |
+
target_pred_ls = []
|
224 |
+
iterable = single_rgb_loader
|
225 |
+
if show_progress_bar:
|
226 |
+
iterable = tqdm(
|
227 |
+
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
228 |
+
)
|
229 |
+
|
230 |
+
for batch in iterable:
|
231 |
+
(batched_img,) = batch
|
232 |
+
target_pred = self.single_infer(
|
233 |
+
rgb_in=batched_img,
|
234 |
+
num_inference_steps=denoising_steps,
|
235 |
+
seed=seed,
|
236 |
+
show_pbar=show_progress_bar,
|
237 |
+
)
|
238 |
+
target_pred = target_pred.detach()
|
239 |
+
if save_memory:
|
240 |
+
target_pred = target_pred.cpu()
|
241 |
+
target_pred_ls.append(target_pred.detach())
|
242 |
+
|
243 |
+
target_preds = torch.concat(target_pred_ls, dim=0)
|
244 |
+
pred_uncert = None
|
245 |
+
|
246 |
+
if save_memory:
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
|
249 |
+
if ensemble_size > 1:
|
250 |
+
final_pred, pred_uncert = ensemble(
|
251 |
+
target_preds,
|
252 |
+
reduction = "median",
|
253 |
+
return_uncertainty=False
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
final_pred = target_preds
|
257 |
+
pred_uncert = None
|
258 |
+
|
259 |
+
if match_input_res:
|
260 |
+
final_pred = torch.nn.functional.interpolate(
|
261 |
+
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
|
262 |
+
) # [1,3,H,W]
|
263 |
+
|
264 |
+
if pred_uncert is not None:
|
265 |
+
pred_uncert = torch.nn.functional.interpolate(
|
266 |
+
pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
|
267 |
+
).squeeze(
|
268 |
+
1
|
269 |
+
) # [1,H,W]
|
270 |
+
|
271 |
+
# Convert to numpy
|
272 |
+
final_pred = final_pred.squeeze()
|
273 |
+
final_pred = final_pred.cpu().numpy()
|
274 |
+
|
275 |
+
albedo = final_pred[0:3, :, :]
|
276 |
+
material = np.stack(
|
277 |
+
(final_pred[3, :, :], final_pred[4, :, :], final_pred[5, :, :]), axis=0
|
278 |
+
)
|
279 |
+
|
280 |
+
albedo_colored = (albedo + 1.0) * 0.5
|
281 |
+
albedo_colored = (albedo_colored * 255).to(np.uint8)
|
282 |
+
albedo_colored = self.chw2hwc(albedo_colored)
|
283 |
+
albedo_colored_img = Image.fromarray(albedo_colored)
|
284 |
+
|
285 |
+
material_colored = (material + 1.0) * 0.5
|
286 |
+
material_colored = (material_colored * 255).to(np.uint8)
|
287 |
+
material_colored = self.chw2hwc(material_colored)
|
288 |
+
material_colored_img = Image.fromarray(material_colored)
|
289 |
+
|
290 |
+
out = MarigoldIIDAppearanceOutput(
|
291 |
+
albedo=albedo,
|
292 |
+
albedo_colored=albedo_colored_img,
|
293 |
+
material=material,
|
294 |
+
material_colored=material_colored_img
|
295 |
+
)
|
296 |
+
|
297 |
+
return out
|
298 |
+
|
299 |
+
def check_inference_step(self, n_step: int):
|
300 |
+
"""
|
301 |
+
Check if denoising step is reasonable
|
302 |
+
Args:
|
303 |
+
n_step (`int`): denoising steps
|
304 |
+
"""
|
305 |
+
assert n_step >= 1
|
306 |
+
|
307 |
+
if isinstance(self.scheduler, DDIMScheduler):
|
308 |
+
pass
|
309 |
+
else:
|
310 |
+
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
311 |
+
|
312 |
+
def encode_empty_text(self):
|
313 |
+
"""
|
314 |
+
Encode text embedding for empty prompt.
|
315 |
+
"""
|
316 |
+
prompt = ""
|
317 |
+
text_inputs = self.tokenizer(
|
318 |
+
prompt,
|
319 |
+
padding="do_not_pad",
|
320 |
+
max_length=self.tokenizer.model_max_length,
|
321 |
+
truncation=True,
|
322 |
+
return_tensors="pt",
|
323 |
+
)
|
324 |
+
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
325 |
+
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
326 |
+
|
327 |
+
@torch.no_grad()
|
328 |
+
def single_infer(
|
329 |
+
self,
|
330 |
+
rgb_in: torch.Tensor,
|
331 |
+
num_inference_steps: int,
|
332 |
+
seed: Union[int, None],
|
333 |
+
show_pbar: bool,
|
334 |
+
) -> torch.Tensor:
|
335 |
+
"""
|
336 |
+
Perform an individual iid prediction without ensembling.
|
337 |
+
"""
|
338 |
+
device = rgb_in.device
|
339 |
+
|
340 |
+
# Set timesteps
|
341 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
342 |
+
timesteps = self.scheduler.timesteps # [T]
|
343 |
+
|
344 |
+
# Encode image
|
345 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
346 |
+
|
347 |
+
target_latent_shape = list(rgb_latent.shape)
|
348 |
+
target_latent_shape[1] *= (
|
349 |
+
2 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
|
350 |
+
)
|
351 |
+
|
352 |
+
# Initialize prediction latent with noise
|
353 |
+
if seed is None:
|
354 |
+
rand_num_generator = None
|
355 |
+
else:
|
356 |
+
rand_num_generator = torch.Generator(device=device)
|
357 |
+
rand_num_generator.manual_seed(seed)
|
358 |
+
target_latents = torch.randn(
|
359 |
+
target_latent_shape,
|
360 |
+
device=device,
|
361 |
+
dtype=self.dtype,
|
362 |
+
generator=rand_num_generator,
|
363 |
+
) # [B, 4, h, w]
|
364 |
+
|
365 |
+
# Batched empty text embedding
|
366 |
+
if self.empty_text_embed is None:
|
367 |
+
self.encode_empty_text()
|
368 |
+
batch_empty_text_embed = self.empty_text_embed.repeat(
|
369 |
+
(rgb_latent.shape[0], 1, 1)
|
370 |
+
) # [B, 2, 1024]
|
371 |
+
|
372 |
+
# Denoising loop
|
373 |
+
if show_pbar:
|
374 |
+
iterable = tqdm(
|
375 |
+
enumerate(timesteps),
|
376 |
+
total=len(timesteps),
|
377 |
+
leave=False,
|
378 |
+
desc=" " * 4 + "Diffusion denoising",
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
iterable = enumerate(timesteps)
|
382 |
+
|
383 |
+
for i, t in iterable:
|
384 |
+
unet_input = torch.cat(
|
385 |
+
[rgb_latent, target_latents], dim=1
|
386 |
+
) # this order is important
|
387 |
+
|
388 |
+
# predict the noise residual
|
389 |
+
noise_pred = self.unet(
|
390 |
+
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
391 |
+
).sample # [B, 4, h, w]
|
392 |
+
|
393 |
+
# compute the previous noisy sample x_t -> x_t-1
|
394 |
+
target_latents = self.scheduler.step(
|
395 |
+
noise_pred, t, target_latents, generator=rand_num_generator
|
396 |
+
).prev_sample
|
397 |
+
|
398 |
+
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
|
399 |
+
|
400 |
+
targets = self.decode_targets(target_latents) # [B, 3, H, W]
|
401 |
+
targets = torch.clip(targets, -1.0, 1.0)
|
402 |
+
|
403 |
+
return targets
|
404 |
+
|
405 |
+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
406 |
+
"""
|
407 |
+
Encode RGB image into latent.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
rgb_in (`torch.Tensor`):
|
411 |
+
Input RGB image to be encoded.
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
`torch.Tensor`: Image latent.
|
415 |
+
"""
|
416 |
+
# encode
|
417 |
+
h = self.vae.encoder(rgb_in)
|
418 |
+
moments = self.vae.quant_conv(h)
|
419 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
420 |
+
# scale latent
|
421 |
+
rgb_latent = mean * self.latent_scale_factor
|
422 |
+
return rgb_latent
|
423 |
+
|
424 |
+
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
|
425 |
+
"""
|
426 |
+
Decode target latent into target map.
|
427 |
+
|
428 |
+
Args:
|
429 |
+
target_latents (`torch.Tensor`):
|
430 |
+
Target latent to be decoded.
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
`torch.Tensor`: Decoded target map.
|
434 |
+
"""
|
435 |
+
|
436 |
+
assert target_latents.shape[1] == 8 # self.n_targets * 4
|
437 |
+
|
438 |
+
# scale latent
|
439 |
+
target_latents = target_latents / self.rgb_latent_scale_factor
|
440 |
+
# decode
|
441 |
+
targets = []
|
442 |
+
for i in range(self.n_targets):
|
443 |
+
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
|
444 |
+
z = self.vae.post_quant_conv(latent)
|
445 |
+
stacked = self.vae.decoder(z)
|
446 |
+
|
447 |
+
targets.append(stacked)
|
448 |
+
|
449 |
+
return torch.cat(targets, dim=1)
|
450 |
+
|
451 |
+
@staticmethod
|
452 |
+
def get_pil_resample_method(method_str: str) -> Resampling:
|
453 |
+
resample_method_dic = {
|
454 |
+
"bilinear": Resampling.BILINEAR,
|
455 |
+
"bicubic": Resampling.BICUBIC,
|
456 |
+
"nearest": Resampling.NEAREST,
|
457 |
+
}
|
458 |
+
resample_method = resample_method_dic.get(method_str, None)
|
459 |
+
if resample_method is None:
|
460 |
+
raise ValueError(f"Unknown resampling method: {resample_method}")
|
461 |
+
else:
|
462 |
+
return resample_method
|
463 |
+
|
464 |
+
@staticmethod
|
465 |
+
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
|
466 |
+
"""
|
467 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
468 |
+
"""
|
469 |
+
original_width, original_height = img.size
|
470 |
+
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
|
471 |
+
|
472 |
+
new_width = int(original_width * downscale_factor)
|
473 |
+
new_height = int(original_height * downscale_factor)
|
474 |
+
|
475 |
+
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
476 |
+
return resized_img
|
477 |
+
|
478 |
+
@staticmethod
|
479 |
+
def chw2hwc(chw):
|
480 |
+
assert 3 == len(chw.shape)
|
481 |
+
if isinstance(chw, torch.Tensor):
|
482 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
483 |
+
elif isinstance(chw, np.ndarray):
|
484 |
+
hwc = np.moveaxis(chw, 0, -1)
|
485 |
+
return hwc
|
486 |
+
|
487 |
+
@staticmethod
|
488 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
489 |
+
"""
|
490 |
+
Automatically search for suitable operating batch size.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
ensemble_size (`int`):
|
494 |
+
Number of predictions to be ensembled.
|
495 |
+
input_res (`int`):
|
496 |
+
Operating resolution of the input image.
|
497 |
+
|
498 |
+
Returns:
|
499 |
+
`int`: Operating batch size.
|
500 |
+
"""
|
501 |
+
# Search table for suggested max. inference batch size
|
502 |
+
bs_search_table = [
|
503 |
+
# tested on A100-PCIE-80GB
|
504 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
505 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
506 |
+
# tested on A100-PCIE-40GB
|
507 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
508 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
509 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
510 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
511 |
+
# tested on RTX3090, RTX4090
|
512 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
513 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
514 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
515 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
516 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
517 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
518 |
+
# tested on GTX1080Ti
|
519 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
520 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
521 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
522 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
523 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
524 |
+
]
|
525 |
+
|
526 |
+
if not torch.cuda.is_available():
|
527 |
+
return 1
|
528 |
+
|
529 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
530 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
531 |
+
for settings in sorted(
|
532 |
+
filtered_bs_search_table,
|
533 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
534 |
+
):
|
535 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
536 |
+
bs = settings["bs"]
|
537 |
+
if bs > ensemble_size:
|
538 |
+
bs = ensemble_size
|
539 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
540 |
+
bs = math.ceil(ensemble_size / 2)
|
541 |
+
return bs
|
542 |
+
|
543 |
+
return 1
|
544 |
+
|
marigold_iid_residual.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldcomputervision.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
import logging
|
20 |
+
import math
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from diffusers import (
|
26 |
+
AutoencoderKL,
|
27 |
+
DDIMScheduler,
|
28 |
+
DiffusionPipeline,
|
29 |
+
UNet2DConditionModel,
|
30 |
+
)
|
31 |
+
from diffusers.utils import BaseOutput, check_min_version
|
32 |
+
from PIL import Image
|
33 |
+
from PIL.Image import Resampling
|
34 |
+
from torch.utils.data import DataLoader, TensorDataset
|
35 |
+
from tqdm.auto import tqdm
|
36 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
37 |
+
|
38 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
39 |
+
check_min_version("0.27.0.dev0")
|
40 |
+
|
41 |
+
class MarigoldIIDResidualOutput(BaseOutput):
|
42 |
+
"""
|
43 |
+
Output class for Marigold IID Residual pipeline.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
albedo (`np.ndarray`):
|
47 |
+
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
|
48 |
+
albedo_colored (`PIL.Image.Image`):
|
49 |
+
Colorized albedo map with the shape of [H, W, 3].
|
50 |
+
shading (`np.ndarray`):
|
51 |
+
Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1].
|
52 |
+
shading_colored (`PIL.Image.Image`):
|
53 |
+
Colorized diffuse shading map with the shape of [H, W, 3].
|
54 |
+
residual (`np.ndarray`):
|
55 |
+
Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1].
|
56 |
+
residual_colored (`PIL.Image.Image`):
|
57 |
+
Colorized non-diffuse residual map with the shape of [H, W, 3].
|
58 |
+
|
59 |
+
"""
|
60 |
+
|
61 |
+
albedo: np.ndarray
|
62 |
+
albedo_colored: Image.Image
|
63 |
+
shading: np.ndarray
|
64 |
+
shading_colored: Image.Image
|
65 |
+
residual: np.ndarray
|
66 |
+
residual_colored: Image.Image
|
67 |
+
|
68 |
+
class MarigoldIIDResidualPipeline(DiffusionPipeline):
|
69 |
+
"""
|
70 |
+
Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
|
71 |
+
|
72 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
73 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
74 |
+
|
75 |
+
Args:
|
76 |
+
unet (`UNet2DConditionModel`):
|
77 |
+
Conditional U-Net to denoise the normals latent, conditioned on image latent.
|
78 |
+
vae (`AutoencoderKL`):
|
79 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
|
80 |
+
to and from latent representations.
|
81 |
+
scheduler (`DDIMScheduler`):
|
82 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
83 |
+
text_encoder (`CLIPTextModel`):
|
84 |
+
Text-encoder, for empty text embedding.
|
85 |
+
tokenizer (`CLIPTokenizer`):
|
86 |
+
CLIP tokenizer.
|
87 |
+
"""
|
88 |
+
|
89 |
+
latent_scale_factor = 0.18215
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
unet: UNet2DConditionModel,
|
94 |
+
vae: AutoencoderKL,
|
95 |
+
scheduler: DDIMScheduler,
|
96 |
+
text_encoder: CLIPTextModel,
|
97 |
+
tokenizer: CLIPTokenizer,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.register_modules(
|
102 |
+
unet=unet,
|
103 |
+
vae=vae,
|
104 |
+
scheduler=scheduler,
|
105 |
+
text_encoder=text_encoder,
|
106 |
+
tokenizer=tokenizer,
|
107 |
+
)
|
108 |
+
|
109 |
+
self.empty_text_embed = None
|
110 |
+
self.n_targets = 3 # Albedo, shading, residual
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def __call__(
|
114 |
+
self,
|
115 |
+
input_image: Image,
|
116 |
+
denoising_steps: int = 4,
|
117 |
+
ensemble_size: int = 10,
|
118 |
+
processing_res: int = 768,
|
119 |
+
match_input_res: bool = True,
|
120 |
+
resample_method: str = "bilinear",
|
121 |
+
batch_size: int = 0,
|
122 |
+
save_memory: bool = False,
|
123 |
+
seed: Union[int, None] = None,
|
124 |
+
color_map: str = "Spectral", # TODO change colorization api based on modality
|
125 |
+
show_progress_bar: bool = True,
|
126 |
+
**kwargs,
|
127 |
+
) -> MarigoldIIDResidualOutput:
|
128 |
+
"""
|
129 |
+
Function invoked when calling the pipeline.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
input_image (`Image`):
|
133 |
+
Input RGB (or gray-scale) image.
|
134 |
+
denoising_steps (`int`, *optional*, defaults to `10`):
|
135 |
+
Number of diffusion denoising steps (DDIM) during inference.
|
136 |
+
ensemble_size (`int`, *optional*, defaults to `10`):
|
137 |
+
Number of predictions to be ensembled.
|
138 |
+
processing_res (`int`, *optional*, defaults to `768`):
|
139 |
+
Maximum resolution of processing.
|
140 |
+
If set to 0: will not resize at all.
|
141 |
+
match_input_res (`bool`, *optional*, defaults to `True`):
|
142 |
+
Resize normals prediction to match input resolution.
|
143 |
+
Only valid if `limit_input_res` is not None.
|
144 |
+
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
145 |
+
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
146 |
+
batch_size (`int`, *optional*, defaults to `0`):
|
147 |
+
Inference batch size, no bigger than `num_ensemble`.
|
148 |
+
If set to 0, the script will automatically decide the proper batch size.
|
149 |
+
save_memory (`bool`, defaults to `False`):
|
150 |
+
Extra steps to save memory at the cost of perforance.
|
151 |
+
seed (`int`, *optional*, defaults to `None`)
|
152 |
+
Reproducibility seed.
|
153 |
+
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
|
154 |
+
Colormap used to colorize the normals map.
|
155 |
+
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
156 |
+
Display a progress bar of diffusion denoising.
|
157 |
+
Returns:
|
158 |
+
`MarigoldIIDResidualOutput`: Output class for Marigold monocular intrinsic image decomposition (Residual) prediction pipeline, including:
|
159 |
+
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
160 |
+
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
|
161 |
+
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
|
162 |
+
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
|
163 |
+
"""
|
164 |
+
|
165 |
+
if not match_input_res:
|
166 |
+
assert processing_res is not None
|
167 |
+
assert processing_res >= 0
|
168 |
+
assert denoising_steps >= 1
|
169 |
+
assert ensemble_size >= 1
|
170 |
+
|
171 |
+
# Check if denoising step is reasonable
|
172 |
+
self.check_inference_step(denoising_steps)
|
173 |
+
|
174 |
+
resample_method: Resampling = self.get_pil_resample_method(resample_method)
|
175 |
+
|
176 |
+
W, H = input_image.size
|
177 |
+
|
178 |
+
if processing_res > 0:
|
179 |
+
input_image = self.resize_max_res(
|
180 |
+
input_image, max_edge_resolution=processing_res, resample_method=resample_method,
|
181 |
+
)
|
182 |
+
input_image = input_image.convert("RGB")
|
183 |
+
image = np.asarray(input_image)
|
184 |
+
|
185 |
+
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
186 |
+
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
187 |
+
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
188 |
+
rgb_norm = rgb_norm.to(self.device)
|
189 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
|
190 |
+
|
191 |
+
def ensemble(
|
192 |
+
targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median",
|
193 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
194 |
+
uncertainty = None
|
195 |
+
if reduction == "mean":
|
196 |
+
prediction = torch.mean(targets, dim=0, keepdim=True)
|
197 |
+
if return_uncertainty:
|
198 |
+
uncertainty = torch.std(targets, dim=0, keepdim=True)
|
199 |
+
elif reduction == "median":
|
200 |
+
prediction = torch.median(targets, dim=0, keepdim=True).values
|
201 |
+
if return_uncertainty:
|
202 |
+
uncertainty = torch.median(
|
203 |
+
torch.abs(targets - prediction), dim=0, keepdim=True
|
204 |
+
).values
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
207 |
+
return prediction, uncertainty
|
208 |
+
|
209 |
+
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
210 |
+
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
211 |
+
|
212 |
+
if batch_size <= 0:
|
213 |
+
batch_size = self.find_batch_size(
|
214 |
+
ensemble_size=ensemble_size,
|
215 |
+
input_res=max(rgb_norm.shape[1:]),
|
216 |
+
dtype=self.dtype,
|
217 |
+
)
|
218 |
+
|
219 |
+
single_rgb_loader = DataLoader(
|
220 |
+
single_rgb_dataset, batch_size=batch_size, shuffle=False
|
221 |
+
)
|
222 |
+
|
223 |
+
target_pred_ls = []
|
224 |
+
iterable = single_rgb_loader
|
225 |
+
if show_progress_bar:
|
226 |
+
iterable = tqdm(
|
227 |
+
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
228 |
+
)
|
229 |
+
|
230 |
+
for batch in iterable:
|
231 |
+
(batched_img,) = batch
|
232 |
+
target_pred = self.single_infer(
|
233 |
+
rgb_in=batched_img,
|
234 |
+
num_inference_steps=denoising_steps,
|
235 |
+
seed=seed,
|
236 |
+
show_pbar=show_progress_bar,
|
237 |
+
)
|
238 |
+
target_pred = target_pred.detach()
|
239 |
+
if save_memory:
|
240 |
+
target_pred = target_pred.cpu()
|
241 |
+
target_pred_ls.append(target_pred.detach())
|
242 |
+
|
243 |
+
target_preds = torch.concat(target_pred_ls, dim=0)
|
244 |
+
pred_uncert = None
|
245 |
+
|
246 |
+
if save_memory:
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
|
249 |
+
if ensemble_size > 1:
|
250 |
+
final_pred, pred_uncert = ensemble(
|
251 |
+
target_preds,
|
252 |
+
reduction = "median",
|
253 |
+
return_uncertainty=False
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
final_pred = target_preds
|
257 |
+
pred_uncert = None
|
258 |
+
|
259 |
+
if match_input_res:
|
260 |
+
final_pred = torch.nn.functional.interpolate(
|
261 |
+
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
|
262 |
+
) # [1,3,H,W]
|
263 |
+
|
264 |
+
if pred_uncert is not None:
|
265 |
+
pred_uncert = torch.nn.functional.interpolate(
|
266 |
+
pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
|
267 |
+
).squeeze(
|
268 |
+
1
|
269 |
+
) # [1,H,W]
|
270 |
+
|
271 |
+
# Convert to numpy
|
272 |
+
final_pred = final_pred.squeeze()
|
273 |
+
final_pred = final_pred.cpu().numpy()
|
274 |
+
|
275 |
+
albedo = final_pred[0:3, :, :]
|
276 |
+
shading = final_pred[3:6, :, :]
|
277 |
+
residual = final_pred[6:, :, :]
|
278 |
+
|
279 |
+
albedo_colored = (albedo + 1.0) * 0.5
|
280 |
+
albedo_colored = (albedo_colored * 255).to(np.uint8)
|
281 |
+
albedo_colored = self.chw2hwc(albedo_colored)
|
282 |
+
albedo_colored_img = Image.fromarray(albedo_colored)
|
283 |
+
|
284 |
+
shading_colored = (shading + 1.0) * 0.5
|
285 |
+
shading_colored = shading_colored / shading_colored.max() # rescale for better visualization
|
286 |
+
shading_colored = (shading_colored * 255).to(np.uint8)
|
287 |
+
shading_colored = self.chw2hwc(shading_colored)
|
288 |
+
shading_colored_img = Image.fromarray(shading_colored)
|
289 |
+
|
290 |
+
residual_colored = (residual + 1.0) * 0.5
|
291 |
+
residual_colored = residual_colored / residual_colored.max() # rescale for better visualization
|
292 |
+
residual_colored = (residual_colored * 255).to(np.uint8)
|
293 |
+
residual_colored = self.chw2hwc(residual_colored)
|
294 |
+
residual_colored_img = Image.fromarray(residual_colored)
|
295 |
+
|
296 |
+
out = MarigoldIIDResidualOutput(
|
297 |
+
albedo=albedo,
|
298 |
+
albedo_colored=albedo_colored_img,
|
299 |
+
shading=shading,
|
300 |
+
shading_colored=shading_colored_img,
|
301 |
+
residual=residual,
|
302 |
+
residual_colored=residual_colored_img
|
303 |
+
)
|
304 |
+
|
305 |
+
return out
|
306 |
+
|
307 |
+
def check_inference_step(self, n_step: int):
|
308 |
+
"""
|
309 |
+
Check if denoising step is reasonable
|
310 |
+
Args:
|
311 |
+
n_step (`int`): denoising steps
|
312 |
+
"""
|
313 |
+
assert n_step >= 1
|
314 |
+
|
315 |
+
if isinstance(self.scheduler, DDIMScheduler):
|
316 |
+
pass
|
317 |
+
else:
|
318 |
+
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
319 |
+
|
320 |
+
def encode_empty_text(self):
|
321 |
+
"""
|
322 |
+
Encode text embedding for empty prompt.
|
323 |
+
"""
|
324 |
+
prompt = ""
|
325 |
+
text_inputs = self.tokenizer(
|
326 |
+
prompt,
|
327 |
+
padding="do_not_pad",
|
328 |
+
max_length=self.tokenizer.model_max_length,
|
329 |
+
truncation=True,
|
330 |
+
return_tensors="pt",
|
331 |
+
)
|
332 |
+
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
333 |
+
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
334 |
+
|
335 |
+
@torch.no_grad()
|
336 |
+
def single_infer(
|
337 |
+
self,
|
338 |
+
rgb_in: torch.Tensor,
|
339 |
+
num_inference_steps: int,
|
340 |
+
seed: Union[int, None],
|
341 |
+
show_pbar: bool,
|
342 |
+
) -> torch.Tensor:
|
343 |
+
"""
|
344 |
+
Perform an individual iid prediction without ensembling.
|
345 |
+
"""
|
346 |
+
device = rgb_in.device
|
347 |
+
|
348 |
+
# Set timesteps
|
349 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
350 |
+
timesteps = self.scheduler.timesteps # [T]
|
351 |
+
|
352 |
+
# Encode image
|
353 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
354 |
+
|
355 |
+
target_latent_shape = list(rgb_latent.shape)
|
356 |
+
target_latent_shape[1] *= (
|
357 |
+
3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
|
358 |
+
)
|
359 |
+
|
360 |
+
# Initialize prediction latent with noise
|
361 |
+
if seed is None:
|
362 |
+
rand_num_generator = None
|
363 |
+
else:
|
364 |
+
rand_num_generator = torch.Generator(device=device)
|
365 |
+
rand_num_generator.manual_seed(seed)
|
366 |
+
target_latents = torch.randn(
|
367 |
+
target_latent_shape,
|
368 |
+
device=device,
|
369 |
+
dtype=self.dtype,
|
370 |
+
generator=rand_num_generator,
|
371 |
+
) # [B, 4, h, w]
|
372 |
+
|
373 |
+
# Batched empty text embedding
|
374 |
+
if self.empty_text_embed is None:
|
375 |
+
self.encode_empty_text()
|
376 |
+
batch_empty_text_embed = self.empty_text_embed.repeat(
|
377 |
+
(rgb_latent.shape[0], 1, 1)
|
378 |
+
) # [B, 2, 1024]
|
379 |
+
|
380 |
+
# Denoising loop
|
381 |
+
if show_pbar:
|
382 |
+
iterable = tqdm(
|
383 |
+
enumerate(timesteps),
|
384 |
+
total=len(timesteps),
|
385 |
+
leave=False,
|
386 |
+
desc=" " * 4 + "Diffusion denoising",
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
iterable = enumerate(timesteps)
|
390 |
+
|
391 |
+
for i, t in iterable:
|
392 |
+
unet_input = torch.cat(
|
393 |
+
[rgb_latent, target_latents], dim=1
|
394 |
+
) # this order is important
|
395 |
+
|
396 |
+
# predict the noise residual
|
397 |
+
noise_pred = self.unet(
|
398 |
+
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
399 |
+
).sample # [B, 4, h, w]
|
400 |
+
|
401 |
+
# compute the previous noisy sample x_t -> x_t-1
|
402 |
+
target_latents = self.scheduler.step(
|
403 |
+
noise_pred, t, target_latents, generator=rand_num_generator
|
404 |
+
).prev_sample
|
405 |
+
|
406 |
+
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
|
407 |
+
|
408 |
+
targets = self.decode_targets(target_latents) # [B, 3, H, W]
|
409 |
+
targets = torch.clip(targets, -1.0, 1.0)
|
410 |
+
|
411 |
+
return targets
|
412 |
+
|
413 |
+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
414 |
+
"""
|
415 |
+
Encode RGB image into latent.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
rgb_in (`torch.Tensor`):
|
419 |
+
Input RGB image to be encoded.
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
`torch.Tensor`: Image latent.
|
423 |
+
"""
|
424 |
+
# encode
|
425 |
+
h = self.vae.encoder(rgb_in)
|
426 |
+
moments = self.vae.quant_conv(h)
|
427 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
428 |
+
# scale latent
|
429 |
+
rgb_latent = mean * self.latent_scale_factor
|
430 |
+
return rgb_latent
|
431 |
+
|
432 |
+
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
|
433 |
+
"""
|
434 |
+
Decode target latent into target map.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
target_latents (`torch.Tensor`):
|
438 |
+
Target latent to be decoded.
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
`torch.Tensor`: Decoded target map.
|
442 |
+
"""
|
443 |
+
|
444 |
+
assert target_latents.shape[1] == 12 # self.n_targets * 4
|
445 |
+
|
446 |
+
# scale latent
|
447 |
+
target_latents = target_latents / self.rgb_latent_scale_factor
|
448 |
+
# decode
|
449 |
+
targets = []
|
450 |
+
for i in range(self.n_targets):
|
451 |
+
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
|
452 |
+
z = self.vae.post_quant_conv(latent)
|
453 |
+
stacked = self.vae.decoder(z)
|
454 |
+
|
455 |
+
targets.append(stacked)
|
456 |
+
|
457 |
+
return torch.cat(targets, dim=1)
|
458 |
+
|
459 |
+
@staticmethod
|
460 |
+
def get_pil_resample_method(method_str: str) -> Resampling:
|
461 |
+
resample_method_dic = {
|
462 |
+
"bilinear": Resampling.BILINEAR,
|
463 |
+
"bicubic": Resampling.BICUBIC,
|
464 |
+
"nearest": Resampling.NEAREST,
|
465 |
+
}
|
466 |
+
resample_method = resample_method_dic.get(method_str, None)
|
467 |
+
if resample_method is None:
|
468 |
+
raise ValueError(f"Unknown resampling method: {resample_method}")
|
469 |
+
else:
|
470 |
+
return resample_method
|
471 |
+
|
472 |
+
@staticmethod
|
473 |
+
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
|
474 |
+
"""
|
475 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
476 |
+
"""
|
477 |
+
original_width, original_height = img.size
|
478 |
+
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
|
479 |
+
|
480 |
+
new_width = int(original_width * downscale_factor)
|
481 |
+
new_height = int(original_height * downscale_factor)
|
482 |
+
|
483 |
+
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
484 |
+
return resized_img
|
485 |
+
|
486 |
+
@staticmethod
|
487 |
+
def chw2hwc(chw):
|
488 |
+
assert 3 == len(chw.shape)
|
489 |
+
if isinstance(chw, torch.Tensor):
|
490 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
491 |
+
elif isinstance(chw, np.ndarray):
|
492 |
+
hwc = np.moveaxis(chw, 0, -1)
|
493 |
+
return hwc
|
494 |
+
|
495 |
+
@staticmethod
|
496 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
497 |
+
"""
|
498 |
+
Automatically search for suitable operating batch size.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
ensemble_size (`int`):
|
502 |
+
Number of predictions to be ensembled.
|
503 |
+
input_res (`int`):
|
504 |
+
Operating resolution of the input image.
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
`int`: Operating batch size.
|
508 |
+
"""
|
509 |
+
# Search table for suggested max. inference batch size
|
510 |
+
bs_search_table = [
|
511 |
+
# tested on A100-PCIE-80GB
|
512 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
513 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
514 |
+
# tested on A100-PCIE-40GB
|
515 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
516 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
517 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
518 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
519 |
+
# tested on RTX3090, RTX4090
|
520 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
521 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
522 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
523 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
524 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
525 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
526 |
+
# tested on GTX1080Ti
|
527 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
528 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
529 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
530 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
531 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
532 |
+
]
|
533 |
+
|
534 |
+
if not torch.cuda.is_available():
|
535 |
+
return 1
|
536 |
+
|
537 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
538 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
539 |
+
for settings in sorted(
|
540 |
+
filtered_bs_search_table,
|
541 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
542 |
+
):
|
543 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
544 |
+
bs = settings["bs"]
|
545 |
+
if bs > ensemble_size:
|
546 |
+
bs = ensemble_size
|
547 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
548 |
+
bs = math.ceil(ensemble_size / 2)
|
549 |
+
return bs
|
550 |
+
|
551 |
+
return 1
|
552 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.25.0
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohttp==3.9.3
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==5.3.0
|
6 |
+
annotated-types==0.6.0
|
7 |
+
anyio==4.3.0
|
8 |
+
async-timeout==4.0.3
|
9 |
+
attrs==23.2.0
|
10 |
+
Authlib==1.3.0
|
11 |
+
certifi==2024.2.2
|
12 |
+
cffi==1.16.0
|
13 |
+
charset-normalizer==3.3.2
|
14 |
+
click==8.0.4
|
15 |
+
cmake==3.29.0.1
|
16 |
+
contourpy==1.2.0
|
17 |
+
cryptography==42.0.5
|
18 |
+
cycler==0.12.1
|
19 |
+
dataclasses-json==0.6.4
|
20 |
+
datasets==2.18.0
|
21 |
+
Deprecated==1.2.14
|
22 |
+
diffusers==0.27.2
|
23 |
+
dill==0.3.8
|
24 |
+
exceptiongroup==1.2.0
|
25 |
+
fastapi==0.110.0
|
26 |
+
ffmpy==0.3.2
|
27 |
+
filelock==3.13.3
|
28 |
+
fonttools==4.50.0
|
29 |
+
frozenlist==1.4.1
|
30 |
+
fsspec==2024.2.0
|
31 |
+
gradio==4.21.0
|
32 |
+
gradio_client==0.12.0
|
33 |
+
gradio_imageslider==0.0.18
|
34 |
+
h11==0.14.0
|
35 |
+
httpcore==1.0.5
|
36 |
+
httpx==0.27.0
|
37 |
+
huggingface-hub==0.22.1
|
38 |
+
idna==3.6
|
39 |
+
imageio==2.34.0
|
40 |
+
imageio-ffmpeg==0.4.9
|
41 |
+
importlib_metadata==7.1.0
|
42 |
+
importlib_resources==6.4.0
|
43 |
+
itsdangerous==2.1.2
|
44 |
+
Jinja2==3.1.3
|
45 |
+
jsonschema==4.21.1
|
46 |
+
jsonschema-specifications==2023.12.1
|
47 |
+
kiwisolver==1.4.5
|
48 |
+
lit==18.1.2
|
49 |
+
markdown-it-py==3.0.0
|
50 |
+
MarkupSafe==2.1.5
|
51 |
+
marshmallow==3.21.1
|
52 |
+
matplotlib==3.8.2
|
53 |
+
mdurl==0.1.2
|
54 |
+
mpmath==1.3.0
|
55 |
+
multidict==6.0.5
|
56 |
+
multiprocess==0.70.16
|
57 |
+
mypy-extensions==1.0.0
|
58 |
+
networkx==3.2.1
|
59 |
+
numpy==1.26.4
|
60 |
+
nvidia-cublas-cu11==11.10.3.66
|
61 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
62 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
63 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
64 |
+
nvidia-cudnn-cu11==8.5.0.96
|
65 |
+
nvidia-cufft-cu11==10.9.0.58
|
66 |
+
nvidia-curand-cu11==10.2.10.91
|
67 |
+
nvidia-cusolver-cu11==11.4.0.1
|
68 |
+
nvidia-cusparse-cu11==11.7.4.91
|
69 |
+
nvidia-nccl-cu11==2.14.3
|
70 |
+
nvidia-nvtx-cu11==11.7.91
|
71 |
+
orjson==3.10.0
|
72 |
+
packaging==24.0
|
73 |
+
pandas==2.2.1
|
74 |
+
pillow==10.2.0
|
75 |
+
protobuf==3.20.3
|
76 |
+
psutil==5.9.8
|
77 |
+
pyarrow==15.0.2
|
78 |
+
pyarrow-hotfix==0.6
|
79 |
+
pycparser==2.22
|
80 |
+
pydantic==2.6.4
|
81 |
+
pydantic_core==2.16.3
|
82 |
+
pydub==0.25.1
|
83 |
+
pygltflib==1.16.1
|
84 |
+
Pygments==2.17.2
|
85 |
+
pyparsing==3.1.2
|
86 |
+
python-dateutil==2.9.0.post0
|
87 |
+
python-multipart==0.0.9
|
88 |
+
pytz==2024.1
|
89 |
+
PyYAML==6.0.1
|
90 |
+
referencing==0.34.0
|
91 |
+
regex==2023.12.25
|
92 |
+
requests==2.31.0
|
93 |
+
rich==13.7.1
|
94 |
+
rpds-py==0.18.0
|
95 |
+
ruff==0.3.4
|
96 |
+
safetensors==0.4.2
|
97 |
+
scipy==1.11.4
|
98 |
+
semantic-version==2.10.0
|
99 |
+
shellingham==1.5.4
|
100 |
+
six==1.16.0
|
101 |
+
sniffio==1.3.1
|
102 |
+
spaces==0.25.0
|
103 |
+
starlette==0.36.3
|
104 |
+
sympy==1.12
|
105 |
+
tokenizers==0.15.2
|
106 |
+
tomlkit==0.12.0
|
107 |
+
toolz==0.12.1
|
108 |
+
torch==2.0.1
|
109 |
+
tqdm==4.66.2
|
110 |
+
transformers==4.36.1
|
111 |
+
trimesh==4.0.5
|
112 |
+
triton==2.0.0
|
113 |
+
typer==0.12.0
|
114 |
+
typer-cli==0.12.0
|
115 |
+
typer-slim==0.12.0
|
116 |
+
typing-inspect==0.9.0
|
117 |
+
typing_extensions==4.10.0
|
118 |
+
tzdata==2024.1
|
119 |
+
urllib3==2.2.1
|
120 |
+
uvicorn==0.29.0
|
121 |
+
websockets==11.0.3
|
122 |
+
wrapt==1.16.0
|
123 |
+
xformers==0.0.21
|
124 |
+
xxhash==3.4.1
|
125 |
+
yarl==1.9.4
|
126 |
+
zipp==3.18.1
|
requirements_min.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.21.0
|
2 |
+
gradio-imageslider==0.0.18
|
3 |
+
pygltflib==1.16.1
|
4 |
+
trimesh==4.0.5
|
5 |
+
imageio
|
6 |
+
imageio-ffmpeg
|
7 |
+
Pillow
|
8 |
+
|
9 |
+
spaces==0.25.0
|
10 |
+
accelerate==0.25.0
|
11 |
+
diffusers==0.27.2
|
12 |
+
matplotlib==3.8.2
|
13 |
+
scipy==1.11.4
|
14 |
+
torch==2.0.1
|
15 |
+
transformers==4.36.1
|
16 |
+
xformers==0.0.21
|