Spaces:
Runtime error
Runtime error
Zafaflahfksdf
commited on
Commit
•
da3eeba
1
Parent(s):
ca1233a
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +142 -0
- LICENSE +201 -0
- README.md +149 -12
- README_DEV.md +61 -0
- fast_sam/__init__.py +9 -0
- fast_sam/fast_sam_wrapper.py +90 -0
- ia_check_versions.py +74 -0
- ia_config.py +115 -0
- ia_devices.py +10 -0
- ia_file_manager.py +71 -0
- ia_get_dataset_colormap.py +416 -0
- ia_logging.py +14 -0
- ia_sam_manager.py +182 -0
- ia_threading.py +55 -0
- ia_ui_gradio.py +30 -0
- ia_ui_items.py +110 -0
- iasam_app.py +809 -0
- images/inpaint_anything_explanation_image_1.png +0 -0
- images/inpaint_anything_ui_image_1.png +0 -0
- images/sample_input_image.png +0 -0
- images/sample_mask_image.png +0 -0
- images/sample_seg_color_image.png +0 -0
- inpalib/__init__.py +18 -0
- inpalib/masklib.py +106 -0
- inpalib/samlib.py +256 -0
- javascript/inpaint-anything.js +458 -0
- lama_cleaner/__init__.py +19 -0
- lama_cleaner/benchmark.py +109 -0
- lama_cleaner/const.py +173 -0
- lama_cleaner/file_manager/__init__.py +1 -0
- lama_cleaner/file_manager/file_manager.py +265 -0
- lama_cleaner/file_manager/storage_backends.py +46 -0
- lama_cleaner/file_manager/utils.py +67 -0
- lama_cleaner/helper.py +292 -0
- lama_cleaner/installer.py +12 -0
- lama_cleaner/model/__init__.py +0 -0
- lama_cleaner/model/base.py +298 -0
- lama_cleaner/model/controlnet.py +289 -0
- lama_cleaner/model/ddim_sampler.py +193 -0
- lama_cleaner/model/fcf.py +1733 -0
- lama_cleaner/model/instruct_pix2pix.py +83 -0
- lama_cleaner/model/lama.py +51 -0
- lama_cleaner/model/ldm.py +333 -0
- lama_cleaner/model/manga.py +91 -0
- lama_cleaner/model/mat.py +1935 -0
- lama_cleaner/model/opencv2.py +28 -0
- lama_cleaner/model/paint_by_example.py +79 -0
- lama_cleaner/model/pipeline/__init__.py +3 -0
- lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py +585 -0
- lama_cleaner/model/plms_sampler.py +225 -0
.gitignore
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pth
|
2 |
+
*.pt
|
3 |
+
*.pyc
|
4 |
+
src/
|
5 |
+
outputs/
|
6 |
+
models/
|
7 |
+
models
|
8 |
+
.DS_Store
|
9 |
+
ia_config.ini
|
10 |
+
.eslintrc
|
11 |
+
.eslintrc.json
|
12 |
+
pyproject.toml
|
13 |
+
|
14 |
+
# Byte-compiled / optimized / DLL files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
|
19 |
+
# C extensions
|
20 |
+
*.so
|
21 |
+
|
22 |
+
# Distribution / packaging
|
23 |
+
.Python
|
24 |
+
build/
|
25 |
+
develop-eggs/
|
26 |
+
dist/
|
27 |
+
downloads/
|
28 |
+
eggs/
|
29 |
+
.eggs/
|
30 |
+
lib/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
pip-wheel-metadata/
|
37 |
+
share/python-wheels/
|
38 |
+
*.egg-info/
|
39 |
+
.installed.cfg
|
40 |
+
*.egg
|
41 |
+
MANIFEST
|
42 |
+
|
43 |
+
# PyInstaller
|
44 |
+
# Usually these files are written by a python script from a template
|
45 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
46 |
+
*.manifest
|
47 |
+
*.spec
|
48 |
+
|
49 |
+
# Installer logs
|
50 |
+
pip-log.txt
|
51 |
+
pip-delete-this-directory.txt
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.nox/
|
57 |
+
.coverage
|
58 |
+
.coverage.*
|
59 |
+
.cache
|
60 |
+
nosetests.xml
|
61 |
+
coverage.xml
|
62 |
+
*.cover
|
63 |
+
*.py,cover
|
64 |
+
.hypothesis/
|
65 |
+
.pytest_cache/
|
66 |
+
|
67 |
+
# Translations
|
68 |
+
*.mo
|
69 |
+
*.pot
|
70 |
+
|
71 |
+
# Django stuff:
|
72 |
+
*.log
|
73 |
+
local_settings.py
|
74 |
+
db.sqlite3
|
75 |
+
db.sqlite3-journal
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
target/
|
89 |
+
|
90 |
+
# Jupyter Notebook
|
91 |
+
.ipynb_checkpoints
|
92 |
+
|
93 |
+
# IPython
|
94 |
+
profile_default/
|
95 |
+
ipython_config.py
|
96 |
+
|
97 |
+
# pyenv
|
98 |
+
.python-version
|
99 |
+
|
100 |
+
# pipenv
|
101 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
102 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
103 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
104 |
+
# install all needed dependencies.
|
105 |
+
#Pipfile.lock
|
106 |
+
|
107 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
108 |
+
__pypackages__/
|
109 |
+
|
110 |
+
# Celery stuff
|
111 |
+
celerybeat-schedule
|
112 |
+
celerybeat.pid
|
113 |
+
|
114 |
+
# SageMath parsed files
|
115 |
+
*.sage.py
|
116 |
+
|
117 |
+
# Environments
|
118 |
+
.env
|
119 |
+
.venv
|
120 |
+
env/
|
121 |
+
venv/
|
122 |
+
ENV/
|
123 |
+
env.bak/
|
124 |
+
venv.bak/
|
125 |
+
|
126 |
+
# Spyder project settings
|
127 |
+
.spyderproject
|
128 |
+
.spyproject
|
129 |
+
|
130 |
+
# Rope project settings
|
131 |
+
.ropeproject
|
132 |
+
|
133 |
+
# mkdocs documentation
|
134 |
+
/site
|
135 |
+
|
136 |
+
# mypy
|
137 |
+
.mypy_cache/
|
138 |
+
.dmypy.json
|
139 |
+
dmypy.json
|
140 |
+
|
141 |
+
# Pyre type checker
|
142 |
+
.pyre/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,149 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: _
|
3 |
+
app_file: iasam_app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 3.50.2
|
6 |
+
---
|
7 |
+
# Inpaint Anything (Inpainting with Segment Anything)
|
8 |
+
|
9 |
+
Inpaint Anything performs stable diffusion inpainting on a browser UI using any mask selected from the output of [Segment Anything](https://github.com/facebookresearch/segment-anything).
|
10 |
+
|
11 |
+
|
12 |
+
Using Segment Anything enables users to specify masks by simply pointing to the desired areas, instead of manually filling them in. This can increase the efficiency and accuracy of the mask creation process, leading to potentially higher-quality inpainting results while saving time and effort.
|
13 |
+
|
14 |
+
[Extension version for AUTOMATIC1111's Web UI](https://github.com/Uminosachi/sd-webui-inpaint-anything)
|
15 |
+
|
16 |
+
![Explanation image](images/inpaint_anything_explanation_image_1.png)
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
|
20 |
+
Please follow these steps to install the software:
|
21 |
+
|
22 |
+
* Create a new conda environment:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
conda create -n inpaint python=3.10
|
26 |
+
conda activate inpaint
|
27 |
+
```
|
28 |
+
|
29 |
+
* Clone the software repository:
|
30 |
+
|
31 |
+
```bash
|
32 |
+
git clone https://github.com/Uminosachi/inpaint-anything.git
|
33 |
+
cd inpaint-anything
|
34 |
+
```
|
35 |
+
|
36 |
+
* For the CUDA environment, install the following packages:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
pip install -r requirements.txt
|
40 |
+
```
|
41 |
+
|
42 |
+
* If you are using macOS, please install the package from the following file instead:
|
43 |
+
|
44 |
+
```bash
|
45 |
+
pip install -r requirements_mac.txt
|
46 |
+
```
|
47 |
+
|
48 |
+
## Running the application
|
49 |
+
|
50 |
+
```bash
|
51 |
+
python iasam_app.py
|
52 |
+
```
|
53 |
+
|
54 |
+
* Open http://127.0.0.1:7860/ in your browser.
|
55 |
+
* Note: If you have a privacy protection extension enabled in your web browser, such as DuckDuckGo, you may not be able to retrieve the mask from your sketch.
|
56 |
+
|
57 |
+
### Options
|
58 |
+
|
59 |
+
* `--save-seg`: Save the segmentation image generated by SAM.
|
60 |
+
* `--offline`: Execute inpainting using an offline network.
|
61 |
+
* `--sam-cpu`: Perform the Segment Anything operation on CPU.
|
62 |
+
|
63 |
+
## Downloading the Model
|
64 |
+
|
65 |
+
* Launch this application.
|
66 |
+
* Click on the `Download model` button, located next to the [Segment Anything Model ID](https://github.com/facebookresearch/segment-anything#model-checkpoints). This includes the [SAM 2](https://github.com/facebookresearch/segment-anything-2), [Segment Anything in High Quality Model ID](https://github.com/SysCV/sam-hq), [Fast Segment Anything](https://github.com/CASIA-IVA-Lab/FastSAM), and [Faster Segment Anything (MobileSAM)](https://github.com/ChaoningZhang/MobileSAM).
|
67 |
+
* Please note that the SAM is available in three sizes: Base, Large, and Huge. Remember, larger sizes consume more VRAM.
|
68 |
+
* Wait for the download to complete.
|
69 |
+
* The downloaded model file will be stored in the `models` directory of this application's repository.
|
70 |
+
|
71 |
+
## Usage
|
72 |
+
|
73 |
+
* Drag and drop your image onto the input image area.
|
74 |
+
* Outpainting can be achieved by the `Padding options`, configuring the scale and balance, and then clicking on the `Run Padding` button.
|
75 |
+
* The `Anime Style` checkbox enhances segmentation mask detection, particularly in anime style images, at the expense of a slight reduction in mask quality.
|
76 |
+
* Click on the `Run Segment Anything` button.
|
77 |
+
* Use sketching to point the area you want to inpaint. You can undo and adjust the pen size.
|
78 |
+
* Hover over either the SAM image or the mask image and press the `S` key for Fullscreen mode, or the `R` key to Reset zoom.
|
79 |
+
* Click on the `Create mask` button. The mask will appear in the selected mask image area.
|
80 |
+
|
81 |
+
### Mask Adjustment
|
82 |
+
|
83 |
+
* `Expand mask region` button: Use this to slightly expand the area of the mask for broader coverage.
|
84 |
+
* `Trim mask by sketch` button: Clicking this will exclude the sketched area from the mask.
|
85 |
+
* `Add mask by sketch` button: Clicking this will add the sketched area to the mask.
|
86 |
+
|
87 |
+
### Inpainting Tab
|
88 |
+
|
89 |
+
* Enter your desired Prompt and Negative Prompt, then choose the Inpainting Model ID.
|
90 |
+
* Click on the `Run Inpainting` button (**Please note that it may take some time to download the model for the first time**).
|
91 |
+
* In the Advanced options, you can adjust the Sampler, Sampling Steps, Guidance Scale, and Seed.
|
92 |
+
* If you enable the `Mask area Only` option, modifications will be confined to the designated mask area only.
|
93 |
+
* Adjust the iteration slider to perform inpainting multiple times with different seeds.
|
94 |
+
* The inpainting process is powered by [diffusers](https://github.com/huggingface/diffusers).
|
95 |
+
|
96 |
+
#### Tips
|
97 |
+
|
98 |
+
* You can directly drag and drop the inpainted image into the input image field on the Web UI. (useful with Chrome and Edge browsers)
|
99 |
+
|
100 |
+
#### Model Cache
|
101 |
+
* The inpainting model, which is saved in HuggingFace's cache and includes `inpaint` (case-insensitive) in its repo_id, will also be added to the Inpainting Model ID dropdown list.
|
102 |
+
* If there's a specific model you'd like to use, you can cache it in advance using the following Python commands:
|
103 |
+
```bash
|
104 |
+
python
|
105 |
+
```
|
106 |
+
```python
|
107 |
+
from diffusers import StableDiffusionInpaintPipeline
|
108 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained("Uminosachi/dreamshaper_5-inpainting")
|
109 |
+
exit()
|
110 |
+
```
|
111 |
+
* The model diffusers downloaded is typically stored in your home directory. You can find it at `/home/username/.cache/huggingface/hub` for Linux and MacOS users, or at `C:\Users\username\.cache\huggingface\hub` for Windows users.
|
112 |
+
* When executing inpainting, if the following error is output to the console, try deleting the corresponding model from the cache folder mentioned above:
|
113 |
+
```
|
114 |
+
An error occurred while trying to fetch model name...
|
115 |
+
```
|
116 |
+
|
117 |
+
### Cleaner Tab
|
118 |
+
|
119 |
+
* Choose the Cleaner Model ID.
|
120 |
+
* Click on the `Run Cleaner` button (**Please note that it may take some time to download the model for the first time**).
|
121 |
+
* Cleaner process is performed using [Lama Cleaner](https://github.com/Sanster/lama-cleaner).
|
122 |
+
|
123 |
+
### Mask only Tab
|
124 |
+
|
125 |
+
* Gives ability to just save mask without any other processing, so it's then possible to use the mask in other graphic applications.
|
126 |
+
* `Get mask as alpha of image` button: Save the mask as RGBA image, with the mask put into the alpha channel of the input image.
|
127 |
+
* `Get mask` button: Save the mask as RGB image.
|
128 |
+
|
129 |
+
![UI image](images/inpaint_anything_ui_image_1.png)
|
130 |
+
|
131 |
+
## Auto-saving images
|
132 |
+
|
133 |
+
* The inpainted image will be automatically saved in the folder that matches the current date within the `outputs` directory.
|
134 |
+
|
135 |
+
## Development
|
136 |
+
|
137 |
+
With the [Inpaint Anything library](README_DEV.md), you can perform segmentation and create masks using sketches from other applications.
|
138 |
+
|
139 |
+
## License
|
140 |
+
|
141 |
+
The source code is licensed under the [Apache 2.0 license](LICENSE).
|
142 |
+
|
143 |
+
## References
|
144 |
+
|
145 |
+
* Ravi, N., Gabeur, V., Hu, Y.-T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädel, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K. V., Carion, N., Wu, C.-Y., Girshick, R., Dollár, P., & Feichtenhofer, C. (2024). [SAM 2: Segment Anything in Images and Videos](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/). arXiv preprint.
|
146 |
+
* Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A. C., Lo, W-Y., Dollár, P., & Girshick, R. (2023). [Segment Anything](https://arxiv.org/abs/2304.02643). arXiv:2304.02643.
|
147 |
+
* Ke, L., Ye, M., Danelljan, M., Liu, Y., Tai, Y-W., Tang, C-K., & Yu, F. (2023). [Segment Anything in High Quality](https://arxiv.org/abs/2306.01567). arXiv:2306.01567.
|
148 |
+
* Zhao, X., Ding, W., An, Y., Du, Y., Yu, T., Li, M., Tang, M., & Wang, J. (2023). [Fast Segment Anything](https://arxiv.org/abs/2306.12156). arXiv:2306.12156 [cs.CV].
|
149 |
+
* Zhang, C., Han, D., Qiao, Y., Kim, J. U., Bae, S-H., Lee, S., & Hong, C. S. (2023). [Faster Segment Anything: Towards Lightweight SAM for Mobile Applications](https://arxiv.org/abs/2306.14289). arXiv:2306.14289.
|
README_DEV.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Usage of Inpaint Anything Library
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
The `inpalib` from the `inpaint-anything` package lets you segment images and create masks using sketches from other applications.
|
6 |
+
|
7 |
+
## Code Breakdown
|
8 |
+
|
9 |
+
### Imports and Module Initialization
|
10 |
+
|
11 |
+
```python
|
12 |
+
import importlib
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image, ImageDraw
|
16 |
+
|
17 |
+
inpalib = importlib.import_module("inpaint-anything.inpalib")
|
18 |
+
```
|
19 |
+
|
20 |
+
### Fetch Model IDs
|
21 |
+
|
22 |
+
```python
|
23 |
+
available_sam_ids = inpalib.get_available_sam_ids()
|
24 |
+
|
25 |
+
use_sam_id = "sam_hq_vit_l.pth"
|
26 |
+
# assert use_sam_id in available_sam_ids, f"Invalid SAM ID: {use_sam_id}"
|
27 |
+
```
|
28 |
+
|
29 |
+
Note: Only the models downloaded via the Inpaint Anything are available.
|
30 |
+
|
31 |
+
### Generate Segments Image
|
32 |
+
|
33 |
+
```python
|
34 |
+
input_image = np.array(Image.open("/path/to/image.png"))
|
35 |
+
|
36 |
+
sam_masks = inpalib.generate_sam_masks(input_image, use_sam_id, anime_style_chk=False)
|
37 |
+
sam_masks = inpalib.sort_masks_by_area(sam_masks)
|
38 |
+
|
39 |
+
seg_color_image = inpalib.create_seg_color_image(input_image, sam_masks)
|
40 |
+
|
41 |
+
Image.fromarray(seg_color_image).save("/path/to/seg_color_image.png")
|
42 |
+
```
|
43 |
+
|
44 |
+
<img src="images/sample_input_image.png" alt="drawing" width="256"/> <img src="images/sample_seg_color_image.png" alt="drawing" width="256"/>
|
45 |
+
|
46 |
+
### Create Mask from Sketch
|
47 |
+
|
48 |
+
```python
|
49 |
+
sketch_image = Image.fromarray(np.zeros_like(input_image))
|
50 |
+
|
51 |
+
draw = ImageDraw.Draw(sketch_image)
|
52 |
+
draw.point((input_image.shape[1] // 2, input_image.shape[0] // 2), fill=(255, 255, 255))
|
53 |
+
|
54 |
+
mask_image = inpalib.create_mask_image(np.array(sketch_image), sam_masks, ignore_black_chk=True)
|
55 |
+
|
56 |
+
Image.fromarray(mask_image).save("/path/to/mask_image.png")
|
57 |
+
```
|
58 |
+
|
59 |
+
<img src="images/sample_mask_image.png" alt="drawing" width="256"/>
|
60 |
+
|
61 |
+
Note: Ensure you adjust the file paths before executing the code.
|
fast_sam/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fast_sam_wrapper import FastSAM
|
2 |
+
from .fast_sam_wrapper import FastSamAutomaticMaskGenerator
|
3 |
+
|
4 |
+
fast_sam_model_registry = {
|
5 |
+
"FastSAM-x": FastSAM,
|
6 |
+
"FastSAM-s": FastSAM,
|
7 |
+
}
|
8 |
+
|
9 |
+
__all__ = ["FastSAM", "FastSamAutomaticMaskGenerator", "fast_sam_model_registry"]
|
fast_sam/fast_sam_wrapper.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import ultralytics
|
9 |
+
|
10 |
+
if hasattr(ultralytics, "FastSAM"):
|
11 |
+
from ultralytics import FastSAM as YOLO
|
12 |
+
else:
|
13 |
+
from ultralytics import YOLO
|
14 |
+
|
15 |
+
|
16 |
+
class FastSAM:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
checkpoint: str,
|
20 |
+
) -> None:
|
21 |
+
self.model_path = checkpoint
|
22 |
+
self.model = YOLO(self.model_path)
|
23 |
+
|
24 |
+
if not hasattr(torch.nn.Upsample, "recompute_scale_factor"):
|
25 |
+
torch.nn.Upsample.recompute_scale_factor = None
|
26 |
+
|
27 |
+
def to(self, device) -> None:
|
28 |
+
self.model.to(device)
|
29 |
+
|
30 |
+
@property
|
31 |
+
def device(self) -> Any:
|
32 |
+
return self.model.device
|
33 |
+
|
34 |
+
def __call__(self, source=None, stream=False, **kwargs) -> Any:
|
35 |
+
return self.model(source=source, stream=stream, **kwargs)
|
36 |
+
|
37 |
+
|
38 |
+
class FastSamAutomaticMaskGenerator:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model: FastSAM,
|
42 |
+
points_per_batch: int = None,
|
43 |
+
pred_iou_thresh: float = None,
|
44 |
+
stability_score_thresh: float = None,
|
45 |
+
) -> None:
|
46 |
+
self.model = model
|
47 |
+
self.points_per_batch = points_per_batch
|
48 |
+
self.pred_iou_thresh = pred_iou_thresh
|
49 |
+
self.stability_score_thresh = stability_score_thresh
|
50 |
+
self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15
|
51 |
+
|
52 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
53 |
+
height, width = image.shape[:2]
|
54 |
+
new_height = math.ceil(height / 32) * 32
|
55 |
+
new_width = math.ceil(width / 32) * 32
|
56 |
+
resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
|
57 |
+
|
58 |
+
backup_nn_dict = {}
|
59 |
+
for key, _ in torch.nn.__dict__.copy().items():
|
60 |
+
if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key:
|
61 |
+
backup_nn_dict[key] = torch.nn.__dict__.pop(key)
|
62 |
+
|
63 |
+
results = self.model(
|
64 |
+
source=resize_image,
|
65 |
+
stream=False,
|
66 |
+
imgsz=max(new_height, new_width),
|
67 |
+
device=self.model.device,
|
68 |
+
retina_masks=True,
|
69 |
+
iou=0.7,
|
70 |
+
conf=self.conf,
|
71 |
+
max_det=256)
|
72 |
+
|
73 |
+
for key, value in backup_nn_dict.items():
|
74 |
+
setattr(torch.nn, key, value)
|
75 |
+
# assert backup_nn_dict[key] == torch.nn.__dict__[key]
|
76 |
+
|
77 |
+
annotations = results[0].masks.data
|
78 |
+
|
79 |
+
if isinstance(annotations[0], torch.Tensor):
|
80 |
+
annotations = np.array(annotations.cpu())
|
81 |
+
|
82 |
+
annotations_list = []
|
83 |
+
for mask in annotations:
|
84 |
+
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
85 |
+
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8))
|
86 |
+
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA)
|
87 |
+
|
88 |
+
annotations_list.append(dict(segmentation=mask.astype(bool)))
|
89 |
+
|
90 |
+
return annotations_list
|
ia_check_versions.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from importlib.metadata import version
|
3 |
+
from importlib.util import find_spec
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from packaging.version import parse
|
7 |
+
|
8 |
+
|
9 |
+
def get_module_version(module_name):
|
10 |
+
try:
|
11 |
+
module_version = version(module_name)
|
12 |
+
except Exception:
|
13 |
+
module_version = None
|
14 |
+
return module_version
|
15 |
+
|
16 |
+
|
17 |
+
def compare_version(version1, version2):
|
18 |
+
if not isinstance(version1, str) or not isinstance(version2, str):
|
19 |
+
return None
|
20 |
+
|
21 |
+
if parse(version1) > parse(version2):
|
22 |
+
return 1
|
23 |
+
elif parse(version1) < parse(version2):
|
24 |
+
return -1
|
25 |
+
else:
|
26 |
+
return 0
|
27 |
+
|
28 |
+
|
29 |
+
def compare_module_version(module_name, version_string):
|
30 |
+
module_version = get_module_version(module_name)
|
31 |
+
|
32 |
+
result = compare_version(module_version, version_string)
|
33 |
+
return result if result is not None else -2
|
34 |
+
|
35 |
+
|
36 |
+
class IACheckVersions:
|
37 |
+
@cached_property
|
38 |
+
def diffusers_enable_cpu_offload(self):
|
39 |
+
if (find_spec("diffusers") is not None and compare_module_version("diffusers", "0.15.0") >= 0 and
|
40 |
+
find_spec("accelerate") is not None and compare_module_version("accelerate", "0.17.0") >= 0 and
|
41 |
+
torch.cuda.is_available()):
|
42 |
+
return True
|
43 |
+
else:
|
44 |
+
return False
|
45 |
+
|
46 |
+
@cached_property
|
47 |
+
def torch_mps_is_available(self):
|
48 |
+
if compare_module_version("torch", "2.0.1") < 0:
|
49 |
+
if not getattr(torch, "has_mps", False):
|
50 |
+
return False
|
51 |
+
try:
|
52 |
+
torch.zeros(1).to(torch.device("mps"))
|
53 |
+
return True
|
54 |
+
except Exception:
|
55 |
+
return False
|
56 |
+
else:
|
57 |
+
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
58 |
+
|
59 |
+
@cached_property
|
60 |
+
def torch_on_amd_rocm(self):
|
61 |
+
if find_spec("torch") is not None and "rocm" in version("torch"):
|
62 |
+
return True
|
63 |
+
else:
|
64 |
+
return False
|
65 |
+
|
66 |
+
@cached_property
|
67 |
+
def gradio_version_is_old(self):
|
68 |
+
if find_spec("gradio") is not None and compare_module_version("gradio", "3.34.0") <= 0:
|
69 |
+
return True
|
70 |
+
else:
|
71 |
+
return False
|
72 |
+
|
73 |
+
|
74 |
+
ia_check_versions = IACheckVersions()
|
ia_config.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configparser
|
2 |
+
# import json
|
3 |
+
import os
|
4 |
+
from types import SimpleNamespace
|
5 |
+
|
6 |
+
from ia_ui_items import get_inp_model_ids, get_sam_model_ids
|
7 |
+
|
8 |
+
|
9 |
+
class IAConfig:
|
10 |
+
SECTIONS = SimpleNamespace(
|
11 |
+
DEFAULT=configparser.DEFAULTSECT,
|
12 |
+
USER="USER",
|
13 |
+
)
|
14 |
+
|
15 |
+
KEYS = SimpleNamespace(
|
16 |
+
SAM_MODEL_ID="sam_model_id",
|
17 |
+
INP_MODEL_ID="inp_model_id",
|
18 |
+
)
|
19 |
+
|
20 |
+
PATHS = SimpleNamespace(
|
21 |
+
INI=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ia_config.ini"),
|
22 |
+
)
|
23 |
+
|
24 |
+
global_args = {}
|
25 |
+
|
26 |
+
def __init__(self):
|
27 |
+
self.ids_dict = {}
|
28 |
+
self.ids_dict[IAConfig.KEYS.SAM_MODEL_ID] = {
|
29 |
+
"list": get_sam_model_ids(),
|
30 |
+
"index": 1,
|
31 |
+
}
|
32 |
+
self.ids_dict[IAConfig.KEYS.INP_MODEL_ID] = {
|
33 |
+
"list": get_inp_model_ids(),
|
34 |
+
"index": 0,
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
ia_config = IAConfig()
|
39 |
+
|
40 |
+
|
41 |
+
def setup_ia_config_ini():
|
42 |
+
ia_config_ini = configparser.ConfigParser(defaults={})
|
43 |
+
if os.path.isfile(IAConfig.PATHS.INI):
|
44 |
+
ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
|
45 |
+
|
46 |
+
changed = False
|
47 |
+
for key, ids_info in ia_config.ids_dict.items():
|
48 |
+
if not ia_config_ini.has_option(IAConfig.SECTIONS.DEFAULT, key):
|
49 |
+
if len(ids_info["list"]) > ids_info["index"]:
|
50 |
+
ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
|
51 |
+
changed = True
|
52 |
+
else:
|
53 |
+
if len(ids_info["list"]) > ids_info["index"] and ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] != ids_info["list"][ids_info["index"]]:
|
54 |
+
ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
|
55 |
+
changed = True
|
56 |
+
|
57 |
+
if changed:
|
58 |
+
with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
|
59 |
+
ia_config_ini.write(f)
|
60 |
+
|
61 |
+
|
62 |
+
def get_ia_config(key, section=IAConfig.SECTIONS.DEFAULT):
|
63 |
+
setup_ia_config_ini()
|
64 |
+
|
65 |
+
ia_config_ini = configparser.ConfigParser(defaults={})
|
66 |
+
ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
|
67 |
+
|
68 |
+
if ia_config_ini.has_option(section, key):
|
69 |
+
return ia_config_ini[section][key]
|
70 |
+
|
71 |
+
section = IAConfig.SECTIONS.DEFAULT
|
72 |
+
if ia_config_ini.has_option(section, key):
|
73 |
+
return ia_config_ini[section][key]
|
74 |
+
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
def get_ia_config_index(key, section=IAConfig.SECTIONS.DEFAULT):
|
79 |
+
value = get_ia_config(key, section)
|
80 |
+
|
81 |
+
ids_dict = ia_config.ids_dict
|
82 |
+
if value is None:
|
83 |
+
if key in ids_dict.keys():
|
84 |
+
ids_info = ids_dict[key]
|
85 |
+
return ids_info["index"]
|
86 |
+
else:
|
87 |
+
return 0
|
88 |
+
else:
|
89 |
+
if key in ids_dict.keys():
|
90 |
+
ids_info = ids_dict[key]
|
91 |
+
return ids_info["list"].index(value) if value in ids_info["list"] else ids_info["index"]
|
92 |
+
else:
|
93 |
+
return 0
|
94 |
+
|
95 |
+
|
96 |
+
def set_ia_config(key, value, section=IAConfig.SECTIONS.DEFAULT):
|
97 |
+
setup_ia_config_ini()
|
98 |
+
|
99 |
+
ia_config_ini = configparser.ConfigParser(defaults={})
|
100 |
+
ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
|
101 |
+
|
102 |
+
if ia_config_ini.has_option(section, key) and ia_config_ini[section][key] == value:
|
103 |
+
return
|
104 |
+
|
105 |
+
if section != IAConfig.SECTIONS.DEFAULT and not ia_config_ini.has_section(section):
|
106 |
+
ia_config_ini[section] = {}
|
107 |
+
|
108 |
+
try:
|
109 |
+
ia_config_ini[section][key] = value
|
110 |
+
except Exception:
|
111 |
+
ia_config_ini[section] = {}
|
112 |
+
ia_config_ini[section][key] = value
|
113 |
+
|
114 |
+
with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
|
115 |
+
ia_config_ini.write(f)
|
ia_devices.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class TorchDevices:
|
5 |
+
def __init__(self):
|
6 |
+
self.cpu = torch.device("cpu")
|
7 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else self.cpu
|
8 |
+
|
9 |
+
|
10 |
+
devices = TorchDevices()
|
ia_file_manager.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
from ia_logging import ia_logging
|
5 |
+
|
6 |
+
|
7 |
+
class IAFileManager:
|
8 |
+
DOWNLOAD_COMPLETE = "Download complete"
|
9 |
+
|
10 |
+
def __init__(self) -> None:
|
11 |
+
self._ia_outputs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
12 |
+
"outputs",
|
13 |
+
datetime.now().strftime("%Y-%m-%d"))
|
14 |
+
|
15 |
+
self._ia_models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
16 |
+
|
17 |
+
@property
|
18 |
+
def outputs_dir(self) -> str:
|
19 |
+
"""Get inpaint-anything outputs directory.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: inpaint-anything outputs directory
|
23 |
+
"""
|
24 |
+
if not os.path.isdir(self._ia_outputs_dir):
|
25 |
+
os.makedirs(self._ia_outputs_dir, exist_ok=True)
|
26 |
+
return self._ia_outputs_dir
|
27 |
+
|
28 |
+
@property
|
29 |
+
def models_dir(self) -> str:
|
30 |
+
"""Get inpaint-anything models directory.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: inpaint-anything models directory
|
34 |
+
"""
|
35 |
+
if not os.path.isdir(self._ia_models_dir):
|
36 |
+
os.makedirs(self._ia_models_dir, exist_ok=True)
|
37 |
+
return self._ia_models_dir
|
38 |
+
|
39 |
+
@property
|
40 |
+
def savename_prefix(self) -> str:
|
41 |
+
"""Get inpaint-anything savename prefix.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
str: inpaint-anything savename prefix
|
45 |
+
"""
|
46 |
+
return datetime.now().strftime("%Y%m%d-%H%M%S")
|
47 |
+
|
48 |
+
|
49 |
+
ia_file_manager = IAFileManager()
|
50 |
+
|
51 |
+
|
52 |
+
def download_model_from_hf(hf_model_id, local_files_only=False):
|
53 |
+
"""Download model from HuggingFace Hub.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sam_model_id (str): HuggingFace model id
|
57 |
+
local_files_only (bool, optional): If True, use only local files. Defaults to False.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
str: download status
|
61 |
+
"""
|
62 |
+
if not local_files_only:
|
63 |
+
ia_logging.info(f"Downloading {hf_model_id}")
|
64 |
+
try:
|
65 |
+
snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only)
|
66 |
+
except FileNotFoundError:
|
67 |
+
return f"{hf_model_id} not found, please download"
|
68 |
+
except Exception as e:
|
69 |
+
return str(e)
|
70 |
+
|
71 |
+
return IAFileManager.DOWNLOAD_COMPLETE
|
ia_get_dataset_colormap.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Lint as: python2, python3
|
2 |
+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
"""Visualizes the segmentation results via specified color map.
|
17 |
+
|
18 |
+
Visualizes the semantic segmentation results by the color map
|
19 |
+
defined by the different datasets. Supported colormaps are:
|
20 |
+
|
21 |
+
* ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/).
|
22 |
+
|
23 |
+
* Cityscapes dataset (https://www.cityscapes-dataset.com).
|
24 |
+
|
25 |
+
* Mapillary Vistas (https://research.mapillary.com).
|
26 |
+
|
27 |
+
* PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
|
28 |
+
"""
|
29 |
+
|
30 |
+
from __future__ import absolute_import, division, print_function
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
|
34 |
+
# from six.moves import range
|
35 |
+
|
36 |
+
# Dataset names.
|
37 |
+
_ADE20K = 'ade20k'
|
38 |
+
_CITYSCAPES = 'cityscapes'
|
39 |
+
_MAPILLARY_VISTAS = 'mapillary_vistas'
|
40 |
+
_PASCAL = 'pascal'
|
41 |
+
|
42 |
+
# Max number of entries in the colormap for each dataset.
|
43 |
+
_DATASET_MAX_ENTRIES = {
|
44 |
+
_ADE20K: 151,
|
45 |
+
_CITYSCAPES: 256,
|
46 |
+
_MAPILLARY_VISTAS: 66,
|
47 |
+
_PASCAL: 512,
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def create_ade20k_label_colormap():
|
52 |
+
"""Creates a label colormap used in ADE20K segmentation benchmark.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
A colormap for visualizing segmentation results.
|
56 |
+
"""
|
57 |
+
return np.asarray([
|
58 |
+
[0, 0, 0],
|
59 |
+
[120, 120, 120],
|
60 |
+
[180, 120, 120],
|
61 |
+
[6, 230, 230],
|
62 |
+
[80, 50, 50],
|
63 |
+
[4, 200, 3],
|
64 |
+
[120, 120, 80],
|
65 |
+
[140, 140, 140],
|
66 |
+
[204, 5, 255],
|
67 |
+
[230, 230, 230],
|
68 |
+
[4, 250, 7],
|
69 |
+
[224, 5, 255],
|
70 |
+
[235, 255, 7],
|
71 |
+
[150, 5, 61],
|
72 |
+
[120, 120, 70],
|
73 |
+
[8, 255, 51],
|
74 |
+
[255, 6, 82],
|
75 |
+
[143, 255, 140],
|
76 |
+
[204, 255, 4],
|
77 |
+
[255, 51, 7],
|
78 |
+
[204, 70, 3],
|
79 |
+
[0, 102, 200],
|
80 |
+
[61, 230, 250],
|
81 |
+
[255, 6, 51],
|
82 |
+
[11, 102, 255],
|
83 |
+
[255, 7, 71],
|
84 |
+
[255, 9, 224],
|
85 |
+
[9, 7, 230],
|
86 |
+
[220, 220, 220],
|
87 |
+
[255, 9, 92],
|
88 |
+
[112, 9, 255],
|
89 |
+
[8, 255, 214],
|
90 |
+
[7, 255, 224],
|
91 |
+
[255, 184, 6],
|
92 |
+
[10, 255, 71],
|
93 |
+
[255, 41, 10],
|
94 |
+
[7, 255, 255],
|
95 |
+
[224, 255, 8],
|
96 |
+
[102, 8, 255],
|
97 |
+
[255, 61, 6],
|
98 |
+
[255, 194, 7],
|
99 |
+
[255, 122, 8],
|
100 |
+
[0, 255, 20],
|
101 |
+
[255, 8, 41],
|
102 |
+
[255, 5, 153],
|
103 |
+
[6, 51, 255],
|
104 |
+
[235, 12, 255],
|
105 |
+
[160, 150, 20],
|
106 |
+
[0, 163, 255],
|
107 |
+
[140, 140, 140],
|
108 |
+
[250, 10, 15],
|
109 |
+
[20, 255, 0],
|
110 |
+
[31, 255, 0],
|
111 |
+
[255, 31, 0],
|
112 |
+
[255, 224, 0],
|
113 |
+
[153, 255, 0],
|
114 |
+
[0, 0, 255],
|
115 |
+
[255, 71, 0],
|
116 |
+
[0, 235, 255],
|
117 |
+
[0, 173, 255],
|
118 |
+
[31, 0, 255],
|
119 |
+
[11, 200, 200],
|
120 |
+
[255, 82, 0],
|
121 |
+
[0, 255, 245],
|
122 |
+
[0, 61, 255],
|
123 |
+
[0, 255, 112],
|
124 |
+
[0, 255, 133],
|
125 |
+
[255, 0, 0],
|
126 |
+
[255, 163, 0],
|
127 |
+
[255, 102, 0],
|
128 |
+
[194, 255, 0],
|
129 |
+
[0, 143, 255],
|
130 |
+
[51, 255, 0],
|
131 |
+
[0, 82, 255],
|
132 |
+
[0, 255, 41],
|
133 |
+
[0, 255, 173],
|
134 |
+
[10, 0, 255],
|
135 |
+
[173, 255, 0],
|
136 |
+
[0, 255, 153],
|
137 |
+
[255, 92, 0],
|
138 |
+
[255, 0, 255],
|
139 |
+
[255, 0, 245],
|
140 |
+
[255, 0, 102],
|
141 |
+
[255, 173, 0],
|
142 |
+
[255, 0, 20],
|
143 |
+
[255, 184, 184],
|
144 |
+
[0, 31, 255],
|
145 |
+
[0, 255, 61],
|
146 |
+
[0, 71, 255],
|
147 |
+
[255, 0, 204],
|
148 |
+
[0, 255, 194],
|
149 |
+
[0, 255, 82],
|
150 |
+
[0, 10, 255],
|
151 |
+
[0, 112, 255],
|
152 |
+
[51, 0, 255],
|
153 |
+
[0, 194, 255],
|
154 |
+
[0, 122, 255],
|
155 |
+
[0, 255, 163],
|
156 |
+
[255, 153, 0],
|
157 |
+
[0, 255, 10],
|
158 |
+
[255, 112, 0],
|
159 |
+
[143, 255, 0],
|
160 |
+
[82, 0, 255],
|
161 |
+
[163, 255, 0],
|
162 |
+
[255, 235, 0],
|
163 |
+
[8, 184, 170],
|
164 |
+
[133, 0, 255],
|
165 |
+
[0, 255, 92],
|
166 |
+
[184, 0, 255],
|
167 |
+
[255, 0, 31],
|
168 |
+
[0, 184, 255],
|
169 |
+
[0, 214, 255],
|
170 |
+
[255, 0, 112],
|
171 |
+
[92, 255, 0],
|
172 |
+
[0, 224, 255],
|
173 |
+
[112, 224, 255],
|
174 |
+
[70, 184, 160],
|
175 |
+
[163, 0, 255],
|
176 |
+
[153, 0, 255],
|
177 |
+
[71, 255, 0],
|
178 |
+
[255, 0, 163],
|
179 |
+
[255, 204, 0],
|
180 |
+
[255, 0, 143],
|
181 |
+
[0, 255, 235],
|
182 |
+
[133, 255, 0],
|
183 |
+
[255, 0, 235],
|
184 |
+
[245, 0, 255],
|
185 |
+
[255, 0, 122],
|
186 |
+
[255, 245, 0],
|
187 |
+
[10, 190, 212],
|
188 |
+
[214, 255, 0],
|
189 |
+
[0, 204, 255],
|
190 |
+
[20, 0, 255],
|
191 |
+
[255, 255, 0],
|
192 |
+
[0, 153, 255],
|
193 |
+
[0, 41, 255],
|
194 |
+
[0, 255, 204],
|
195 |
+
[41, 0, 255],
|
196 |
+
[41, 255, 0],
|
197 |
+
[173, 0, 255],
|
198 |
+
[0, 245, 255],
|
199 |
+
[71, 0, 255],
|
200 |
+
[122, 0, 255],
|
201 |
+
[0, 255, 184],
|
202 |
+
[0, 92, 255],
|
203 |
+
[184, 255, 0],
|
204 |
+
[0, 133, 255],
|
205 |
+
[255, 214, 0],
|
206 |
+
[25, 194, 194],
|
207 |
+
[102, 255, 0],
|
208 |
+
[92, 0, 255],
|
209 |
+
])
|
210 |
+
|
211 |
+
|
212 |
+
def create_cityscapes_label_colormap():
|
213 |
+
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
A colormap for visualizing segmentation results.
|
217 |
+
"""
|
218 |
+
colormap = np.zeros((256, 3), dtype=np.uint8)
|
219 |
+
colormap[0] = [128, 64, 128]
|
220 |
+
colormap[1] = [244, 35, 232]
|
221 |
+
colormap[2] = [70, 70, 70]
|
222 |
+
colormap[3] = [102, 102, 156]
|
223 |
+
colormap[4] = [190, 153, 153]
|
224 |
+
colormap[5] = [153, 153, 153]
|
225 |
+
colormap[6] = [250, 170, 30]
|
226 |
+
colormap[7] = [220, 220, 0]
|
227 |
+
colormap[8] = [107, 142, 35]
|
228 |
+
colormap[9] = [152, 251, 152]
|
229 |
+
colormap[10] = [70, 130, 180]
|
230 |
+
colormap[11] = [220, 20, 60]
|
231 |
+
colormap[12] = [255, 0, 0]
|
232 |
+
colormap[13] = [0, 0, 142]
|
233 |
+
colormap[14] = [0, 0, 70]
|
234 |
+
colormap[15] = [0, 60, 100]
|
235 |
+
colormap[16] = [0, 80, 100]
|
236 |
+
colormap[17] = [0, 0, 230]
|
237 |
+
colormap[18] = [119, 11, 32]
|
238 |
+
return colormap
|
239 |
+
|
240 |
+
|
241 |
+
def create_mapillary_vistas_label_colormap():
|
242 |
+
"""Creates a label colormap used in Mapillary Vistas segmentation benchmark.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
A colormap for visualizing segmentation results.
|
246 |
+
"""
|
247 |
+
return np.asarray([
|
248 |
+
[165, 42, 42],
|
249 |
+
[0, 192, 0],
|
250 |
+
[196, 196, 196],
|
251 |
+
[190, 153, 153],
|
252 |
+
[180, 165, 180],
|
253 |
+
[102, 102, 156],
|
254 |
+
[102, 102, 156],
|
255 |
+
[128, 64, 255],
|
256 |
+
[140, 140, 200],
|
257 |
+
[170, 170, 170],
|
258 |
+
[250, 170, 160],
|
259 |
+
[96, 96, 96],
|
260 |
+
[230, 150, 140],
|
261 |
+
[128, 64, 128],
|
262 |
+
[110, 110, 110],
|
263 |
+
[244, 35, 232],
|
264 |
+
[150, 100, 100],
|
265 |
+
[70, 70, 70],
|
266 |
+
[150, 120, 90],
|
267 |
+
[220, 20, 60],
|
268 |
+
[255, 0, 0],
|
269 |
+
[255, 0, 0],
|
270 |
+
[255, 0, 0],
|
271 |
+
[200, 128, 128],
|
272 |
+
[255, 255, 255],
|
273 |
+
[64, 170, 64],
|
274 |
+
[128, 64, 64],
|
275 |
+
[70, 130, 180],
|
276 |
+
[255, 255, 255],
|
277 |
+
[152, 251, 152],
|
278 |
+
[107, 142, 35],
|
279 |
+
[0, 170, 30],
|
280 |
+
[255, 255, 128],
|
281 |
+
[250, 0, 30],
|
282 |
+
[0, 0, 0],
|
283 |
+
[220, 220, 220],
|
284 |
+
[170, 170, 170],
|
285 |
+
[222, 40, 40],
|
286 |
+
[100, 170, 30],
|
287 |
+
[40, 40, 40],
|
288 |
+
[33, 33, 33],
|
289 |
+
[170, 170, 170],
|
290 |
+
[0, 0, 142],
|
291 |
+
[170, 170, 170],
|
292 |
+
[210, 170, 100],
|
293 |
+
[153, 153, 153],
|
294 |
+
[128, 128, 128],
|
295 |
+
[0, 0, 142],
|
296 |
+
[250, 170, 30],
|
297 |
+
[192, 192, 192],
|
298 |
+
[220, 220, 0],
|
299 |
+
[180, 165, 180],
|
300 |
+
[119, 11, 32],
|
301 |
+
[0, 0, 142],
|
302 |
+
[0, 60, 100],
|
303 |
+
[0, 0, 142],
|
304 |
+
[0, 0, 90],
|
305 |
+
[0, 0, 230],
|
306 |
+
[0, 80, 100],
|
307 |
+
[128, 64, 64],
|
308 |
+
[0, 0, 110],
|
309 |
+
[0, 0, 70],
|
310 |
+
[0, 0, 192],
|
311 |
+
[32, 32, 32],
|
312 |
+
[0, 0, 0],
|
313 |
+
[0, 0, 0],
|
314 |
+
])
|
315 |
+
|
316 |
+
|
317 |
+
def create_pascal_label_colormap():
|
318 |
+
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
A colormap for visualizing segmentation results.
|
322 |
+
"""
|
323 |
+
colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
|
324 |
+
ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
|
325 |
+
|
326 |
+
for shift in reversed(list(range(8))):
|
327 |
+
for channel in range(3):
|
328 |
+
colormap[:, channel] |= bit_get(ind, channel) << shift
|
329 |
+
ind >>= 3
|
330 |
+
|
331 |
+
return colormap
|
332 |
+
|
333 |
+
|
334 |
+
def get_ade20k_name():
|
335 |
+
return _ADE20K
|
336 |
+
|
337 |
+
|
338 |
+
def get_cityscapes_name():
|
339 |
+
return _CITYSCAPES
|
340 |
+
|
341 |
+
|
342 |
+
def get_mapillary_vistas_name():
|
343 |
+
return _MAPILLARY_VISTAS
|
344 |
+
|
345 |
+
|
346 |
+
def get_pascal_name():
|
347 |
+
return _PASCAL
|
348 |
+
|
349 |
+
|
350 |
+
def bit_get(val, idx):
|
351 |
+
"""Gets the bit value.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
val: Input value, int or numpy int array.
|
355 |
+
idx: Which bit of the input val.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
The "idx"-th bit of input val.
|
359 |
+
"""
|
360 |
+
return (val >> idx) & 1
|
361 |
+
|
362 |
+
|
363 |
+
def create_label_colormap(dataset=_PASCAL):
|
364 |
+
"""Creates a label colormap for the specified dataset.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
dataset: The colormap used in the dataset.
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
A numpy array of the dataset colormap.
|
371 |
+
|
372 |
+
Raises:
|
373 |
+
ValueError: If the dataset is not supported.
|
374 |
+
"""
|
375 |
+
if dataset == _ADE20K:
|
376 |
+
return create_ade20k_label_colormap()
|
377 |
+
elif dataset == _CITYSCAPES:
|
378 |
+
return create_cityscapes_label_colormap()
|
379 |
+
elif dataset == _MAPILLARY_VISTAS:
|
380 |
+
return create_mapillary_vistas_label_colormap()
|
381 |
+
elif dataset == _PASCAL:
|
382 |
+
return create_pascal_label_colormap()
|
383 |
+
else:
|
384 |
+
raise ValueError('Unsupported dataset.')
|
385 |
+
|
386 |
+
|
387 |
+
def label_to_color_image(label, dataset=_PASCAL):
|
388 |
+
"""Adds color defined by the dataset colormap to the label.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
label: A 2D array with integer type, storing the segmentation label.
|
392 |
+
dataset: The colormap used in the dataset.
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
result: A 2D array with floating type. The element of the array
|
396 |
+
is the color indexed by the corresponding element in the input label
|
397 |
+
to the dataset color map.
|
398 |
+
|
399 |
+
Raises:
|
400 |
+
ValueError: If label is not of rank 2 or its value is larger than color
|
401 |
+
map maximum entry.
|
402 |
+
"""
|
403 |
+
if label.ndim != 2:
|
404 |
+
raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
|
405 |
+
|
406 |
+
if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
|
407 |
+
raise ValueError(
|
408 |
+
'label value too large: {} >= {}.'.format(
|
409 |
+
np.max(label), _DATASET_MAX_ENTRIES[dataset]))
|
410 |
+
|
411 |
+
colormap = create_label_colormap(dataset)
|
412 |
+
return colormap[label]
|
413 |
+
|
414 |
+
|
415 |
+
def get_dataset_colormap_max_entries(dataset):
|
416 |
+
return _DATASET_MAX_ENTRIES[dataset]
|
ia_logging.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
5 |
+
warnings.filterwarnings(action="ignore", category=FutureWarning, module="huggingface_hub")
|
6 |
+
|
7 |
+
ia_logging = logging.getLogger("Inpaint Anything")
|
8 |
+
ia_logging.setLevel(logging.INFO)
|
9 |
+
ia_logging.propagate = False
|
10 |
+
|
11 |
+
ia_logging_sh = logging.StreamHandler()
|
12 |
+
ia_logging_sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
13 |
+
ia_logging_sh.setLevel(logging.INFO)
|
14 |
+
ia_logging.addHandler(ia_logging_sh)
|
ia_sam_manager.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
|
8 |
+
from ia_check_versions import ia_check_versions
|
9 |
+
from ia_config import IAConfig
|
10 |
+
from ia_devices import devices
|
11 |
+
from ia_logging import ia_logging
|
12 |
+
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
|
13 |
+
from mobile_sam import SamPredictor as SamPredictorMobile
|
14 |
+
from mobile_sam import sam_model_registry as sam_model_registry_mobile
|
15 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
16 |
+
from sam2.build_sam import build_sam2
|
17 |
+
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
18 |
+
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
|
19 |
+
from segment_anything_hq import SamPredictor as SamPredictorHQ
|
20 |
+
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
|
21 |
+
|
22 |
+
|
23 |
+
def check_bfloat16_support() -> bool:
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
|
26 |
+
if compute_capability[0] >= 8:
|
27 |
+
ia_logging.debug("The CUDA device supports bfloat16")
|
28 |
+
return True
|
29 |
+
else:
|
30 |
+
ia_logging.debug("The CUDA device does not support bfloat16")
|
31 |
+
return False
|
32 |
+
else:
|
33 |
+
ia_logging.debug("CUDA is not available")
|
34 |
+
return False
|
35 |
+
|
36 |
+
|
37 |
+
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
|
38 |
+
def wrapper(*args, **kwargs):
|
39 |
+
updated_kwargs = {**fixed_kwargs, **kwargs}
|
40 |
+
return func(*args, *fixed_args, **updated_kwargs)
|
41 |
+
return wrapper
|
42 |
+
|
43 |
+
|
44 |
+
def rename_args(func, arg_map):
|
45 |
+
def wrapper(*args, **kwargs):
|
46 |
+
new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()}
|
47 |
+
return func(*args, **new_kwargs)
|
48 |
+
return wrapper
|
49 |
+
|
50 |
+
|
51 |
+
arg_map = {"checkpoint": "ckpt_path"}
|
52 |
+
rename_build_sam2 = rename_args(build_sam2, arg_map)
|
53 |
+
end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False)
|
54 |
+
sam2_model_registry = {
|
55 |
+
"sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"),
|
56 |
+
"sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"),
|
57 |
+
"sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"),
|
58 |
+
"sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"),
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
|
63 |
+
"""Get SAM mask generator.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
sam_checkpoint (str): SAM checkpoint path
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
SamAutomaticMaskGenerator or None: SAM mask generator
|
70 |
+
"""
|
71 |
+
points_per_batch = 64
|
72 |
+
if "_hq_" in os.path.basename(sam_checkpoint):
|
73 |
+
model_type = os.path.basename(sam_checkpoint)[7:12]
|
74 |
+
sam_model_registry_local = sam_model_registry_hq
|
75 |
+
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
|
76 |
+
points_per_batch = 32
|
77 |
+
elif "FastSAM" in os.path.basename(sam_checkpoint):
|
78 |
+
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
|
79 |
+
sam_model_registry_local = fast_sam_model_registry
|
80 |
+
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
|
81 |
+
points_per_batch = None
|
82 |
+
elif "mobile_sam" in os.path.basename(sam_checkpoint):
|
83 |
+
model_type = "vit_t"
|
84 |
+
sam_model_registry_local = sam_model_registry_mobile
|
85 |
+
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
|
86 |
+
points_per_batch = 64
|
87 |
+
elif "sam2_" in os.path.basename(sam_checkpoint):
|
88 |
+
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
|
89 |
+
sam_model_registry_local = sam2_model_registry
|
90 |
+
SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator
|
91 |
+
points_per_batch = 128
|
92 |
+
else:
|
93 |
+
model_type = os.path.basename(sam_checkpoint)[4:9]
|
94 |
+
sam_model_registry_local = sam_model_registry
|
95 |
+
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
|
96 |
+
points_per_batch = 64
|
97 |
+
|
98 |
+
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
|
99 |
+
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
|
100 |
+
|
101 |
+
if "sam2_" in model_type:
|
102 |
+
pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
|
103 |
+
stability_score_thresh = round(stability_score_thresh - 0.03, 2)
|
104 |
+
sam2_gen_kwargs = dict(
|
105 |
+
points_per_side=64,
|
106 |
+
points_per_batch=points_per_batch,
|
107 |
+
pred_iou_thresh=pred_iou_thresh,
|
108 |
+
stability_score_thresh=stability_score_thresh,
|
109 |
+
stability_score_offset=0.7,
|
110 |
+
crop_n_layers=1,
|
111 |
+
box_nms_thresh=0.7,
|
112 |
+
crop_n_points_downscale_factor=2)
|
113 |
+
if platform.system() == "Darwin":
|
114 |
+
sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1))
|
115 |
+
|
116 |
+
if os.path.isfile(sam_checkpoint):
|
117 |
+
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
|
118 |
+
if platform.system() == "Darwin":
|
119 |
+
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
|
120 |
+
sam.to(device=torch.device("cpu"))
|
121 |
+
else:
|
122 |
+
sam.to(device=torch.device("mps"))
|
123 |
+
else:
|
124 |
+
if IAConfig.global_args.get("sam_cpu", False):
|
125 |
+
ia_logging.info("SAM is running on CPU... (the option has been selected)")
|
126 |
+
sam.to(device=devices.cpu)
|
127 |
+
else:
|
128 |
+
sam.to(device=devices.device)
|
129 |
+
sam_gen_kwargs = dict(
|
130 |
+
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
|
131 |
+
if "sam2_" in model_type:
|
132 |
+
sam_gen_kwargs.update(sam2_gen_kwargs)
|
133 |
+
sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs)
|
134 |
+
else:
|
135 |
+
sam_mask_generator = None
|
136 |
+
|
137 |
+
return sam_mask_generator
|
138 |
+
|
139 |
+
|
140 |
+
def get_sam_predictor(sam_checkpoint):
|
141 |
+
"""Get SAM predictor.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
sam_checkpoint (str): SAM checkpoint path
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
SamPredictor or None: SAM predictor
|
148 |
+
"""
|
149 |
+
# model_type = "vit_h"
|
150 |
+
if "_hq_" in os.path.basename(sam_checkpoint):
|
151 |
+
model_type = os.path.basename(sam_checkpoint)[7:12]
|
152 |
+
sam_model_registry_local = sam_model_registry_hq
|
153 |
+
SamPredictorLocal = SamPredictorHQ
|
154 |
+
elif "FastSAM" in os.path.basename(sam_checkpoint):
|
155 |
+
raise NotImplementedError("FastSAM predictor is not implemented yet.")
|
156 |
+
elif "mobile_sam" in os.path.basename(sam_checkpoint):
|
157 |
+
model_type = "vit_t"
|
158 |
+
sam_model_registry_local = sam_model_registry_mobile
|
159 |
+
SamPredictorLocal = SamPredictorMobile
|
160 |
+
else:
|
161 |
+
model_type = os.path.basename(sam_checkpoint)[4:9]
|
162 |
+
sam_model_registry_local = sam_model_registry
|
163 |
+
SamPredictorLocal = SamPredictor
|
164 |
+
|
165 |
+
if os.path.isfile(sam_checkpoint):
|
166 |
+
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
|
167 |
+
if platform.system() == "Darwin":
|
168 |
+
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
|
169 |
+
sam.to(device=torch.device("cpu"))
|
170 |
+
else:
|
171 |
+
sam.to(device=torch.device("mps"))
|
172 |
+
else:
|
173 |
+
if IAConfig.global_args.get("sam_cpu", False):
|
174 |
+
ia_logging.info("SAM is running on CPU... (the option has been selected)")
|
175 |
+
sam.to(device=devices.cpu)
|
176 |
+
else:
|
177 |
+
sam.to(device=devices.device)
|
178 |
+
sam_predictor = SamPredictorLocal(sam)
|
179 |
+
else:
|
180 |
+
sam_predictor = None
|
181 |
+
|
182 |
+
return sam_predictor
|
ia_threading.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import inspect
|
3 |
+
import threading
|
4 |
+
from functools import wraps
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from ia_check_versions import ia_check_versions
|
9 |
+
|
10 |
+
model_access_sem = threading.Semaphore(1)
|
11 |
+
|
12 |
+
|
13 |
+
def torch_gc():
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
torch.cuda.empty_cache()
|
16 |
+
torch.cuda.ipc_collect()
|
17 |
+
if ia_check_versions.torch_mps_is_available:
|
18 |
+
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
|
19 |
+
torch.mps.empty_cache()
|
20 |
+
|
21 |
+
|
22 |
+
def clear_cache():
|
23 |
+
gc.collect()
|
24 |
+
torch_gc()
|
25 |
+
|
26 |
+
|
27 |
+
def post_clear_cache(sem):
|
28 |
+
with sem:
|
29 |
+
gc.collect()
|
30 |
+
torch_gc()
|
31 |
+
|
32 |
+
|
33 |
+
def async_post_clear_cache():
|
34 |
+
thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,))
|
35 |
+
thread.start()
|
36 |
+
|
37 |
+
|
38 |
+
def clear_cache_decorator(func):
|
39 |
+
@wraps(func)
|
40 |
+
def yield_wrapper(*args, **kwargs):
|
41 |
+
clear_cache()
|
42 |
+
yield from func(*args, **kwargs)
|
43 |
+
clear_cache()
|
44 |
+
|
45 |
+
@wraps(func)
|
46 |
+
def wrapper(*args, **kwargs):
|
47 |
+
clear_cache()
|
48 |
+
res = func(*args, **kwargs)
|
49 |
+
clear_cache()
|
50 |
+
return res
|
51 |
+
|
52 |
+
if inspect.isgeneratorfunction(func):
|
53 |
+
return yield_wrapper
|
54 |
+
else:
|
55 |
+
return wrapper
|
ia_ui_gradio.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
6 |
+
|
7 |
+
|
8 |
+
def webpath(fn):
|
9 |
+
web_path = os.path.realpath(fn)
|
10 |
+
|
11 |
+
return f'file={web_path}?{os.path.getmtime(fn)}'
|
12 |
+
|
13 |
+
|
14 |
+
def javascript_html():
|
15 |
+
script_path = os.path.join(os.path.dirname(__file__), "javascript", "inpaint-anything.js")
|
16 |
+
head = f'<script type="text/javascript" src="{webpath(script_path)}"></script>\n'
|
17 |
+
|
18 |
+
return head
|
19 |
+
|
20 |
+
|
21 |
+
def reload_javascript():
|
22 |
+
js = javascript_html()
|
23 |
+
|
24 |
+
def template_response(*args, **kwargs):
|
25 |
+
res = GradioTemplateResponseOriginal(*args, **kwargs)
|
26 |
+
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
27 |
+
res.init_headers()
|
28 |
+
return res
|
29 |
+
|
30 |
+
gr.routes.templates.TemplateResponse = template_response
|
ia_ui_items.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import scan_cache_dir
|
2 |
+
|
3 |
+
|
4 |
+
def get_sampler_names():
|
5 |
+
"""Get sampler name list.
|
6 |
+
|
7 |
+
Returns:
|
8 |
+
list: sampler name list
|
9 |
+
"""
|
10 |
+
sampler_names = [
|
11 |
+
"DDIM",
|
12 |
+
"Euler",
|
13 |
+
"Euler a",
|
14 |
+
"DPM2 Karras",
|
15 |
+
"DPM2 a Karras",
|
16 |
+
]
|
17 |
+
return sampler_names
|
18 |
+
|
19 |
+
|
20 |
+
def get_sam_model_ids():
|
21 |
+
"""Get SAM model ids list.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
list: SAM model ids list
|
25 |
+
"""
|
26 |
+
sam_model_ids = [
|
27 |
+
"sam2_hiera_large.pt",
|
28 |
+
"sam2_hiera_base_plus.pt",
|
29 |
+
"sam2_hiera_small.pt",
|
30 |
+
"sam2_hiera_tiny.pt",
|
31 |
+
"sam_vit_h_4b8939.pth",
|
32 |
+
"sam_vit_l_0b3195.pth",
|
33 |
+
"sam_vit_b_01ec64.pth",
|
34 |
+
"sam_hq_vit_h.pth",
|
35 |
+
"sam_hq_vit_l.pth",
|
36 |
+
"sam_hq_vit_b.pth",
|
37 |
+
"FastSAM-x.pt",
|
38 |
+
"FastSAM-s.pt",
|
39 |
+
"mobile_sam.pt",
|
40 |
+
]
|
41 |
+
return sam_model_ids
|
42 |
+
|
43 |
+
|
44 |
+
inp_list_from_cache = None
|
45 |
+
|
46 |
+
|
47 |
+
def get_inp_model_ids():
|
48 |
+
"""Get inpainting model ids list.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
list: model ids list
|
52 |
+
"""
|
53 |
+
global inp_list_from_cache
|
54 |
+
model_ids = [
|
55 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
56 |
+
"Uminosachi/dreamshaper_8Inpainting",
|
57 |
+
"Uminosachi/deliberate_v3-inpainting",
|
58 |
+
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
|
59 |
+
"Uminosachi/revAnimated_v121Inp-inpainting",
|
60 |
+
"runwayml/stable-diffusion-inpainting",
|
61 |
+
]
|
62 |
+
if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list):
|
63 |
+
model_ids.extend(inp_list_from_cache)
|
64 |
+
return model_ids
|
65 |
+
try:
|
66 |
+
hf_cache_info = scan_cache_dir()
|
67 |
+
inpaint_repos = []
|
68 |
+
for repo in hf_cache_info.repos:
|
69 |
+
if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids:
|
70 |
+
inpaint_repos.append(repo.repo_id)
|
71 |
+
inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1])
|
72 |
+
model_ids.extend(inp_list_from_cache)
|
73 |
+
return model_ids
|
74 |
+
except Exception:
|
75 |
+
return model_ids
|
76 |
+
|
77 |
+
|
78 |
+
def get_cleaner_model_ids():
|
79 |
+
"""Get cleaner model ids list.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
list: model ids list
|
83 |
+
"""
|
84 |
+
model_ids = [
|
85 |
+
"lama",
|
86 |
+
"ldm",
|
87 |
+
"zits",
|
88 |
+
"mat",
|
89 |
+
"fcf",
|
90 |
+
"manga",
|
91 |
+
]
|
92 |
+
return model_ids
|
93 |
+
|
94 |
+
|
95 |
+
def get_padding_mode_names():
|
96 |
+
"""Get padding mode name list.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
list: padding mode name list
|
100 |
+
"""
|
101 |
+
padding_mode_names = [
|
102 |
+
"constant",
|
103 |
+
"edge",
|
104 |
+
"reflect",
|
105 |
+
"mean",
|
106 |
+
"median",
|
107 |
+
"maximum",
|
108 |
+
"minimum",
|
109 |
+
]
|
110 |
+
return padding_mode_names
|
iasam_app.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
# import math
|
3 |
+
import gc
|
4 |
+
import os
|
5 |
+
import platform
|
6 |
+
|
7 |
+
if platform.system() == "Darwin":
|
8 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
9 |
+
|
10 |
+
if platform.system() == "Windows":
|
11 |
+
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
|
12 |
+
|
13 |
+
import random
|
14 |
+
import traceback
|
15 |
+
from importlib.util import find_spec
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import gradio as gr
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from diffusers import (DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
22 |
+
KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler,
|
23 |
+
StableDiffusionInpaintPipeline)
|
24 |
+
from PIL import Image, ImageFilter
|
25 |
+
from PIL.PngImagePlugin import PngInfo
|
26 |
+
from torch.hub import download_url_to_file
|
27 |
+
from torchvision import transforms
|
28 |
+
|
29 |
+
import inpalib
|
30 |
+
from ia_check_versions import ia_check_versions
|
31 |
+
from ia_config import IAConfig, get_ia_config_index, set_ia_config, setup_ia_config_ini
|
32 |
+
from ia_devices import devices
|
33 |
+
from ia_file_manager import IAFileManager, download_model_from_hf, ia_file_manager
|
34 |
+
from ia_logging import ia_logging
|
35 |
+
from ia_threading import clear_cache_decorator
|
36 |
+
from ia_ui_gradio import reload_javascript
|
37 |
+
from ia_ui_items import (get_cleaner_model_ids, get_inp_model_ids, get_padding_mode_names,
|
38 |
+
get_sam_model_ids, get_sampler_names)
|
39 |
+
from lama_cleaner.model_manager import ModelManager
|
40 |
+
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
|
41 |
+
|
42 |
+
print("platform:", platform.system())
|
43 |
+
|
44 |
+
reload_javascript()
|
45 |
+
|
46 |
+
if find_spec("xformers") is not None:
|
47 |
+
xformers_available = True
|
48 |
+
else:
|
49 |
+
xformers_available = False
|
50 |
+
|
51 |
+
parser = argparse.ArgumentParser(description="Inpaint Anything")
|
52 |
+
parser.add_argument("--save-seg", action="store_true", help="Save the segmentation image generated by SAM.")
|
53 |
+
parser.add_argument("--offline", action="store_true", help="Execute inpainting using an offline network.")
|
54 |
+
parser.add_argument("--sam-cpu", action="store_true", help="Perform the Segment Anything operation on CPU.")
|
55 |
+
args = parser.parse_args()
|
56 |
+
IAConfig.global_args.update(args.__dict__)
|
57 |
+
|
58 |
+
|
59 |
+
@clear_cache_decorator
|
60 |
+
def download_model(sam_model_id):
|
61 |
+
"""Download SAM model.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
sam_model_id (str): SAM model id
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
str: download status
|
68 |
+
"""
|
69 |
+
if "_hq_" in sam_model_id:
|
70 |
+
url_sam = "https://huggingface.co/Uminosachi/sam-hq/resolve/main/" + sam_model_id
|
71 |
+
elif "FastSAM" in sam_model_id:
|
72 |
+
url_sam = "https://huggingface.co/Uminosachi/FastSAM/resolve/main/" + sam_model_id
|
73 |
+
elif "mobile_sam" in sam_model_id:
|
74 |
+
url_sam = "https://huggingface.co/Uminosachi/MobileSAM/resolve/main/" + sam_model_id
|
75 |
+
elif "sam2_" in sam_model_id:
|
76 |
+
url_sam = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" + sam_model_id
|
77 |
+
else:
|
78 |
+
url_sam = "https://dl.fbaipublicfiles.com/segment_anything/" + sam_model_id
|
79 |
+
|
80 |
+
sam_checkpoint = os.path.join(ia_file_manager.models_dir, sam_model_id)
|
81 |
+
if not os.path.isfile(sam_checkpoint):
|
82 |
+
try:
|
83 |
+
download_url_to_file(url_sam, sam_checkpoint)
|
84 |
+
except Exception as e:
|
85 |
+
ia_logging.error(str(e))
|
86 |
+
return str(e)
|
87 |
+
|
88 |
+
return IAFileManager.DOWNLOAD_COMPLETE
|
89 |
+
else:
|
90 |
+
return "Model already exists"
|
91 |
+
|
92 |
+
|
93 |
+
sam_dict = dict(sam_masks=None, mask_image=None, cnet=None, orig_image=None, pad_mask=None)
|
94 |
+
|
95 |
+
|
96 |
+
def save_mask_image(mask_image, save_mask_chk=False):
|
97 |
+
"""Save mask image.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
mask_image (np.ndarray): mask image
|
101 |
+
save_mask_chk (bool, optional): If True, save mask image. Defaults to False.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
None
|
105 |
+
"""
|
106 |
+
if save_mask_chk:
|
107 |
+
save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png"
|
108 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
109 |
+
Image.fromarray(mask_image).save(save_name)
|
110 |
+
|
111 |
+
|
112 |
+
@clear_cache_decorator
|
113 |
+
def input_image_upload(input_image, sam_image, sel_mask):
|
114 |
+
global sam_dict
|
115 |
+
sam_dict["orig_image"] = input_image
|
116 |
+
sam_dict["pad_mask"] = None
|
117 |
+
|
118 |
+
if (sam_dict["mask_image"] is None or not isinstance(sam_dict["mask_image"], np.ndarray) or
|
119 |
+
sam_dict["mask_image"].shape != input_image.shape):
|
120 |
+
sam_dict["mask_image"] = np.zeros_like(input_image, dtype=np.uint8)
|
121 |
+
|
122 |
+
ret_sel_image = cv2.addWeighted(input_image, 0.5, sam_dict["mask_image"], 0.5, 0)
|
123 |
+
|
124 |
+
if sam_image is None or not isinstance(sam_image, dict) or "image" not in sam_image:
|
125 |
+
sam_dict["sam_masks"] = None
|
126 |
+
ret_sam_image = np.zeros_like(input_image, dtype=np.uint8)
|
127 |
+
elif sam_image["image"].shape == input_image.shape:
|
128 |
+
ret_sam_image = gr.update()
|
129 |
+
else:
|
130 |
+
sam_dict["sam_masks"] = None
|
131 |
+
ret_sam_image = gr.update(value=np.zeros_like(input_image, dtype=np.uint8))
|
132 |
+
|
133 |
+
if sel_mask is None or not isinstance(sel_mask, dict) or "image" not in sel_mask:
|
134 |
+
ret_sel_mask = ret_sel_image
|
135 |
+
elif sel_mask["image"].shape == ret_sel_image.shape and np.all(sel_mask["image"] == ret_sel_image):
|
136 |
+
ret_sel_mask = gr.update()
|
137 |
+
else:
|
138 |
+
ret_sel_mask = gr.update(value=ret_sel_image)
|
139 |
+
|
140 |
+
return ret_sam_image, ret_sel_mask, gr.update(interactive=True)
|
141 |
+
|
142 |
+
|
143 |
+
@clear_cache_decorator
|
144 |
+
def run_padding(input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode="edge"):
|
145 |
+
global sam_dict
|
146 |
+
if input_image is None or sam_dict["orig_image"] is None:
|
147 |
+
sam_dict["orig_image"] = None
|
148 |
+
sam_dict["pad_mask"] = None
|
149 |
+
return None, "Input image not found"
|
150 |
+
|
151 |
+
orig_image = sam_dict["orig_image"]
|
152 |
+
|
153 |
+
height, width = orig_image.shape[:2]
|
154 |
+
pad_width, pad_height = (int(width * pad_scale_width), int(height * pad_scale_height))
|
155 |
+
ia_logging.info(f"resize by padding: ({height}, {width}) -> ({pad_height}, {pad_width})")
|
156 |
+
|
157 |
+
pad_size_w, pad_size_h = (pad_width - width, pad_height - height)
|
158 |
+
pad_size_l = int(pad_size_w * pad_lr_barance)
|
159 |
+
pad_size_r = pad_size_w - pad_size_l
|
160 |
+
pad_size_t = int(pad_size_h * pad_tb_barance)
|
161 |
+
pad_size_b = pad_size_h - pad_size_t
|
162 |
+
|
163 |
+
pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r), (0, 0)]
|
164 |
+
if padding_mode == "constant":
|
165 |
+
fill_value = 127
|
166 |
+
pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode, constant_values=fill_value)
|
167 |
+
else:
|
168 |
+
pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode)
|
169 |
+
|
170 |
+
mask_pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r)]
|
171 |
+
pad_mask = np.zeros((height, width), dtype=np.uint8)
|
172 |
+
pad_mask = np.pad(pad_mask, pad_width=mask_pad_width, mode="constant", constant_values=255)
|
173 |
+
sam_dict["pad_mask"] = dict(segmentation=pad_mask.astype(bool))
|
174 |
+
|
175 |
+
return pad_image, "Padding done"
|
176 |
+
|
177 |
+
|
178 |
+
@clear_cache_decorator
|
179 |
+
def run_sam(input_image, sam_model_id, sam_image, anime_style_chk=False):
|
180 |
+
global sam_dict
|
181 |
+
if not inpalib.sam_file_exists(sam_model_id):
|
182 |
+
ret_sam_image = None if sam_image is None else gr.update()
|
183 |
+
return ret_sam_image, f"{sam_model_id} not found, please download"
|
184 |
+
|
185 |
+
if input_image is None:
|
186 |
+
ret_sam_image = None if sam_image is None else gr.update()
|
187 |
+
return ret_sam_image, "Input image not found"
|
188 |
+
|
189 |
+
set_ia_config(IAConfig.KEYS.SAM_MODEL_ID, sam_model_id, IAConfig.SECTIONS.USER)
|
190 |
+
|
191 |
+
if sam_dict["sam_masks"] is not None:
|
192 |
+
sam_dict["sam_masks"] = None
|
193 |
+
gc.collect()
|
194 |
+
|
195 |
+
ia_logging.info(f"input_image: {input_image.shape} {input_image.dtype}")
|
196 |
+
|
197 |
+
try:
|
198 |
+
sam_masks = inpalib.generate_sam_masks(input_image, sam_model_id, anime_style_chk)
|
199 |
+
sam_masks = inpalib.sort_masks_by_area(sam_masks)
|
200 |
+
sam_masks = inpalib.insert_mask_to_sam_masks(sam_masks, sam_dict["pad_mask"])
|
201 |
+
|
202 |
+
seg_image = inpalib.create_seg_color_image(input_image, sam_masks)
|
203 |
+
|
204 |
+
sam_dict["sam_masks"] = sam_masks
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
print(traceback.format_exc())
|
208 |
+
ia_logging.error(str(e))
|
209 |
+
ret_sam_image = None if sam_image is None else gr.update()
|
210 |
+
return ret_sam_image, "Segment Anything failed"
|
211 |
+
|
212 |
+
if IAConfig.global_args.get("save_seg", False):
|
213 |
+
save_name = "_".join([ia_file_manager.savename_prefix, os.path.splitext(sam_model_id)[0]]) + ".png"
|
214 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
215 |
+
Image.fromarray(seg_image).save(save_name)
|
216 |
+
|
217 |
+
if sam_image is None:
|
218 |
+
return seg_image, "Segment Anything complete"
|
219 |
+
else:
|
220 |
+
if sam_image["image"].shape == seg_image.shape and np.all(sam_image["image"] == seg_image):
|
221 |
+
return gr.update(), "Segment Anything complete"
|
222 |
+
else:
|
223 |
+
return gr.update(value=seg_image), "Segment Anything complete"
|
224 |
+
|
225 |
+
|
226 |
+
@clear_cache_decorator
|
227 |
+
def select_mask(input_image, sam_image, invert_chk, ignore_black_chk, sel_mask):
|
228 |
+
global sam_dict
|
229 |
+
if sam_dict["sam_masks"] is None or sam_image is None:
|
230 |
+
ret_sel_mask = None if sel_mask is None else gr.update()
|
231 |
+
return ret_sel_mask
|
232 |
+
sam_masks = sam_dict["sam_masks"]
|
233 |
+
|
234 |
+
# image = sam_image["image"]
|
235 |
+
mask = sam_image["mask"][:, :, 0:1]
|
236 |
+
|
237 |
+
try:
|
238 |
+
seg_image = inpalib.create_mask_image(mask, sam_masks, ignore_black_chk)
|
239 |
+
if invert_chk:
|
240 |
+
seg_image = inpalib.invert_mask(seg_image)
|
241 |
+
|
242 |
+
sam_dict["mask_image"] = seg_image
|
243 |
+
|
244 |
+
except Exception as e:
|
245 |
+
print(traceback.format_exc())
|
246 |
+
ia_logging.error(str(e))
|
247 |
+
ret_sel_mask = None if sel_mask is None else gr.update()
|
248 |
+
return ret_sel_mask
|
249 |
+
|
250 |
+
if input_image is not None and input_image.shape == seg_image.shape:
|
251 |
+
ret_image = cv2.addWeighted(input_image, 0.5, seg_image, 0.5, 0)
|
252 |
+
else:
|
253 |
+
ret_image = seg_image
|
254 |
+
|
255 |
+
if sel_mask is None:
|
256 |
+
return ret_image
|
257 |
+
else:
|
258 |
+
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
|
259 |
+
return gr.update()
|
260 |
+
else:
|
261 |
+
return gr.update(value=ret_image)
|
262 |
+
|
263 |
+
|
264 |
+
@clear_cache_decorator
|
265 |
+
def expand_mask(input_image, sel_mask, expand_iteration=1):
|
266 |
+
global sam_dict
|
267 |
+
if sam_dict["mask_image"] is None or sel_mask is None:
|
268 |
+
return None
|
269 |
+
|
270 |
+
new_sel_mask = sam_dict["mask_image"]
|
271 |
+
|
272 |
+
expand_iteration = int(np.clip(expand_iteration, 1, 100))
|
273 |
+
|
274 |
+
new_sel_mask = cv2.dilate(new_sel_mask, np.ones((3, 3), dtype=np.uint8), iterations=expand_iteration)
|
275 |
+
|
276 |
+
sam_dict["mask_image"] = new_sel_mask
|
277 |
+
|
278 |
+
if input_image is not None and input_image.shape == new_sel_mask.shape:
|
279 |
+
ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
|
280 |
+
else:
|
281 |
+
ret_image = new_sel_mask
|
282 |
+
|
283 |
+
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
|
284 |
+
return gr.update()
|
285 |
+
else:
|
286 |
+
return gr.update(value=ret_image)
|
287 |
+
|
288 |
+
|
289 |
+
@clear_cache_decorator
|
290 |
+
def apply_mask(input_image, sel_mask):
|
291 |
+
global sam_dict
|
292 |
+
if sam_dict["mask_image"] is None or sel_mask is None:
|
293 |
+
return None
|
294 |
+
|
295 |
+
sel_mask_image = sam_dict["mask_image"]
|
296 |
+
sel_mask_mask = np.logical_not(sel_mask["mask"][:, :, 0:3].astype(bool)).astype(np.uint8)
|
297 |
+
new_sel_mask = sel_mask_image * sel_mask_mask
|
298 |
+
|
299 |
+
sam_dict["mask_image"] = new_sel_mask
|
300 |
+
|
301 |
+
if input_image is not None and input_image.shape == new_sel_mask.shape:
|
302 |
+
ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
|
303 |
+
else:
|
304 |
+
ret_image = new_sel_mask
|
305 |
+
|
306 |
+
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
|
307 |
+
return gr.update()
|
308 |
+
else:
|
309 |
+
return gr.update(value=ret_image)
|
310 |
+
|
311 |
+
|
312 |
+
@clear_cache_decorator
|
313 |
+
def add_mask(input_image, sel_mask):
|
314 |
+
global sam_dict
|
315 |
+
if sam_dict["mask_image"] is None or sel_mask is None:
|
316 |
+
return None
|
317 |
+
|
318 |
+
sel_mask_image = sam_dict["mask_image"]
|
319 |
+
sel_mask_mask = sel_mask["mask"][:, :, 0:3].astype(bool).astype(np.uint8)
|
320 |
+
new_sel_mask = sel_mask_image + (sel_mask_mask * np.invert(sel_mask_image, dtype=np.uint8))
|
321 |
+
|
322 |
+
sam_dict["mask_image"] = new_sel_mask
|
323 |
+
|
324 |
+
if input_image is not None and input_image.shape == new_sel_mask.shape:
|
325 |
+
ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
|
326 |
+
else:
|
327 |
+
ret_image = new_sel_mask
|
328 |
+
|
329 |
+
if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
|
330 |
+
return gr.update()
|
331 |
+
else:
|
332 |
+
return gr.update(value=ret_image)
|
333 |
+
|
334 |
+
|
335 |
+
def auto_resize_to_pil(input_image, mask_image):
|
336 |
+
init_image = Image.fromarray(input_image).convert("RGB")
|
337 |
+
mask_image = Image.fromarray(mask_image).convert("RGB")
|
338 |
+
assert init_image.size == mask_image.size, "The sizes of the image and mask do not match"
|
339 |
+
width, height = init_image.size
|
340 |
+
|
341 |
+
new_height = (height // 8) * 8
|
342 |
+
new_width = (width // 8) * 8
|
343 |
+
if new_width < width or new_height < height:
|
344 |
+
if (new_width / width) < (new_height / height):
|
345 |
+
scale = new_height / height
|
346 |
+
else:
|
347 |
+
scale = new_width / width
|
348 |
+
resize_height = int(height*scale+0.5)
|
349 |
+
resize_width = int(width*scale+0.5)
|
350 |
+
if height != resize_height or width != resize_width:
|
351 |
+
ia_logging.info(f"resize: ({height}, {width}) -> ({resize_height}, {resize_width})")
|
352 |
+
init_image = transforms.functional.resize(init_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
|
353 |
+
mask_image = transforms.functional.resize(mask_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
|
354 |
+
if resize_height != new_height or resize_width != new_width:
|
355 |
+
ia_logging.info(f"center_crop: ({resize_height}, {resize_width}) -> ({new_height}, {new_width})")
|
356 |
+
init_image = transforms.functional.center_crop(init_image, (new_height, new_width))
|
357 |
+
mask_image = transforms.functional.center_crop(mask_image, (new_height, new_width))
|
358 |
+
|
359 |
+
return init_image, mask_image
|
360 |
+
|
361 |
+
|
362 |
+
@clear_cache_decorator
|
363 |
+
def run_inpaint(input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk,
|
364 |
+
sampler_name="DDIM", iteration_count=1):
|
365 |
+
global sam_dict
|
366 |
+
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
|
367 |
+
ia_logging.error("The image or mask does not exist")
|
368 |
+
return
|
369 |
+
|
370 |
+
mask_image = sam_dict["mask_image"]
|
371 |
+
if input_image.shape != mask_image.shape:
|
372 |
+
ia_logging.error("The sizes of the image and mask do not match")
|
373 |
+
return
|
374 |
+
|
375 |
+
set_ia_config(IAConfig.KEYS.INP_MODEL_ID, inp_model_id, IAConfig.SECTIONS.USER)
|
376 |
+
|
377 |
+
save_mask_image(mask_image, save_mask_chk)
|
378 |
+
|
379 |
+
ia_logging.info(f"Loading model {inp_model_id}")
|
380 |
+
config_offline_inpainting = IAConfig.global_args.get("offline", False)
|
381 |
+
if config_offline_inpainting:
|
382 |
+
ia_logging.info("Run Inpainting on offline network: {}".format(str(config_offline_inpainting)))
|
383 |
+
local_files_only = False
|
384 |
+
local_file_status = download_model_from_hf(inp_model_id, local_files_only=True)
|
385 |
+
if local_file_status != IAFileManager.DOWNLOAD_COMPLETE:
|
386 |
+
if config_offline_inpainting:
|
387 |
+
ia_logging.warning(local_file_status)
|
388 |
+
return
|
389 |
+
else:
|
390 |
+
local_files_only = True
|
391 |
+
ia_logging.info("local_files_only: {}".format(str(local_files_only)))
|
392 |
+
|
393 |
+
if platform.system() == "Darwin" or devices.device == devices.cpu or ia_check_versions.torch_on_amd_rocm:
|
394 |
+
torch_dtype = torch.float32
|
395 |
+
else:
|
396 |
+
torch_dtype = torch.float16
|
397 |
+
|
398 |
+
try:
|
399 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
400 |
+
inp_model_id, torch_dtype=torch_dtype, local_files_only=local_files_only, use_safetensors=True)
|
401 |
+
except Exception as e:
|
402 |
+
ia_logging.error(str(e))
|
403 |
+
if not config_offline_inpainting:
|
404 |
+
try:
|
405 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
406 |
+
inp_model_id, torch_dtype=torch_dtype, use_safetensors=True)
|
407 |
+
except Exception as e:
|
408 |
+
ia_logging.error(str(e))
|
409 |
+
try:
|
410 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
411 |
+
inp_model_id, torch_dtype=torch_dtype, force_download=True, use_safetensors=True)
|
412 |
+
except Exception as e:
|
413 |
+
ia_logging.error(str(e))
|
414 |
+
return
|
415 |
+
else:
|
416 |
+
return
|
417 |
+
pipe.safety_checker = None
|
418 |
+
|
419 |
+
ia_logging.info(f"Using sampler {sampler_name}")
|
420 |
+
if sampler_name == "DDIM":
|
421 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
422 |
+
elif sampler_name == "Euler":
|
423 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
424 |
+
elif sampler_name == "Euler a":
|
425 |
+
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
426 |
+
elif sampler_name == "DPM2 Karras":
|
427 |
+
pipe.scheduler = KDPM2DiscreteScheduler.from_config(pipe.scheduler.config)
|
428 |
+
elif sampler_name == "DPM2 a Karras":
|
429 |
+
pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
430 |
+
else:
|
431 |
+
ia_logging.info("Sampler fallback to DDIM")
|
432 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
433 |
+
|
434 |
+
if platform.system() == "Darwin":
|
435 |
+
pipe = pipe.to("mps" if ia_check_versions.torch_mps_is_available else "cpu")
|
436 |
+
pipe.enable_attention_slicing()
|
437 |
+
torch_generator = torch.Generator(devices.cpu)
|
438 |
+
else:
|
439 |
+
if ia_check_versions.diffusers_enable_cpu_offload and devices.device != devices.cpu:
|
440 |
+
ia_logging.info("Enable model cpu offload")
|
441 |
+
pipe.enable_model_cpu_offload()
|
442 |
+
else:
|
443 |
+
pipe = pipe.to(devices.device)
|
444 |
+
if xformers_available:
|
445 |
+
ia_logging.info("Enable xformers memory efficient attention")
|
446 |
+
pipe.enable_xformers_memory_efficient_attention()
|
447 |
+
else:
|
448 |
+
ia_logging.info("Enable attention slicing")
|
449 |
+
pipe.enable_attention_slicing()
|
450 |
+
if "privateuseone" in str(getattr(devices.device, "type", "")):
|
451 |
+
torch_generator = torch.Generator(devices.cpu)
|
452 |
+
else:
|
453 |
+
torch_generator = torch.Generator(devices.device)
|
454 |
+
|
455 |
+
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
|
456 |
+
width, height = init_image.size
|
457 |
+
|
458 |
+
output_list = []
|
459 |
+
iteration_count = iteration_count if iteration_count is not None else 1
|
460 |
+
for count in range(int(iteration_count)):
|
461 |
+
gc.collect()
|
462 |
+
if seed < 0 or count > 0:
|
463 |
+
seed = random.randint(0, 2147483647)
|
464 |
+
|
465 |
+
generator = torch_generator.manual_seed(seed)
|
466 |
+
|
467 |
+
pipe_args_dict = {
|
468 |
+
"prompt": prompt,
|
469 |
+
"image": init_image,
|
470 |
+
"width": width,
|
471 |
+
"height": height,
|
472 |
+
"mask_image": mask_image,
|
473 |
+
"num_inference_steps": ddim_steps,
|
474 |
+
"guidance_scale": cfg_scale,
|
475 |
+
"negative_prompt": n_prompt,
|
476 |
+
"generator": generator,
|
477 |
+
}
|
478 |
+
|
479 |
+
output_image = pipe(**pipe_args_dict).images[0]
|
480 |
+
|
481 |
+
if composite_chk:
|
482 |
+
dilate_mask_image = Image.fromarray(cv2.dilate(np.array(mask_image), np.ones((3, 3), dtype=np.uint8), iterations=4))
|
483 |
+
output_image = Image.composite(output_image, init_image, dilate_mask_image.convert("L").filter(ImageFilter.GaussianBlur(3)))
|
484 |
+
|
485 |
+
generation_params = {
|
486 |
+
"Steps": ddim_steps,
|
487 |
+
"Sampler": sampler_name,
|
488 |
+
"CFG scale": cfg_scale,
|
489 |
+
"Seed": seed,
|
490 |
+
"Size": f"{width}x{height}",
|
491 |
+
"Model": inp_model_id,
|
492 |
+
}
|
493 |
+
|
494 |
+
generation_params_text = ", ".join([k if k == v else f"{k}: {v}" for k, v in generation_params.items() if v is not None])
|
495 |
+
prompt_text = prompt if prompt else ""
|
496 |
+
negative_prompt_text = "\nNegative prompt: " + n_prompt if n_prompt else ""
|
497 |
+
infotext = f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
498 |
+
|
499 |
+
metadata = PngInfo()
|
500 |
+
metadata.add_text("parameters", infotext)
|
501 |
+
|
502 |
+
save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(inp_model_id), str(seed)]) + ".png"
|
503 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
504 |
+
output_image.save(save_name, pnginfo=metadata)
|
505 |
+
|
506 |
+
output_list.append(output_image)
|
507 |
+
|
508 |
+
yield output_list, max([1, iteration_count - (count + 1)])
|
509 |
+
|
510 |
+
|
511 |
+
@clear_cache_decorator
|
512 |
+
def run_cleaner(input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk):
|
513 |
+
global sam_dict
|
514 |
+
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
|
515 |
+
ia_logging.error("The image or mask does not exist")
|
516 |
+
return None
|
517 |
+
|
518 |
+
mask_image = sam_dict["mask_image"]
|
519 |
+
if input_image.shape != mask_image.shape:
|
520 |
+
ia_logging.error("The sizes of the image and mask do not match")
|
521 |
+
return None
|
522 |
+
|
523 |
+
save_mask_image(mask_image, cleaner_save_mask_chk)
|
524 |
+
|
525 |
+
ia_logging.info(f"Loading model {cleaner_model_id}")
|
526 |
+
if platform.system() == "Darwin":
|
527 |
+
model = ModelManager(name=cleaner_model_id, device=devices.cpu)
|
528 |
+
else:
|
529 |
+
model = ModelManager(name=cleaner_model_id, device=devices.device)
|
530 |
+
|
531 |
+
init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
|
532 |
+
width, height = init_image.size
|
533 |
+
|
534 |
+
init_image = np.array(init_image)
|
535 |
+
mask_image = np.array(mask_image.convert("L"))
|
536 |
+
|
537 |
+
config = Config(
|
538 |
+
ldm_steps=20,
|
539 |
+
ldm_sampler=LDMSampler.ddim,
|
540 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
541 |
+
hd_strategy_crop_margin=32,
|
542 |
+
hd_strategy_crop_trigger_size=512,
|
543 |
+
hd_strategy_resize_limit=512,
|
544 |
+
prompt="",
|
545 |
+
sd_steps=20,
|
546 |
+
sd_sampler=SDSampler.ddim
|
547 |
+
)
|
548 |
+
|
549 |
+
output_image = model(image=init_image, mask=mask_image, config=config)
|
550 |
+
output_image = cv2.cvtColor(output_image.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
551 |
+
output_image = Image.fromarray(output_image)
|
552 |
+
|
553 |
+
save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(cleaner_model_id)]) + ".png"
|
554 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
555 |
+
output_image.save(save_name)
|
556 |
+
|
557 |
+
del model
|
558 |
+
return [output_image]
|
559 |
+
|
560 |
+
|
561 |
+
@clear_cache_decorator
|
562 |
+
def run_get_alpha_image(input_image, sel_mask):
|
563 |
+
global sam_dict
|
564 |
+
if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
|
565 |
+
ia_logging.error("The image or mask does not exist")
|
566 |
+
return None, ""
|
567 |
+
|
568 |
+
mask_image = sam_dict["mask_image"]
|
569 |
+
if input_image.shape != mask_image.shape:
|
570 |
+
ia_logging.error("The sizes of the image and mask do not match")
|
571 |
+
return None, ""
|
572 |
+
|
573 |
+
alpha_image = Image.fromarray(input_image).convert("RGBA")
|
574 |
+
mask_image = Image.fromarray(mask_image).convert("L")
|
575 |
+
|
576 |
+
alpha_image.putalpha(mask_image)
|
577 |
+
|
578 |
+
save_name = "_".join([ia_file_manager.savename_prefix, "rgba_image"]) + ".png"
|
579 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
580 |
+
alpha_image.save(save_name)
|
581 |
+
|
582 |
+
return alpha_image, f"saved: {save_name}"
|
583 |
+
|
584 |
+
|
585 |
+
@clear_cache_decorator
|
586 |
+
def run_get_mask(sel_mask):
|
587 |
+
global sam_dict
|
588 |
+
if sam_dict["mask_image"] is None or sel_mask is None:
|
589 |
+
return None
|
590 |
+
|
591 |
+
mask_image = sam_dict["mask_image"]
|
592 |
+
|
593 |
+
save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png"
|
594 |
+
save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
|
595 |
+
Image.fromarray(mask_image).save(save_name)
|
596 |
+
|
597 |
+
return mask_image
|
598 |
+
|
599 |
+
|
600 |
+
def on_ui_tabs():
|
601 |
+
setup_ia_config_ini()
|
602 |
+
sampler_names = get_sampler_names()
|
603 |
+
sam_model_ids = get_sam_model_ids()
|
604 |
+
sam_model_index = get_ia_config_index(IAConfig.KEYS.SAM_MODEL_ID, IAConfig.SECTIONS.USER)
|
605 |
+
inp_model_ids = get_inp_model_ids()
|
606 |
+
inp_model_index = get_ia_config_index(IAConfig.KEYS.INP_MODEL_ID, IAConfig.SECTIONS.USER)
|
607 |
+
cleaner_model_ids = get_cleaner_model_ids()
|
608 |
+
padding_mode_names = get_padding_mode_names()
|
609 |
+
|
610 |
+
out_gallery_kwargs = dict(columns=2, height=520, object_fit="contain", preview=True)
|
611 |
+
|
612 |
+
block = gr.Blocks(analytics_enabled=False).queue()
|
613 |
+
block.title = "Inpaint Anything"
|
614 |
+
with block as inpaint_anything_interface:
|
615 |
+
with gr.Row():
|
616 |
+
gr.Markdown("## Inpainting with Segment Anything")
|
617 |
+
with gr.Row():
|
618 |
+
with gr.Column():
|
619 |
+
with gr.Row():
|
620 |
+
with gr.Column():
|
621 |
+
sam_model_id = gr.Dropdown(label="Segment Anything Model ID", elem_id="sam_model_id", choices=sam_model_ids,
|
622 |
+
value=sam_model_ids[sam_model_index], show_label=True)
|
623 |
+
with gr.Column():
|
624 |
+
with gr.Row():
|
625 |
+
load_model_btn = gr.Button("Download model", elem_id="load_model_btn")
|
626 |
+
with gr.Row():
|
627 |
+
status_text = gr.Textbox(label="", elem_id="status_text", max_lines=1, show_label=False, interactive=False)
|
628 |
+
with gr.Row():
|
629 |
+
input_image = gr.Image(label="Input image", elem_id="ia_input_image", source="upload", type="numpy", interactive=True)
|
630 |
+
|
631 |
+
with gr.Row():
|
632 |
+
with gr.Accordion("Padding options", elem_id="padding_options", open=False):
|
633 |
+
with gr.Row():
|
634 |
+
with gr.Column():
|
635 |
+
pad_scale_width = gr.Slider(label="Scale Width", elem_id="pad_scale_width", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
|
636 |
+
with gr.Column():
|
637 |
+
pad_lr_barance = gr.Slider(label="Left/Right Balance", elem_id="pad_lr_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
|
638 |
+
with gr.Row():
|
639 |
+
with gr.Column():
|
640 |
+
pad_scale_height = gr.Slider(label="Scale Height", elem_id="pad_scale_height", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
|
641 |
+
with gr.Column():
|
642 |
+
pad_tb_barance = gr.Slider(label="Top/Bottom Balance", elem_id="pad_tb_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
|
643 |
+
with gr.Row():
|
644 |
+
with gr.Column():
|
645 |
+
padding_mode = gr.Dropdown(label="Padding Mode", elem_id="padding_mode", choices=padding_mode_names, value="edge")
|
646 |
+
with gr.Column():
|
647 |
+
padding_btn = gr.Button("Run Padding", elem_id="padding_btn")
|
648 |
+
|
649 |
+
with gr.Row():
|
650 |
+
with gr.Column():
|
651 |
+
anime_style_chk = gr.Checkbox(label="Anime Style (Up Detection, Down mask Quality)", elem_id="anime_style_chk",
|
652 |
+
show_label=True, interactive=True)
|
653 |
+
with gr.Column():
|
654 |
+
sam_btn = gr.Button("Run Segment Anything", elem_id="sam_btn", variant="primary", interactive=False)
|
655 |
+
|
656 |
+
with gr.Tab("Inpainting", elem_id="inpainting_tab"):
|
657 |
+
prompt = gr.Textbox(label="Inpainting Prompt", elem_id="sd_prompt")
|
658 |
+
n_prompt = gr.Textbox(label="Negative Prompt", elem_id="sd_n_prompt")
|
659 |
+
with gr.Accordion("Advanced options", elem_id="inp_advanced_options", open=False):
|
660 |
+
composite_chk = gr.Checkbox(label="Mask area Only", elem_id="composite_chk", value=True, show_label=True, interactive=True)
|
661 |
+
with gr.Row():
|
662 |
+
with gr.Column():
|
663 |
+
sampler_name = gr.Dropdown(label="Sampler", elem_id="sampler_name", choices=sampler_names,
|
664 |
+
value=sampler_names[0], show_label=True)
|
665 |
+
with gr.Column():
|
666 |
+
ddim_steps = gr.Slider(label="Sampling Steps", elem_id="ddim_steps", minimum=1, maximum=100, value=20, step=1)
|
667 |
+
cfg_scale = gr.Slider(label="Guidance Scale", elem_id="cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
668 |
+
seed = gr.Slider(
|
669 |
+
label="Seed",
|
670 |
+
elem_id="sd_seed",
|
671 |
+
minimum=-1,
|
672 |
+
maximum=2147483647,
|
673 |
+
step=1,
|
674 |
+
value=-1,
|
675 |
+
)
|
676 |
+
with gr.Row():
|
677 |
+
with gr.Column():
|
678 |
+
inp_model_id = gr.Dropdown(label="Inpainting Model ID", elem_id="inp_model_id",
|
679 |
+
choices=inp_model_ids, value=inp_model_ids[inp_model_index], show_label=True)
|
680 |
+
with gr.Column():
|
681 |
+
with gr.Row():
|
682 |
+
inpaint_btn = gr.Button("Run Inpainting", elem_id="inpaint_btn", variant="primary")
|
683 |
+
with gr.Row():
|
684 |
+
save_mask_chk = gr.Checkbox(label="Save mask", elem_id="save_mask_chk",
|
685 |
+
value=False, show_label=False, interactive=False, visible=False)
|
686 |
+
iteration_count = gr.Slider(label="Iterations", elem_id="iteration_count", minimum=1, maximum=10, value=1, step=1)
|
687 |
+
|
688 |
+
with gr.Row():
|
689 |
+
if ia_check_versions.gradio_version_is_old:
|
690 |
+
out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False
|
691 |
+
).style(**out_gallery_kwargs)
|
692 |
+
else:
|
693 |
+
out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False,
|
694 |
+
**out_gallery_kwargs)
|
695 |
+
|
696 |
+
with gr.Tab("Cleaner", elem_id="cleaner_tab"):
|
697 |
+
with gr.Row():
|
698 |
+
with gr.Column():
|
699 |
+
cleaner_model_id = gr.Dropdown(label="Cleaner Model ID", elem_id="cleaner_model_id",
|
700 |
+
choices=cleaner_model_ids, value=cleaner_model_ids[0], show_label=True)
|
701 |
+
with gr.Column():
|
702 |
+
with gr.Row():
|
703 |
+
cleaner_btn = gr.Button("Run Cleaner", elem_id="cleaner_btn", variant="primary")
|
704 |
+
with gr.Row():
|
705 |
+
cleaner_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="cleaner_save_mask_chk",
|
706 |
+
value=False, show_label=False, interactive=False, visible=False)
|
707 |
+
|
708 |
+
with gr.Row():
|
709 |
+
if ia_check_versions.gradio_version_is_old:
|
710 |
+
cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False
|
711 |
+
).style(**out_gallery_kwargs)
|
712 |
+
else:
|
713 |
+
cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False,
|
714 |
+
**out_gallery_kwargs)
|
715 |
+
|
716 |
+
with gr.Tab("Mask only", elem_id="mask_only_tab"):
|
717 |
+
with gr.Row():
|
718 |
+
with gr.Column():
|
719 |
+
get_alpha_image_btn = gr.Button("Get mask as alpha of image", elem_id="get_alpha_image_btn")
|
720 |
+
with gr.Column():
|
721 |
+
get_mask_btn = gr.Button("Get mask", elem_id="get_mask_btn")
|
722 |
+
|
723 |
+
with gr.Row():
|
724 |
+
with gr.Column():
|
725 |
+
alpha_out_image = gr.Image(label="Alpha channel image", elem_id="alpha_out_image", type="pil", image_mode="RGBA", interactive=False)
|
726 |
+
with gr.Column():
|
727 |
+
mask_out_image = gr.Image(label="Mask image", elem_id="mask_out_image", type="numpy", interactive=False)
|
728 |
+
|
729 |
+
with gr.Row():
|
730 |
+
with gr.Column():
|
731 |
+
get_alpha_status_text = gr.Textbox(label="", elem_id="get_alpha_status_text", max_lines=1, show_label=False, interactive=False)
|
732 |
+
with gr.Column():
|
733 |
+
gr.Markdown("")
|
734 |
+
|
735 |
+
with gr.Column():
|
736 |
+
with gr.Row():
|
737 |
+
gr.Markdown("Mouse over image: Press `S` key for Fullscreen mode, `R` key to Reset zoom")
|
738 |
+
with gr.Row():
|
739 |
+
if ia_check_versions.gradio_version_is_old:
|
740 |
+
sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8,
|
741 |
+
show_label=False, interactive=True).style(height=480)
|
742 |
+
else:
|
743 |
+
sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8,
|
744 |
+
show_label=False, interactive=True, height=480)
|
745 |
+
|
746 |
+
with gr.Row():
|
747 |
+
with gr.Column():
|
748 |
+
select_btn = gr.Button("Create Mask", elem_id="select_btn", variant="primary")
|
749 |
+
with gr.Column():
|
750 |
+
with gr.Row():
|
751 |
+
invert_chk = gr.Checkbox(label="Invert mask", elem_id="invert_chk", show_label=True, interactive=True)
|
752 |
+
ignore_black_chk = gr.Checkbox(label="Ignore black area", elem_id="ignore_black_chk", value=True, show_label=True, interactive=True)
|
753 |
+
|
754 |
+
with gr.Row():
|
755 |
+
if ia_check_versions.gradio_version_is_old:
|
756 |
+
sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12,
|
757 |
+
show_label=False, interactive=True).style(height=480)
|
758 |
+
else:
|
759 |
+
sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12,
|
760 |
+
show_label=False, interactive=True, height=480)
|
761 |
+
|
762 |
+
with gr.Row():
|
763 |
+
with gr.Column():
|
764 |
+
expand_mask_btn = gr.Button("Expand mask region", elem_id="expand_mask_btn")
|
765 |
+
expand_mask_iteration_count = gr.Slider(label="Expand Mask Iterations",
|
766 |
+
elem_id="expand_mask_iteration_count", minimum=1, maximum=100, value=1, step=1)
|
767 |
+
with gr.Column():
|
768 |
+
apply_mask_btn = gr.Button("Trim mask by sketch", elem_id="apply_mask_btn")
|
769 |
+
add_mask_btn = gr.Button("Add mask by sketch", elem_id="add_mask_btn")
|
770 |
+
|
771 |
+
load_model_btn.click(download_model, inputs=[sam_model_id], outputs=[status_text])
|
772 |
+
input_image.upload(input_image_upload, inputs=[input_image, sam_image, sel_mask], outputs=[sam_image, sel_mask, sam_btn]).then(
|
773 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_initSamSelMask")
|
774 |
+
padding_btn.click(run_padding, inputs=[input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode],
|
775 |
+
outputs=[input_image, status_text])
|
776 |
+
sam_btn.click(run_sam, inputs=[input_image, sam_model_id, sam_image, anime_style_chk], outputs=[sam_image, status_text]).then(
|
777 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSamMask")
|
778 |
+
select_btn.click(select_mask, inputs=[input_image, sam_image, invert_chk, ignore_black_chk, sel_mask], outputs=[sel_mask]).then(
|
779 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
|
780 |
+
expand_mask_btn.click(expand_mask, inputs=[input_image, sel_mask, expand_mask_iteration_count], outputs=[sel_mask]).then(
|
781 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
|
782 |
+
apply_mask_btn.click(apply_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
|
783 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
|
784 |
+
add_mask_btn.click(add_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
|
785 |
+
fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
|
786 |
+
|
787 |
+
inpaint_btn.click(
|
788 |
+
run_inpaint,
|
789 |
+
inputs=[input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk,
|
790 |
+
sampler_name, iteration_count],
|
791 |
+
outputs=[out_image, iteration_count])
|
792 |
+
cleaner_btn.click(
|
793 |
+
run_cleaner,
|
794 |
+
inputs=[input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk],
|
795 |
+
outputs=[cleaner_out_image])
|
796 |
+
get_alpha_image_btn.click(
|
797 |
+
run_get_alpha_image,
|
798 |
+
inputs=[input_image, sel_mask],
|
799 |
+
outputs=[alpha_out_image, get_alpha_status_text])
|
800 |
+
get_mask_btn.click(
|
801 |
+
run_get_mask,
|
802 |
+
inputs=[sel_mask],
|
803 |
+
outputs=[mask_out_image])
|
804 |
+
|
805 |
+
return [(inpaint_anything_interface, "Inpaint Anything", "inpaint_anything")]
|
806 |
+
|
807 |
+
|
808 |
+
block, _, _ = on_ui_tabs()[0]
|
809 |
+
block.launch(share=True)
|
images/inpaint_anything_explanation_image_1.png
ADDED
images/inpaint_anything_ui_image_1.png
ADDED
images/sample_input_image.png
ADDED
images/sample_mask_image.png
ADDED
images/sample_seg_color_image.png
ADDED
inpalib/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .masklib import create_mask_image, invert_mask
|
2 |
+
from .samlib import (create_seg_color_image, generate_sam_masks, get_all_sam_ids,
|
3 |
+
get_available_sam_ids, get_seg_colormap, insert_mask_to_sam_masks,
|
4 |
+
sam_file_exists, sam_file_path, sort_masks_by_area)
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"create_mask_image",
|
8 |
+
"invert_mask",
|
9 |
+
"create_seg_color_image",
|
10 |
+
"generate_sam_masks",
|
11 |
+
"get_all_sam_ids",
|
12 |
+
"get_available_sam_ids",
|
13 |
+
"get_seg_colormap",
|
14 |
+
"insert_mask_to_sam_masks",
|
15 |
+
"sam_file_exists",
|
16 |
+
"sam_file_path",
|
17 |
+
"sort_masks_by_area",
|
18 |
+
]
|
inpalib/masklib.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def invert_mask(mask: np.ndarray) -> np.ndarray:
|
8 |
+
"""Invert mask.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
mask (np.ndarray): mask
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
np.ndarray: inverted mask
|
15 |
+
"""
|
16 |
+
if mask is None or not isinstance(mask, np.ndarray):
|
17 |
+
raise ValueError("Invalid mask")
|
18 |
+
|
19 |
+
# return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255
|
20 |
+
return np.invert(mask.astype(np.uint8))
|
21 |
+
|
22 |
+
|
23 |
+
def check_inputs_create_mask_image(
|
24 |
+
mask: Union[np.ndarray, Image.Image],
|
25 |
+
sam_masks: List[Dict[str, Any]],
|
26 |
+
ignore_black_chk: bool = True,
|
27 |
+
) -> None:
|
28 |
+
"""Check create mask image inputs.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
mask (Union[np.ndarray, Image.Image]): mask
|
32 |
+
sam_masks (List[Dict[str, Any]]): SAM masks
|
33 |
+
ignore_black_chk (bool): ignore black check
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
None
|
37 |
+
"""
|
38 |
+
if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
|
39 |
+
raise ValueError("Invalid mask")
|
40 |
+
|
41 |
+
if sam_masks is None or not isinstance(sam_masks, list):
|
42 |
+
raise ValueError("Invalid SAM masks")
|
43 |
+
|
44 |
+
if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
|
45 |
+
raise ValueError("Invalid ignore black check")
|
46 |
+
|
47 |
+
|
48 |
+
def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
|
49 |
+
"""Convert mask.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
mask (Union[np.ndarray, Image.Image]): mask
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
np.ndarray: converted mask
|
56 |
+
"""
|
57 |
+
if isinstance(mask, Image.Image):
|
58 |
+
mask = np.array(mask)
|
59 |
+
|
60 |
+
if mask.ndim == 2:
|
61 |
+
mask = mask[:, :, np.newaxis]
|
62 |
+
|
63 |
+
if mask.shape[2] != 1:
|
64 |
+
mask = mask[:, :, 0:1]
|
65 |
+
|
66 |
+
return mask
|
67 |
+
|
68 |
+
|
69 |
+
def create_mask_image(
|
70 |
+
mask: Union[np.ndarray, Image.Image],
|
71 |
+
sam_masks: List[Dict[str, Any]],
|
72 |
+
ignore_black_chk: bool = True,
|
73 |
+
) -> np.ndarray:
|
74 |
+
"""Create mask image.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
mask (Union[np.ndarray, Image.Image]): mask
|
78 |
+
sam_masks (List[Dict[str, Any]]): SAM masks
|
79 |
+
ignore_black_chk (bool): ignore black check
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
np.ndarray: mask image
|
83 |
+
"""
|
84 |
+
check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
|
85 |
+
mask = convert_mask(mask)
|
86 |
+
|
87 |
+
canvas_image = np.zeros(mask.shape, dtype=np.uint8)
|
88 |
+
mask_region = np.zeros(mask.shape, dtype=np.uint8)
|
89 |
+
for seg_dict in sam_masks:
|
90 |
+
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
|
91 |
+
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
92 |
+
if (seg_mask * canvas_mask * mask).astype(bool).any():
|
93 |
+
mask_region = mask_region + (seg_mask * canvas_mask)
|
94 |
+
seg_color = seg_mask * canvas_mask
|
95 |
+
canvas_image = canvas_image + seg_color
|
96 |
+
|
97 |
+
if not ignore_black_chk:
|
98 |
+
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
99 |
+
if (canvas_mask * mask).astype(bool).any():
|
100 |
+
mask_region = mask_region + (canvas_mask)
|
101 |
+
|
102 |
+
mask_region = np.tile(mask_region * 255, (1, 1, 3))
|
103 |
+
|
104 |
+
seg_image = mask_region.astype(np.uint8)
|
105 |
+
|
106 |
+
return seg_image
|
inpalib/samlib.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import Any, Dict, List, Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
inpa_basedir = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
|
13 |
+
if inpa_basedir not in sys.path:
|
14 |
+
sys.path.append(inpa_basedir)
|
15 |
+
|
16 |
+
from ia_file_manager import ia_file_manager # noqa: E402
|
17 |
+
from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402
|
18 |
+
from ia_logging import ia_logging # noqa: E402
|
19 |
+
from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402
|
20 |
+
from ia_ui_items import get_sam_model_ids # noqa: E402
|
21 |
+
|
22 |
+
|
23 |
+
def get_all_sam_ids() -> List[str]:
|
24 |
+
"""Get all SAM IDs.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[str]: SAM IDs
|
28 |
+
"""
|
29 |
+
return get_sam_model_ids()
|
30 |
+
|
31 |
+
|
32 |
+
def sam_file_path(sam_id: str) -> str:
|
33 |
+
"""Get SAM file path.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
sam_id (str): SAM ID
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: SAM file path
|
40 |
+
"""
|
41 |
+
return os.path.join(ia_file_manager.models_dir, sam_id)
|
42 |
+
|
43 |
+
|
44 |
+
def sam_file_exists(sam_id: str) -> bool:
|
45 |
+
"""Check if SAM file exists.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
sam_id (str): SAM ID
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if SAM file exists else False
|
52 |
+
"""
|
53 |
+
sam_checkpoint = sam_file_path(sam_id)
|
54 |
+
|
55 |
+
return os.path.isfile(sam_checkpoint)
|
56 |
+
|
57 |
+
|
58 |
+
def get_available_sam_ids() -> List[str]:
|
59 |
+
"""Get available SAM IDs.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
List[str]: available SAM IDs
|
63 |
+
"""
|
64 |
+
all_sam_ids = get_all_sam_ids()
|
65 |
+
for sam_id in all_sam_ids.copy():
|
66 |
+
if not sam_file_exists(sam_id):
|
67 |
+
all_sam_ids.remove(sam_id)
|
68 |
+
|
69 |
+
return all_sam_ids
|
70 |
+
|
71 |
+
|
72 |
+
def check_inputs_generate_sam_masks(
|
73 |
+
input_image: Union[np.ndarray, Image.Image],
|
74 |
+
sam_id: str,
|
75 |
+
anime_style_chk: bool = False,
|
76 |
+
) -> None:
|
77 |
+
"""Check generate SAM masks inputs.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
input_image (Union[np.ndarray, Image.Image]): input image
|
81 |
+
sam_id (str): SAM ID
|
82 |
+
anime_style_chk (bool): anime style check
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
None
|
86 |
+
"""
|
87 |
+
if input_image is None or not isinstance(input_image, (np.ndarray, Image.Image)):
|
88 |
+
raise ValueError("Invalid input image")
|
89 |
+
|
90 |
+
if sam_id is None or not isinstance(sam_id, str):
|
91 |
+
raise ValueError("Invalid SAM ID")
|
92 |
+
|
93 |
+
if anime_style_chk is None or not isinstance(anime_style_chk, bool):
|
94 |
+
raise ValueError("Invalid anime style check")
|
95 |
+
|
96 |
+
|
97 |
+
def convert_input_image(input_image: Union[np.ndarray, Image.Image]) -> np.ndarray:
|
98 |
+
"""Convert input image.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
input_image (Union[np.ndarray, Image.Image]): input image
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
np.ndarray: converted input image
|
105 |
+
"""
|
106 |
+
if isinstance(input_image, Image.Image):
|
107 |
+
input_image = np.array(input_image)
|
108 |
+
|
109 |
+
if input_image.ndim == 2:
|
110 |
+
input_image = input_image[:, :, np.newaxis]
|
111 |
+
|
112 |
+
if input_image.shape[2] == 1:
|
113 |
+
input_image = np.concatenate([input_image] * 3, axis=-1)
|
114 |
+
|
115 |
+
return input_image
|
116 |
+
|
117 |
+
|
118 |
+
def generate_sam_masks(
|
119 |
+
input_image: Union[np.ndarray, Image.Image],
|
120 |
+
sam_id: str,
|
121 |
+
anime_style_chk: bool = False,
|
122 |
+
) -> List[Dict[str, Any]]:
|
123 |
+
"""Generate SAM masks.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
input_image (Union[np.ndarray, Image.Image]): input image
|
127 |
+
sam_id (str): SAM ID
|
128 |
+
anime_style_chk (bool): anime style check
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
List[Dict[str, Any]]: SAM masks
|
132 |
+
"""
|
133 |
+
check_inputs_generate_sam_masks(input_image, sam_id, anime_style_chk)
|
134 |
+
input_image = convert_input_image(input_image)
|
135 |
+
|
136 |
+
sam_checkpoint = sam_file_path(sam_id)
|
137 |
+
sam_mask_generator = get_sam_mask_generator(sam_checkpoint, anime_style_chk)
|
138 |
+
ia_logging.info(f"{sam_mask_generator.__class__.__name__} {sam_id}")
|
139 |
+
|
140 |
+
if "sam2_" in sam_id:
|
141 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
142 |
+
torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16
|
143 |
+
with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype):
|
144 |
+
sam_masks = sam_mask_generator.generate(input_image)
|
145 |
+
else:
|
146 |
+
sam_masks = sam_mask_generator.generate(input_image)
|
147 |
+
|
148 |
+
if anime_style_chk:
|
149 |
+
for sam_mask in sam_masks:
|
150 |
+
sam_mask_seg = sam_mask["segmentation"]
|
151 |
+
sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
|
152 |
+
sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))
|
153 |
+
sam_mask["segmentation"] = sam_mask_seg.astype(bool)
|
154 |
+
|
155 |
+
ia_logging.info("sam_masks: {}".format(len(sam_masks)))
|
156 |
+
|
157 |
+
sam_masks = copy.deepcopy(sam_masks)
|
158 |
+
return sam_masks
|
159 |
+
|
160 |
+
|
161 |
+
def sort_masks_by_area(
|
162 |
+
sam_masks: List[Dict[str, Any]],
|
163 |
+
) -> List[Dict[str, Any]]:
|
164 |
+
"""Sort mask by area.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
sam_masks (List[Dict[str, Any]]): SAM masks
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
List[Dict[str, Any]]: sorted SAM masks
|
171 |
+
"""
|
172 |
+
return sorted(sam_masks, key=lambda x: np.sum(x.get("segmentation").astype(np.uint32)))
|
173 |
+
|
174 |
+
|
175 |
+
def get_seg_colormap() -> np.ndarray:
|
176 |
+
"""Get segmentation colormap.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
np.ndarray: segmentation colormap
|
180 |
+
"""
|
181 |
+
cm_pascal = create_pascal_label_colormap()
|
182 |
+
seg_colormap = cm_pascal
|
183 |
+
seg_colormap = np.array([c for c in seg_colormap if max(c) >= 64], dtype=np.uint8)
|
184 |
+
|
185 |
+
return seg_colormap
|
186 |
+
|
187 |
+
|
188 |
+
def insert_mask_to_sam_masks(
|
189 |
+
sam_masks: List[Dict[str, Any]],
|
190 |
+
insert_mask: Dict[str, Any],
|
191 |
+
) -> List[Dict[str, Any]]:
|
192 |
+
"""Insert mask to SAM masks.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
sam_masks (List[Dict[str, Any]]): SAM masks
|
196 |
+
insert_mask (Dict[str, Any]): insert mask
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
List[Dict[str, Any]]: SAM masks
|
200 |
+
"""
|
201 |
+
if insert_mask is not None and isinstance(insert_mask, dict) and "segmentation" in insert_mask:
|
202 |
+
if (len(sam_masks) > 0 and
|
203 |
+
sam_masks[0]["segmentation"].shape == insert_mask["segmentation"].shape and
|
204 |
+
np.any(insert_mask["segmentation"])):
|
205 |
+
sam_masks.insert(0, insert_mask)
|
206 |
+
ia_logging.info("insert mask to sam_masks")
|
207 |
+
|
208 |
+
return sam_masks
|
209 |
+
|
210 |
+
|
211 |
+
def create_seg_color_image(
|
212 |
+
input_image: Union[np.ndarray, Image.Image],
|
213 |
+
sam_masks: List[Dict[str, Any]],
|
214 |
+
) -> np.ndarray:
|
215 |
+
"""Create segmentation color image.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
input_image (Union[np.ndarray, Image.Image]): input image
|
219 |
+
sam_masks (List[Dict[str, Any]]): SAM masks
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
np.ndarray: segmentation color image
|
223 |
+
"""
|
224 |
+
input_image = convert_input_image(input_image)
|
225 |
+
|
226 |
+
seg_colormap = get_seg_colormap()
|
227 |
+
sam_masks = sam_masks[:len(seg_colormap)]
|
228 |
+
|
229 |
+
with tqdm(total=len(sam_masks), desc="Processing segments") as progress_bar:
|
230 |
+
canvas_image = np.zeros((*input_image.shape[:2], 1), dtype=np.uint8)
|
231 |
+
for idx, seg_dict in enumerate(sam_masks[0:min(255, len(sam_masks))]):
|
232 |
+
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
|
233 |
+
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
234 |
+
seg_color = np.array([idx+1], dtype=np.uint8) * seg_mask * canvas_mask
|
235 |
+
canvas_image = canvas_image + seg_color
|
236 |
+
progress_bar.update(1)
|
237 |
+
seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
|
238 |
+
temp_canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image)
|
239 |
+
if len(sam_masks) > 255:
|
240 |
+
canvas_image = canvas_image.astype(bool).astype(np.uint8)
|
241 |
+
for idx, seg_dict in enumerate(sam_masks[255:min(509, len(sam_masks))]):
|
242 |
+
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
|
243 |
+
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
|
244 |
+
seg_color = np.array([idx+2], dtype=np.uint8) * seg_mask * canvas_mask
|
245 |
+
canvas_image = canvas_image + seg_color
|
246 |
+
progress_bar.update(1)
|
247 |
+
seg_colormap = seg_colormap[256:]
|
248 |
+
seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
|
249 |
+
seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
|
250 |
+
canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image)
|
251 |
+
canvas_image = temp_canvas_image + canvas_image
|
252 |
+
else:
|
253 |
+
canvas_image = temp_canvas_image
|
254 |
+
ret_seg_image = canvas_image.astype(np.uint8)
|
255 |
+
|
256 |
+
return ret_seg_image
|
javascript/inpaint-anything.js
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const inpaintAnything_waitForElement = async (parent, selector, exist) => {
|
2 |
+
return new Promise((resolve) => {
|
3 |
+
const observer = new MutationObserver(() => {
|
4 |
+
if (!!parent.querySelector(selector) != exist) {
|
5 |
+
return;
|
6 |
+
}
|
7 |
+
observer.disconnect();
|
8 |
+
resolve(undefined);
|
9 |
+
});
|
10 |
+
|
11 |
+
observer.observe(parent, {
|
12 |
+
childList: true,
|
13 |
+
subtree: true,
|
14 |
+
});
|
15 |
+
|
16 |
+
if (!!parent.querySelector(selector) == exist) {
|
17 |
+
resolve(undefined);
|
18 |
+
}
|
19 |
+
});
|
20 |
+
};
|
21 |
+
|
22 |
+
const inpaintAnything_waitForStyle = async (parent, selector, style) => {
|
23 |
+
return new Promise((resolve) => {
|
24 |
+
const observer = new MutationObserver(() => {
|
25 |
+
if (!parent.querySelector(selector) || !parent.querySelector(selector).style[style]) {
|
26 |
+
return;
|
27 |
+
}
|
28 |
+
observer.disconnect();
|
29 |
+
resolve(undefined);
|
30 |
+
});
|
31 |
+
|
32 |
+
observer.observe(parent, {
|
33 |
+
childList: true,
|
34 |
+
subtree: true,
|
35 |
+
attributes: true,
|
36 |
+
attributeFilter: ["style"],
|
37 |
+
});
|
38 |
+
|
39 |
+
if (!!parent.querySelector(selector) && !!parent.querySelector(selector).style[style]) {
|
40 |
+
resolve(undefined);
|
41 |
+
}
|
42 |
+
});
|
43 |
+
};
|
44 |
+
|
45 |
+
const inpaintAnything_timeout = (ms) => {
|
46 |
+
return new Promise(function (resolve, reject) {
|
47 |
+
setTimeout(() => reject("Timeout"), ms);
|
48 |
+
});
|
49 |
+
};
|
50 |
+
|
51 |
+
async function inpaintAnything_clearSamMask() {
|
52 |
+
const waitForElementToBeInDocument = (parent, selector) =>
|
53 |
+
Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]);
|
54 |
+
|
55 |
+
const elemId = "#ia_sam_image";
|
56 |
+
|
57 |
+
const targetElement = document.querySelector(elemId);
|
58 |
+
if (!targetElement) {
|
59 |
+
return;
|
60 |
+
}
|
61 |
+
await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']");
|
62 |
+
|
63 |
+
targetElement.style.transform = null;
|
64 |
+
targetElement.style.zIndex = null;
|
65 |
+
targetElement.style.overflow = "auto";
|
66 |
+
|
67 |
+
const samMaskClear = targetElement.querySelector("button[aria-label='Clear']");
|
68 |
+
if (!samMaskClear) {
|
69 |
+
return;
|
70 |
+
}
|
71 |
+
const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']");
|
72 |
+
if (!removeImageButton) {
|
73 |
+
return;
|
74 |
+
}
|
75 |
+
samMaskClear?.click();
|
76 |
+
|
77 |
+
if (typeof inpaintAnything_clearSamMask.clickRemoveImage === "undefined") {
|
78 |
+
inpaintAnything_clearSamMask.clickRemoveImage = () => {
|
79 |
+
targetElement.style.transform = null;
|
80 |
+
targetElement.style.zIndex = null;
|
81 |
+
};
|
82 |
+
} else {
|
83 |
+
removeImageButton.removeEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage);
|
84 |
+
}
|
85 |
+
removeImageButton.addEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage);
|
86 |
+
}
|
87 |
+
|
88 |
+
async function inpaintAnything_clearSelMask() {
|
89 |
+
const waitForElementToBeInDocument = (parent, selector) =>
|
90 |
+
Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]);
|
91 |
+
|
92 |
+
const elemId = "#ia_sel_mask";
|
93 |
+
|
94 |
+
const targetElement = document.querySelector(elemId);
|
95 |
+
if (!targetElement) {
|
96 |
+
return;
|
97 |
+
}
|
98 |
+
await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']");
|
99 |
+
|
100 |
+
targetElement.style.transform = null;
|
101 |
+
targetElement.style.zIndex = null;
|
102 |
+
targetElement.style.overflow = "auto";
|
103 |
+
|
104 |
+
const selMaskClear = targetElement.querySelector("button[aria-label='Clear']");
|
105 |
+
if (!selMaskClear) {
|
106 |
+
return;
|
107 |
+
}
|
108 |
+
const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']");
|
109 |
+
if (!removeImageButton) {
|
110 |
+
return;
|
111 |
+
}
|
112 |
+
selMaskClear?.click();
|
113 |
+
|
114 |
+
if (typeof inpaintAnything_clearSelMask.clickRemoveImage === "undefined") {
|
115 |
+
inpaintAnything_clearSelMask.clickRemoveImage = () => {
|
116 |
+
targetElement.style.transform = null;
|
117 |
+
targetElement.style.zIndex = null;
|
118 |
+
};
|
119 |
+
} else {
|
120 |
+
removeImageButton.removeEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage);
|
121 |
+
}
|
122 |
+
removeImageButton.addEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage);
|
123 |
+
}
|
124 |
+
|
125 |
+
async function inpaintAnything_initSamSelMask() {
|
126 |
+
inpaintAnything_clearSamMask();
|
127 |
+
inpaintAnything_clearSelMask();
|
128 |
+
}
|
129 |
+
|
130 |
+
var uiLoadedCallbacks = [];
|
131 |
+
|
132 |
+
function gradioApp() {
|
133 |
+
const elems = document.getElementsByTagName("gradio-app");
|
134 |
+
const elem = elems.length == 0 ? document : elems[0];
|
135 |
+
|
136 |
+
if (elem !== document) {
|
137 |
+
elem.getElementById = function (id) {
|
138 |
+
return document.getElementById(id);
|
139 |
+
};
|
140 |
+
}
|
141 |
+
return elem.shadowRoot ? elem.shadowRoot : elem;
|
142 |
+
}
|
143 |
+
|
144 |
+
function onUiLoaded(callback) {
|
145 |
+
uiLoadedCallbacks.push(callback);
|
146 |
+
}
|
147 |
+
|
148 |
+
function executeCallbacks(queue) {
|
149 |
+
for (const callback of queue) {
|
150 |
+
try {
|
151 |
+
callback();
|
152 |
+
} catch (e) {
|
153 |
+
console.error("error running callback", callback, ":", e);
|
154 |
+
}
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
onUiLoaded(async () => {
|
159 |
+
const elementIDs = {
|
160 |
+
ia_sam_image: "#ia_sam_image",
|
161 |
+
ia_sel_mask: "#ia_sel_mask",
|
162 |
+
ia_out_image: "#ia_out_image",
|
163 |
+
ia_cleaner_out_image: "#ia_cleaner_out_image",
|
164 |
+
};
|
165 |
+
|
166 |
+
function setStyleHeight(elemId, height) {
|
167 |
+
const elem = gradioApp().querySelector(elemId);
|
168 |
+
if (elem) {
|
169 |
+
if (!elem.style.height) {
|
170 |
+
elem.style.height = height;
|
171 |
+
const observer = new MutationObserver(() => {
|
172 |
+
const divPreview = elem.querySelector(".preview");
|
173 |
+
if (divPreview) {
|
174 |
+
divPreview.classList.remove("fixed-height");
|
175 |
+
}
|
176 |
+
});
|
177 |
+
observer.observe(elem, {
|
178 |
+
childList: true,
|
179 |
+
attributes: true,
|
180 |
+
attributeFilter: ["class"],
|
181 |
+
});
|
182 |
+
}
|
183 |
+
}
|
184 |
+
}
|
185 |
+
|
186 |
+
setStyleHeight(elementIDs.ia_out_image, "520px");
|
187 |
+
setStyleHeight(elementIDs.ia_cleaner_out_image, "520px");
|
188 |
+
|
189 |
+
// Default config
|
190 |
+
const defaultHotkeysConfig = {
|
191 |
+
canvas_hotkey_reset: "KeyR",
|
192 |
+
canvas_hotkey_fullscreen: "KeyS",
|
193 |
+
};
|
194 |
+
|
195 |
+
const elemData = {};
|
196 |
+
let activeElement;
|
197 |
+
|
198 |
+
function applyZoomAndPan(elemId) {
|
199 |
+
const targetElement = gradioApp().querySelector(elemId);
|
200 |
+
|
201 |
+
if (!targetElement) {
|
202 |
+
console.log("Element not found");
|
203 |
+
return;
|
204 |
+
}
|
205 |
+
|
206 |
+
targetElement.style.transformOrigin = "0 0";
|
207 |
+
|
208 |
+
elemData[elemId] = {
|
209 |
+
zoomLevel: 1,
|
210 |
+
panX: 0,
|
211 |
+
panY: 0,
|
212 |
+
};
|
213 |
+
let fullScreenMode = false;
|
214 |
+
|
215 |
+
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
216 |
+
function toggleOverlap(forced = "") {
|
217 |
+
// const zIndex1 = "0";
|
218 |
+
const zIndex1 = null;
|
219 |
+
const zIndex2 = "998";
|
220 |
+
|
221 |
+
targetElement.style.zIndex = targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
|
222 |
+
|
223 |
+
if (forced === "off") {
|
224 |
+
targetElement.style.zIndex = zIndex1;
|
225 |
+
} else if (forced === "on") {
|
226 |
+
targetElement.style.zIndex = zIndex2;
|
227 |
+
}
|
228 |
+
}
|
229 |
+
|
230 |
+
/**
|
231 |
+
* This function fits the target element to the screen by calculating
|
232 |
+
* the required scale and offsets. It also updates the global variables
|
233 |
+
* zoomLevel, panX, and panY to reflect the new state.
|
234 |
+
*/
|
235 |
+
|
236 |
+
function fitToElement() {
|
237 |
+
//Reset Zoom
|
238 |
+
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
239 |
+
|
240 |
+
// Get element and screen dimensions
|
241 |
+
const elementWidth = targetElement.offsetWidth;
|
242 |
+
const elementHeight = targetElement.offsetHeight;
|
243 |
+
const parentElement = targetElement.parentElement;
|
244 |
+
const screenWidth = parentElement.clientWidth;
|
245 |
+
const screenHeight = parentElement.clientHeight;
|
246 |
+
|
247 |
+
// Get element's coordinates relative to the parent element
|
248 |
+
const elementRect = targetElement.getBoundingClientRect();
|
249 |
+
const parentRect = parentElement.getBoundingClientRect();
|
250 |
+
const elementX = elementRect.x - parentRect.x;
|
251 |
+
|
252 |
+
// Calculate scale and offsets
|
253 |
+
const scaleX = screenWidth / elementWidth;
|
254 |
+
const scaleY = screenHeight / elementHeight;
|
255 |
+
const scale = Math.min(scaleX, scaleY);
|
256 |
+
|
257 |
+
const transformOrigin = window.getComputedStyle(targetElement).transformOrigin;
|
258 |
+
const [originX, originY] = transformOrigin.split(" ");
|
259 |
+
const originXValue = parseFloat(originX);
|
260 |
+
const originYValue = parseFloat(originY);
|
261 |
+
|
262 |
+
const offsetX = (screenWidth - elementWidth * scale) / 2 - originXValue * (1 - scale);
|
263 |
+
const offsetY = (screenHeight - elementHeight * scale) / 2.5 - originYValue * (1 - scale);
|
264 |
+
|
265 |
+
// Apply scale and offsets to the element
|
266 |
+
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
267 |
+
|
268 |
+
// Update global variables
|
269 |
+
elemData[elemId].zoomLevel = scale;
|
270 |
+
elemData[elemId].panX = offsetX;
|
271 |
+
elemData[elemId].panY = offsetY;
|
272 |
+
|
273 |
+
fullScreenMode = false;
|
274 |
+
toggleOverlap("off");
|
275 |
+
}
|
276 |
+
|
277 |
+
// Reset the zoom level and pan position of the target element to their initial values
|
278 |
+
function resetZoom() {
|
279 |
+
elemData[elemId] = {
|
280 |
+
zoomLevel: 1,
|
281 |
+
panX: 0,
|
282 |
+
panY: 0,
|
283 |
+
};
|
284 |
+
|
285 |
+
// fixCanvas();
|
286 |
+
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
287 |
+
|
288 |
+
// const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`);
|
289 |
+
|
290 |
+
toggleOverlap("off");
|
291 |
+
fullScreenMode = false;
|
292 |
+
|
293 |
+
// if (
|
294 |
+
// canvas &&
|
295 |
+
// parseFloat(canvas.style.width) > 865 &&
|
296 |
+
// parseFloat(targetElement.style.width) > 865
|
297 |
+
// ) {
|
298 |
+
// fitToElement();
|
299 |
+
// return;
|
300 |
+
// }
|
301 |
+
|
302 |
+
// targetElement.style.width = "";
|
303 |
+
// if (canvas) {
|
304 |
+
// targetElement.style.height = canvas.style.height;
|
305 |
+
// }
|
306 |
+
targetElement.style.width = null;
|
307 |
+
targetElement.style.height = 480;
|
308 |
+
}
|
309 |
+
|
310 |
+
/**
|
311 |
+
* This function fits the target element to the screen by calculating
|
312 |
+
* the required scale and offsets. It also updates the global variables
|
313 |
+
* zoomLevel, panX, and panY to reflect the new state.
|
314 |
+
*/
|
315 |
+
|
316 |
+
// Fullscreen mode
|
317 |
+
function fitToScreen() {
|
318 |
+
const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`);
|
319 |
+
const img = gradioApp().querySelector(`${elemId} img`);
|
320 |
+
|
321 |
+
if (!canvas && !img) return;
|
322 |
+
|
323 |
+
// if (canvas.offsetWidth > 862) {
|
324 |
+
// targetElement.style.width = canvas.offsetWidth + "px";
|
325 |
+
// }
|
326 |
+
|
327 |
+
if (fullScreenMode) {
|
328 |
+
resetZoom();
|
329 |
+
fullScreenMode = false;
|
330 |
+
return;
|
331 |
+
}
|
332 |
+
|
333 |
+
//Reset Zoom
|
334 |
+
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
335 |
+
|
336 |
+
// Get scrollbar width to right-align the image
|
337 |
+
const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth;
|
338 |
+
|
339 |
+
// Get element and screen dimensions
|
340 |
+
const elementWidth = targetElement.offsetWidth;
|
341 |
+
const elementHeight = targetElement.offsetHeight;
|
342 |
+
const screenWidth = window.innerWidth - scrollbarWidth;
|
343 |
+
const screenHeight = window.innerHeight;
|
344 |
+
|
345 |
+
// Get element's coordinates relative to the page
|
346 |
+
const elementRect = targetElement.getBoundingClientRect();
|
347 |
+
const elementY = elementRect.y;
|
348 |
+
const elementX = elementRect.x;
|
349 |
+
|
350 |
+
// Calculate scale and offsets
|
351 |
+
const scaleX = screenWidth / elementWidth;
|
352 |
+
const scaleY = screenHeight / elementHeight;
|
353 |
+
const scale = Math.min(scaleX, scaleY);
|
354 |
+
|
355 |
+
// Get the current transformOrigin
|
356 |
+
const computedStyle = window.getComputedStyle(targetElement);
|
357 |
+
const transformOrigin = computedStyle.transformOrigin;
|
358 |
+
const [originX, originY] = transformOrigin.split(" ");
|
359 |
+
const originXValue = parseFloat(originX);
|
360 |
+
const originYValue = parseFloat(originY);
|
361 |
+
|
362 |
+
// Calculate offsets with respect to the transformOrigin
|
363 |
+
const offsetX = (screenWidth - elementWidth * scale) / 2 - elementX - originXValue * (1 - scale);
|
364 |
+
const offsetY = (screenHeight - elementHeight * scale) / 2 - elementY - originYValue * (1 - scale);
|
365 |
+
|
366 |
+
// Apply scale and offsets to the element
|
367 |
+
targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
|
368 |
+
|
369 |
+
// Update global variables
|
370 |
+
elemData[elemId].zoomLevel = scale;
|
371 |
+
elemData[elemId].panX = offsetX;
|
372 |
+
elemData[elemId].panY = offsetY;
|
373 |
+
|
374 |
+
fullScreenMode = true;
|
375 |
+
toggleOverlap("on");
|
376 |
+
}
|
377 |
+
|
378 |
+
// Reset zoom when uploading a new image
|
379 |
+
const fileInput = gradioApp().querySelector(`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`);
|
380 |
+
if (fileInput) {
|
381 |
+
fileInput.addEventListener("click", resetZoom);
|
382 |
+
}
|
383 |
+
|
384 |
+
// Handle keydown events
|
385 |
+
function handleKeyDown(event) {
|
386 |
+
// Disable key locks to make pasting from the buffer work correctly
|
387 |
+
if (
|
388 |
+
(event.ctrlKey && event.code === "KeyV") ||
|
389 |
+
(event.ctrlKey && event.code === "KeyC") ||
|
390 |
+
event.code === "F5"
|
391 |
+
) {
|
392 |
+
return;
|
393 |
+
}
|
394 |
+
|
395 |
+
// before activating shortcut, ensure user is not actively typing in an input field
|
396 |
+
if (event.target.nodeName === "TEXTAREA" || event.target.nodeName === "INPUT") {
|
397 |
+
return;
|
398 |
+
}
|
399 |
+
|
400 |
+
const hotkeyActions = {
|
401 |
+
[defaultHotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
402 |
+
[defaultHotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
|
403 |
+
};
|
404 |
+
|
405 |
+
const action = hotkeyActions[event.code];
|
406 |
+
if (action) {
|
407 |
+
event.preventDefault();
|
408 |
+
action(event);
|
409 |
+
}
|
410 |
+
}
|
411 |
+
|
412 |
+
// Handle events only inside the targetElement
|
413 |
+
let isKeyDownHandlerAttached = false;
|
414 |
+
|
415 |
+
function handleMouseMove() {
|
416 |
+
if (!isKeyDownHandlerAttached) {
|
417 |
+
document.addEventListener("keydown", handleKeyDown);
|
418 |
+
isKeyDownHandlerAttached = true;
|
419 |
+
|
420 |
+
activeElement = elemId;
|
421 |
+
}
|
422 |
+
}
|
423 |
+
|
424 |
+
function handleMouseLeave() {
|
425 |
+
if (isKeyDownHandlerAttached) {
|
426 |
+
document.removeEventListener("keydown", handleKeyDown);
|
427 |
+
isKeyDownHandlerAttached = false;
|
428 |
+
|
429 |
+
activeElement = null;
|
430 |
+
}
|
431 |
+
}
|
432 |
+
|
433 |
+
// Add mouse event handlers
|
434 |
+
targetElement.addEventListener("mousemove", handleMouseMove);
|
435 |
+
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
436 |
+
}
|
437 |
+
|
438 |
+
applyZoomAndPan(elementIDs.ia_sam_image);
|
439 |
+
applyZoomAndPan(elementIDs.ia_sel_mask);
|
440 |
+
// applyZoomAndPan(elementIDs.ia_out_image);
|
441 |
+
// applyZoomAndPan(elementIDs.ia_cleaner_out_image);
|
442 |
+
});
|
443 |
+
|
444 |
+
var executedOnLoaded = false;
|
445 |
+
|
446 |
+
document.addEventListener("DOMContentLoaded", function () {
|
447 |
+
var mutationObserver = new MutationObserver(function () {
|
448 |
+
if (
|
449 |
+
!executedOnLoaded &&
|
450 |
+
gradioApp().querySelector("#ia_sam_image") &&
|
451 |
+
gradioApp().querySelector("#ia_sel_mask")
|
452 |
+
) {
|
453 |
+
executedOnLoaded = true;
|
454 |
+
executeCallbacks(uiLoadedCallbacks);
|
455 |
+
}
|
456 |
+
});
|
457 |
+
mutationObserver.observe(gradioApp(), { childList: true, subtree: true });
|
458 |
+
});
|
lama_cleaner/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
+
|
5 |
+
import warnings # noqa: E402
|
6 |
+
|
7 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
8 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner")
|
9 |
+
|
10 |
+
from lama_cleaner.parse_args import parse_args # noqa: E402
|
11 |
+
|
12 |
+
|
13 |
+
def entry_point():
|
14 |
+
args = parse_args()
|
15 |
+
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
16 |
+
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
17 |
+
from lama_cleaner.server import main
|
18 |
+
|
19 |
+
main(args)
|
lama_cleaner/benchmark.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import nvidia_smi
|
9 |
+
import psutil
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from lama_cleaner.model_manager import ModelManager
|
13 |
+
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
14 |
+
|
15 |
+
try:
|
16 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
17 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
18 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
19 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
|
23 |
+
NUM_THREADS = str(4)
|
24 |
+
|
25 |
+
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
26 |
+
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
27 |
+
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
28 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
29 |
+
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
30 |
+
if os.environ.get("CACHE_DIR"):
|
31 |
+
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
32 |
+
|
33 |
+
|
34 |
+
def run_model(model, size):
|
35 |
+
# RGB
|
36 |
+
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
37 |
+
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
38 |
+
|
39 |
+
config = Config(
|
40 |
+
ldm_steps=2,
|
41 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
42 |
+
hd_strategy_crop_margin=128,
|
43 |
+
hd_strategy_crop_trigger_size=128,
|
44 |
+
hd_strategy_resize_limit=128,
|
45 |
+
prompt="a fox is sitting on a bench",
|
46 |
+
sd_steps=5,
|
47 |
+
sd_sampler=SDSampler.ddim
|
48 |
+
)
|
49 |
+
model(image, mask, config)
|
50 |
+
|
51 |
+
|
52 |
+
def benchmark(model, times: int, empty_cache: bool):
|
53 |
+
sizes = [(512, 512)]
|
54 |
+
|
55 |
+
nvidia_smi.nvmlInit()
|
56 |
+
device_id = 0
|
57 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
58 |
+
|
59 |
+
def format(metrics):
|
60 |
+
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
61 |
+
|
62 |
+
process = psutil.Process(os.getpid())
|
63 |
+
# 每个 size 给出显存和内存占用的指标
|
64 |
+
for size in sizes:
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
time_metrics = []
|
67 |
+
cpu_metrics = []
|
68 |
+
memory_metrics = []
|
69 |
+
gpu_memory_metrics = []
|
70 |
+
for _ in range(times):
|
71 |
+
start = time.time()
|
72 |
+
run_model(model, size)
|
73 |
+
torch.cuda.synchronize()
|
74 |
+
|
75 |
+
# cpu_metrics.append(process.cpu_percent())
|
76 |
+
time_metrics.append((time.time() - start) * 1000)
|
77 |
+
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
78 |
+
gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
|
79 |
+
|
80 |
+
print(f"size: {size}".center(80, "-"))
|
81 |
+
# print(f"cpu: {format(cpu_metrics)}")
|
82 |
+
print(f"latency: {format(time_metrics)}ms")
|
83 |
+
print(f"memory: {format(memory_metrics)} MB")
|
84 |
+
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
85 |
+
|
86 |
+
nvidia_smi.nvmlShutdown()
|
87 |
+
|
88 |
+
|
89 |
+
def get_args_parser():
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
parser.add_argument("--name")
|
92 |
+
parser.add_argument("--device", default="cuda", type=str)
|
93 |
+
parser.add_argument("--times", default=10, type=int)
|
94 |
+
parser.add_argument("--empty-cache", action="store_true")
|
95 |
+
return parser.parse_args()
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
args = get_args_parser()
|
100 |
+
device = torch.device(args.device)
|
101 |
+
model = ModelManager(
|
102 |
+
name=args.name,
|
103 |
+
device=device,
|
104 |
+
sd_run_local=True,
|
105 |
+
disable_nsfw=True,
|
106 |
+
sd_cpu_textencoder=True,
|
107 |
+
hf_access_token="123"
|
108 |
+
)
|
109 |
+
benchmark(model, args.times, args.empty_cache)
|
lama_cleaner/const.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from enum import Enum
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
MPS_SUPPORT_MODELS = [
|
8 |
+
"instruct_pix2pix",
|
9 |
+
"sd1.5",
|
10 |
+
"anything4",
|
11 |
+
"realisticVision1.4",
|
12 |
+
"sd2",
|
13 |
+
"paint_by_example",
|
14 |
+
"controlnet",
|
15 |
+
]
|
16 |
+
|
17 |
+
DEFAULT_MODEL = "lama"
|
18 |
+
AVAILABLE_MODELS = [
|
19 |
+
"lama",
|
20 |
+
"ldm",
|
21 |
+
"zits",
|
22 |
+
"mat",
|
23 |
+
"fcf",
|
24 |
+
"sd1.5",
|
25 |
+
"anything4",
|
26 |
+
"realisticVision1.4",
|
27 |
+
"cv2",
|
28 |
+
"manga",
|
29 |
+
"sd2",
|
30 |
+
"paint_by_example",
|
31 |
+
"instruct_pix2pix",
|
32 |
+
]
|
33 |
+
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
34 |
+
|
35 |
+
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
36 |
+
DEFAULT_DEVICE = "cuda"
|
37 |
+
|
38 |
+
NO_HALF_HELP = """
|
39 |
+
Using full precision model.
|
40 |
+
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
|
41 |
+
"""
|
42 |
+
|
43 |
+
CPU_OFFLOAD_HELP = """
|
44 |
+
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
|
45 |
+
"""
|
46 |
+
|
47 |
+
DISABLE_NSFW_HELP = """
|
48 |
+
Disable NSFW checker. (sd/paint_by_example)
|
49 |
+
"""
|
50 |
+
|
51 |
+
SD_CPU_TEXTENCODER_HELP = """
|
52 |
+
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
53 |
+
"""
|
54 |
+
|
55 |
+
SD_CONTROLNET_HELP = """
|
56 |
+
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
|
57 |
+
"""
|
58 |
+
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
|
59 |
+
SD_CONTROLNET_CHOICES = [
|
60 |
+
"control_v11p_sd15_canny",
|
61 |
+
"control_v11p_sd15_openpose",
|
62 |
+
"control_v11p_sd15_inpaint",
|
63 |
+
"control_v11f1p_sd15_depth"
|
64 |
+
]
|
65 |
+
|
66 |
+
SD_LOCAL_MODEL_HELP = """
|
67 |
+
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
|
68 |
+
"""
|
69 |
+
|
70 |
+
LOCAL_FILES_ONLY_HELP = """
|
71 |
+
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
72 |
+
"""
|
73 |
+
|
74 |
+
ENABLE_XFORMERS_HELP = """
|
75 |
+
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
|
76 |
+
"""
|
77 |
+
|
78 |
+
DEFAULT_MODEL_DIR = os.getenv(
|
79 |
+
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
|
80 |
+
)
|
81 |
+
MODEL_DIR_HELP = """
|
82 |
+
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
83 |
+
"""
|
84 |
+
|
85 |
+
OUTPUT_DIR_HELP = """
|
86 |
+
Result images will be saved to output directory automatically without confirmation.
|
87 |
+
"""
|
88 |
+
|
89 |
+
INPUT_HELP = """
|
90 |
+
If input is image, it will be loaded by default.
|
91 |
+
If input is directory, you can browse and select image in file manager.
|
92 |
+
"""
|
93 |
+
|
94 |
+
GUI_HELP = """
|
95 |
+
Launch Lama Cleaner as desktop app
|
96 |
+
"""
|
97 |
+
|
98 |
+
NO_GUI_AUTO_CLOSE_HELP = """
|
99 |
+
Prevent backend auto close after the GUI window closed.
|
100 |
+
"""
|
101 |
+
|
102 |
+
QUALITY_HELP = """
|
103 |
+
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
|
104 |
+
"""
|
105 |
+
|
106 |
+
|
107 |
+
class RealESRGANModelName(str, Enum):
|
108 |
+
realesr_general_x4v3 = "realesr-general-x4v3"
|
109 |
+
RealESRGAN_x4plus = "RealESRGAN_x4plus"
|
110 |
+
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
111 |
+
|
112 |
+
|
113 |
+
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
114 |
+
|
115 |
+
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
116 |
+
INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
117 |
+
AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h"]
|
118 |
+
AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
|
119 |
+
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
|
120 |
+
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
|
121 |
+
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
122 |
+
REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
123 |
+
GFPGAN_HELP = (
|
124 |
+
"Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
|
125 |
+
)
|
126 |
+
GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
127 |
+
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
|
128 |
+
RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
129 |
+
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
130 |
+
|
131 |
+
|
132 |
+
class Config(BaseModel):
|
133 |
+
host: str = "127.0.0.1"
|
134 |
+
port: int = 8080
|
135 |
+
model: str = DEFAULT_MODEL
|
136 |
+
sd_local_model_path: str = None
|
137 |
+
sd_controlnet: bool = False
|
138 |
+
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
|
139 |
+
device: str = DEFAULT_DEVICE
|
140 |
+
gui: bool = False
|
141 |
+
no_gui_auto_close: bool = False
|
142 |
+
no_half: bool = False
|
143 |
+
cpu_offload: bool = False
|
144 |
+
disable_nsfw: bool = False
|
145 |
+
sd_cpu_textencoder: bool = False
|
146 |
+
enable_xformers: bool = False
|
147 |
+
local_files_only: bool = False
|
148 |
+
model_dir: str = DEFAULT_MODEL_DIR
|
149 |
+
input: str = None
|
150 |
+
output_dir: str = None
|
151 |
+
# plugins
|
152 |
+
enable_interactive_seg: bool = False
|
153 |
+
interactive_seg_model: str = "vit_l"
|
154 |
+
interactive_seg_device: str = "cpu"
|
155 |
+
enable_remove_bg: bool = False
|
156 |
+
enable_anime_seg: bool = False
|
157 |
+
enable_realesrgan: bool = False
|
158 |
+
realesrgan_device: str = "cpu"
|
159 |
+
realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
|
160 |
+
realesrgan_no_half: bool = False
|
161 |
+
enable_gfpgan: bool = False
|
162 |
+
gfpgan_device: str = "cpu"
|
163 |
+
enable_restoreformer: bool = False
|
164 |
+
restoreformer_device: str = "cpu"
|
165 |
+
enable_gif: bool = False
|
166 |
+
|
167 |
+
|
168 |
+
def load_config(installer_config: str):
|
169 |
+
if os.path.exists(installer_config):
|
170 |
+
with open(installer_config, "r", encoding="utf-8") as f:
|
171 |
+
return Config(**json.load(f))
|
172 |
+
else:
|
173 |
+
return Config()
|
lama_cleaner/file_manager/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .file_manager import FileManager
|
lama_cleaner/file_manager/file_manager.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import time
|
7 |
+
from io import BytesIO
|
8 |
+
from pathlib import Path
|
9 |
+
import numpy as np
|
10 |
+
# from watchdog.events import FileSystemEventHandler
|
11 |
+
# from watchdog.observers import Observer
|
12 |
+
|
13 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
14 |
+
from loguru import logger
|
15 |
+
|
16 |
+
LARGE_ENOUGH_NUMBER = 100
|
17 |
+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
18 |
+
from .storage_backends import FilesystemStorageBackend
|
19 |
+
from .utils import aspect_to_string, generate_filename, glob_img
|
20 |
+
|
21 |
+
|
22 |
+
class FileManager:
|
23 |
+
def __init__(self, app=None):
|
24 |
+
self.app = app
|
25 |
+
self._default_root_directory = "media"
|
26 |
+
self._default_thumbnail_directory = "media"
|
27 |
+
self._default_root_url = "/"
|
28 |
+
self._default_thumbnail_root_url = "/"
|
29 |
+
self._default_format = "JPEG"
|
30 |
+
self.output_dir: Path = None
|
31 |
+
|
32 |
+
if app is not None:
|
33 |
+
self.init_app(app)
|
34 |
+
|
35 |
+
self.image_dir_filenames = []
|
36 |
+
self.output_dir_filenames = []
|
37 |
+
|
38 |
+
self.image_dir_observer = None
|
39 |
+
self.output_dir_observer = None
|
40 |
+
|
41 |
+
self.modified_time = {
|
42 |
+
"image": datetime.utcnow(),
|
43 |
+
"output": datetime.utcnow(),
|
44 |
+
}
|
45 |
+
|
46 |
+
# def start(self):
|
47 |
+
# self.image_dir_filenames = self._media_names(self.root_directory)
|
48 |
+
# self.output_dir_filenames = self._media_names(self.output_dir)
|
49 |
+
#
|
50 |
+
# logger.info(f"Start watching image directory: {self.root_directory}")
|
51 |
+
# self.image_dir_observer = Observer()
|
52 |
+
# self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
53 |
+
# self.image_dir_observer.start()
|
54 |
+
#
|
55 |
+
# logger.info(f"Start watching output directory: {self.output_dir}")
|
56 |
+
# self.output_dir_observer = Observer()
|
57 |
+
# self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
58 |
+
# self.output_dir_observer.start()
|
59 |
+
|
60 |
+
def on_modified(self, event):
|
61 |
+
if not os.path.isdir(event.src_path):
|
62 |
+
return
|
63 |
+
if event.src_path == str(self.root_directory):
|
64 |
+
logger.info(f"Image directory {event.src_path} modified")
|
65 |
+
self.image_dir_filenames = self._media_names(self.root_directory)
|
66 |
+
self.modified_time["image"] = datetime.utcnow()
|
67 |
+
elif event.src_path == str(self.output_dir):
|
68 |
+
logger.info(f"Output directory {event.src_path} modified")
|
69 |
+
self.output_dir_filenames = self._media_names(self.output_dir)
|
70 |
+
self.modified_time["output"] = datetime.utcnow()
|
71 |
+
|
72 |
+
def init_app(self, app):
|
73 |
+
if self.app is None:
|
74 |
+
self.app = app
|
75 |
+
app.thumbnail_instance = self
|
76 |
+
|
77 |
+
if not hasattr(app, "extensions"):
|
78 |
+
app.extensions = {}
|
79 |
+
|
80 |
+
if "thumbnail" in app.extensions:
|
81 |
+
raise RuntimeError("Flask-thumbnail extension already initialized")
|
82 |
+
|
83 |
+
app.extensions["thumbnail"] = self
|
84 |
+
|
85 |
+
app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
|
86 |
+
app.config.setdefault(
|
87 |
+
"THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory
|
88 |
+
)
|
89 |
+
app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
|
90 |
+
app.config.setdefault(
|
91 |
+
"THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url
|
92 |
+
)
|
93 |
+
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
94 |
+
|
95 |
+
@property
|
96 |
+
def root_directory(self):
|
97 |
+
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
98 |
+
|
99 |
+
if os.path.isabs(path):
|
100 |
+
return path
|
101 |
+
else:
|
102 |
+
return os.path.join(self.app.root_path, path)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def thumbnail_directory(self):
|
106 |
+
path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
|
107 |
+
|
108 |
+
if os.path.isabs(path):
|
109 |
+
return path
|
110 |
+
else:
|
111 |
+
return os.path.join(self.app.root_path, path)
|
112 |
+
|
113 |
+
@property
|
114 |
+
def root_url(self):
|
115 |
+
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def media_names(self):
|
119 |
+
# return self.image_dir_filenames
|
120 |
+
return self._media_names(self.root_directory)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def output_media_names(self):
|
124 |
+
return self._media_names(self.output_dir)
|
125 |
+
# return self.output_dir_filenames
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _media_names(directory: Path):
|
129 |
+
names = sorted([it.name for it in glob_img(directory)])
|
130 |
+
res = []
|
131 |
+
for name in names:
|
132 |
+
path = os.path.join(directory, name)
|
133 |
+
img = Image.open(path)
|
134 |
+
res.append(
|
135 |
+
{
|
136 |
+
"name": name,
|
137 |
+
"height": img.height,
|
138 |
+
"width": img.width,
|
139 |
+
"ctime": os.path.getctime(path),
|
140 |
+
"mtime": os.path.getmtime(path),
|
141 |
+
}
|
142 |
+
)
|
143 |
+
return res
|
144 |
+
|
145 |
+
@property
|
146 |
+
def thumbnail_url(self):
|
147 |
+
return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
|
148 |
+
|
149 |
+
def get_thumbnail(
|
150 |
+
self, directory: Path, original_filename: str, width, height, **options
|
151 |
+
):
|
152 |
+
storage = FilesystemStorageBackend(self.app)
|
153 |
+
crop = options.get("crop", "fit")
|
154 |
+
background = options.get("background")
|
155 |
+
quality = options.get("quality", 90)
|
156 |
+
|
157 |
+
original_path, original_filename = os.path.split(original_filename)
|
158 |
+
original_filepath = os.path.join(directory, original_path, original_filename)
|
159 |
+
image = Image.open(BytesIO(storage.read(original_filepath)))
|
160 |
+
|
161 |
+
# keep ratio resize
|
162 |
+
if width is not None:
|
163 |
+
height = int(image.height * width / image.width)
|
164 |
+
else:
|
165 |
+
width = int(image.width * height / image.height)
|
166 |
+
|
167 |
+
thumbnail_size = (width, height)
|
168 |
+
|
169 |
+
thumbnail_filename = generate_filename(
|
170 |
+
original_filename,
|
171 |
+
aspect_to_string(thumbnail_size),
|
172 |
+
crop,
|
173 |
+
background,
|
174 |
+
quality,
|
175 |
+
)
|
176 |
+
|
177 |
+
thumbnail_filepath = os.path.join(
|
178 |
+
self.thumbnail_directory, original_path, thumbnail_filename
|
179 |
+
)
|
180 |
+
thumbnail_url = os.path.join(
|
181 |
+
self.thumbnail_url, original_path, thumbnail_filename
|
182 |
+
)
|
183 |
+
|
184 |
+
if storage.exists(thumbnail_filepath):
|
185 |
+
return thumbnail_url, (width, height)
|
186 |
+
|
187 |
+
try:
|
188 |
+
image.load()
|
189 |
+
except (IOError, OSError):
|
190 |
+
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
191 |
+
return thumbnail_url, (width, height)
|
192 |
+
|
193 |
+
# get original image format
|
194 |
+
options["format"] = options.get("format", image.format)
|
195 |
+
|
196 |
+
image = self._create_thumbnail(
|
197 |
+
image, thumbnail_size, crop, background=background
|
198 |
+
)
|
199 |
+
|
200 |
+
raw_data = self.get_raw_data(image, **options)
|
201 |
+
storage.save(thumbnail_filepath, raw_data)
|
202 |
+
|
203 |
+
return thumbnail_url, (width, height)
|
204 |
+
|
205 |
+
def get_raw_data(self, image, **options):
|
206 |
+
data = {
|
207 |
+
"format": self._get_format(image, **options),
|
208 |
+
"quality": options.get("quality", 90),
|
209 |
+
}
|
210 |
+
|
211 |
+
_file = BytesIO()
|
212 |
+
image.save(_file, **data)
|
213 |
+
return _file.getvalue()
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def colormode(image, colormode="RGB"):
|
217 |
+
if colormode == "RGB" or colormode == "RGBA":
|
218 |
+
if image.mode == "RGBA":
|
219 |
+
return image
|
220 |
+
if image.mode == "LA":
|
221 |
+
return image.convert("RGBA")
|
222 |
+
return image.convert(colormode)
|
223 |
+
|
224 |
+
if colormode == "GRAY":
|
225 |
+
return image.convert("L")
|
226 |
+
|
227 |
+
return image.convert(colormode)
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def background(original_image, color=0xFF):
|
231 |
+
size = (max(original_image.size),) * 2
|
232 |
+
image = Image.new("L", size, color)
|
233 |
+
image.paste(
|
234 |
+
original_image,
|
235 |
+
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
|
236 |
+
)
|
237 |
+
|
238 |
+
return image
|
239 |
+
|
240 |
+
def _get_format(self, image, **options):
|
241 |
+
if options.get("format"):
|
242 |
+
return options.get("format")
|
243 |
+
if image.format:
|
244 |
+
return image.format
|
245 |
+
|
246 |
+
return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
|
247 |
+
|
248 |
+
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
249 |
+
try:
|
250 |
+
resample = Image.Resampling.LANCZOS
|
251 |
+
except AttributeError: # pylint: disable=raise-missing-from
|
252 |
+
resample = Image.ANTIALIAS
|
253 |
+
|
254 |
+
if crop == "fit":
|
255 |
+
image = ImageOps.fit(image, size, resample)
|
256 |
+
else:
|
257 |
+
image = image.copy()
|
258 |
+
image.thumbnail(size, resample=resample)
|
259 |
+
|
260 |
+
if background is not None:
|
261 |
+
image = self.background(image)
|
262 |
+
|
263 |
+
image = self.colormode(image)
|
264 |
+
|
265 |
+
return image
|
lama_cleaner/file_manager/storage_backends.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
|
6 |
+
|
7 |
+
class BaseStorageBackend(ABC):
|
8 |
+
def __init__(self, app=None):
|
9 |
+
self.app = app
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def read(self, filepath, mode="rb", **kwargs):
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def exists(self, filepath):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def save(self, filepath, data):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
|
24 |
+
class FilesystemStorageBackend(BaseStorageBackend):
|
25 |
+
def read(self, filepath, mode="rb", **kwargs):
|
26 |
+
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
|
27 |
+
return f.read()
|
28 |
+
|
29 |
+
def exists(self, filepath):
|
30 |
+
return os.path.exists(filepath)
|
31 |
+
|
32 |
+
def save(self, filepath, data):
|
33 |
+
directory = os.path.dirname(filepath)
|
34 |
+
|
35 |
+
if not os.path.exists(directory):
|
36 |
+
try:
|
37 |
+
os.makedirs(directory)
|
38 |
+
except OSError as e:
|
39 |
+
if e.errno != errno.EEXIST:
|
40 |
+
raise
|
41 |
+
|
42 |
+
if not os.path.isdir(directory):
|
43 |
+
raise IOError("{} is not a directory".format(directory))
|
44 |
+
|
45 |
+
with open(filepath, "wb") as f:
|
46 |
+
f.write(data)
|
lama_cleaner/file_manager/utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
|
9 |
+
def generate_filename(original_filename, *options):
|
10 |
+
name, ext = os.path.splitext(original_filename)
|
11 |
+
for v in options:
|
12 |
+
if v:
|
13 |
+
name += "_%s" % v
|
14 |
+
name += ext
|
15 |
+
|
16 |
+
return name
|
17 |
+
|
18 |
+
|
19 |
+
def parse_size(size):
|
20 |
+
if isinstance(size, int):
|
21 |
+
# If the size parameter is a single number, assume square aspect.
|
22 |
+
return [size, size]
|
23 |
+
|
24 |
+
if isinstance(size, (tuple, list)):
|
25 |
+
if len(size) == 1:
|
26 |
+
# If single value tuple/list is provided, exand it to two elements
|
27 |
+
return size + type(size)(size)
|
28 |
+
return size
|
29 |
+
|
30 |
+
try:
|
31 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
32 |
+
except ValueError:
|
33 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
34 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
35 |
+
)
|
36 |
+
|
37 |
+
if len(thumbnail_size) == 1:
|
38 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
39 |
+
thumbnail_size.append(thumbnail_size[0])
|
40 |
+
|
41 |
+
return thumbnail_size
|
42 |
+
|
43 |
+
|
44 |
+
def aspect_to_string(size):
|
45 |
+
if isinstance(size, str):
|
46 |
+
return size
|
47 |
+
|
48 |
+
return "x".join(map(str, size))
|
49 |
+
|
50 |
+
|
51 |
+
IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
|
52 |
+
|
53 |
+
|
54 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
55 |
+
p = Path(p)
|
56 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
57 |
+
yield p
|
58 |
+
else:
|
59 |
+
if recursive:
|
60 |
+
files = Path(p).glob("**/*.*")
|
61 |
+
else:
|
62 |
+
files = Path(p).glob("*.*")
|
63 |
+
|
64 |
+
for it in files:
|
65 |
+
if it.suffix not in IMG_SUFFIX:
|
66 |
+
continue
|
67 |
+
yield it
|
lama_cleaner/helper.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
import cv2
|
8 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from lama_cleaner.const import MPS_SUPPORT_MODELS
|
12 |
+
from loguru import logger
|
13 |
+
from torch.hub import download_url_to_file, get_dir
|
14 |
+
import hashlib
|
15 |
+
|
16 |
+
|
17 |
+
def md5sum(filename):
|
18 |
+
md5 = hashlib.md5()
|
19 |
+
with open(filename, "rb") as f:
|
20 |
+
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
21 |
+
md5.update(chunk)
|
22 |
+
return md5.hexdigest()
|
23 |
+
|
24 |
+
|
25 |
+
def switch_mps_device(model_name, device):
|
26 |
+
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
|
27 |
+
logger.info(f"{model_name} not support mps, switch to cpu")
|
28 |
+
return torch.device("cpu")
|
29 |
+
return device
|
30 |
+
|
31 |
+
|
32 |
+
def get_cache_path_by_url(url):
|
33 |
+
parts = urlparse(url)
|
34 |
+
hub_dir = get_dir()
|
35 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
36 |
+
if not os.path.isdir(model_dir):
|
37 |
+
os.makedirs(model_dir)
|
38 |
+
filename = os.path.basename(parts.path)
|
39 |
+
cached_file = os.path.join(model_dir, filename)
|
40 |
+
return cached_file
|
41 |
+
|
42 |
+
|
43 |
+
def download_model(url, model_md5: str = None):
|
44 |
+
cached_file = get_cache_path_by_url(url)
|
45 |
+
if not os.path.exists(cached_file):
|
46 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
47 |
+
hash_prefix = None
|
48 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
49 |
+
if model_md5:
|
50 |
+
_md5 = md5sum(cached_file)
|
51 |
+
if model_md5 == _md5:
|
52 |
+
logger.info(f"Download model success, md5: {_md5}")
|
53 |
+
else:
|
54 |
+
try:
|
55 |
+
os.remove(cached_file)
|
56 |
+
logger.error(
|
57 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner."
|
58 |
+
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
59 |
+
)
|
60 |
+
except:
|
61 |
+
logger.error(
|
62 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart lama-cleaner."
|
63 |
+
)
|
64 |
+
exit(-1)
|
65 |
+
|
66 |
+
return cached_file
|
67 |
+
|
68 |
+
|
69 |
+
def ceil_modulo(x, mod):
|
70 |
+
if x % mod == 0:
|
71 |
+
return x
|
72 |
+
return (x // mod + 1) * mod
|
73 |
+
|
74 |
+
|
75 |
+
def handle_error(model_path, model_md5, e):
|
76 |
+
_md5 = md5sum(model_path)
|
77 |
+
if _md5 != model_md5:
|
78 |
+
try:
|
79 |
+
os.remove(model_path)
|
80 |
+
logger.error(
|
81 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner."
|
82 |
+
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
83 |
+
)
|
84 |
+
except:
|
85 |
+
logger.error(
|
86 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart lama-cleaner."
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
logger.error(
|
90 |
+
f"Failed to load model {model_path},"
|
91 |
+
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
92 |
+
)
|
93 |
+
exit(-1)
|
94 |
+
|
95 |
+
|
96 |
+
def load_jit_model(url_or_path, device, model_md5: str):
|
97 |
+
if os.path.exists(url_or_path):
|
98 |
+
model_path = url_or_path
|
99 |
+
else:
|
100 |
+
model_path = download_model(url_or_path, model_md5)
|
101 |
+
|
102 |
+
logger.info(f"Loading model from: {model_path}")
|
103 |
+
try:
|
104 |
+
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
105 |
+
except Exception as e:
|
106 |
+
handle_error(model_path, model_md5, e)
|
107 |
+
model.eval()
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
|
112 |
+
if os.path.exists(url_or_path):
|
113 |
+
model_path = url_or_path
|
114 |
+
else:
|
115 |
+
model_path = download_model(url_or_path, model_md5)
|
116 |
+
|
117 |
+
try:
|
118 |
+
logger.info(f"Loading model from: {model_path}")
|
119 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
120 |
+
model.load_state_dict(state_dict, strict=True)
|
121 |
+
model.to(device)
|
122 |
+
except Exception as e:
|
123 |
+
handle_error(model_path, model_md5, e)
|
124 |
+
model.eval()
|
125 |
+
return model
|
126 |
+
|
127 |
+
|
128 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
129 |
+
data = cv2.imencode(
|
130 |
+
f".{ext}",
|
131 |
+
image_numpy,
|
132 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
133 |
+
)[1]
|
134 |
+
image_bytes = data.tobytes()
|
135 |
+
return image_bytes
|
136 |
+
|
137 |
+
|
138 |
+
def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif_infos={}) -> bytes:
|
139 |
+
with io.BytesIO() as output:
|
140 |
+
kwargs = {k: v for k, v in exif_infos.items() if v is not None}
|
141 |
+
if ext == "png" and "parameters" in kwargs:
|
142 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
143 |
+
pnginfo_data.add_text("parameters", kwargs["parameters"])
|
144 |
+
kwargs["pnginfo"] = pnginfo_data
|
145 |
+
|
146 |
+
pil_img.save(
|
147 |
+
output,
|
148 |
+
format=ext,
|
149 |
+
quality=quality,
|
150 |
+
**kwargs,
|
151 |
+
)
|
152 |
+
image_bytes = output.getvalue()
|
153 |
+
return image_bytes
|
154 |
+
|
155 |
+
|
156 |
+
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
|
157 |
+
alpha_channel = None
|
158 |
+
image = Image.open(io.BytesIO(img_bytes))
|
159 |
+
|
160 |
+
if return_exif:
|
161 |
+
info = image.info or {}
|
162 |
+
exif_infos = {"exif": image.getexif(), "parameters": info.get("parameters")}
|
163 |
+
|
164 |
+
try:
|
165 |
+
image = ImageOps.exif_transpose(image)
|
166 |
+
except:
|
167 |
+
pass
|
168 |
+
|
169 |
+
if gray:
|
170 |
+
image = image.convert("L")
|
171 |
+
np_img = np.array(image)
|
172 |
+
else:
|
173 |
+
if image.mode == "RGBA":
|
174 |
+
np_img = np.array(image)
|
175 |
+
alpha_channel = np_img[:, :, -1]
|
176 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
177 |
+
else:
|
178 |
+
image = image.convert("RGB")
|
179 |
+
np_img = np.array(image)
|
180 |
+
|
181 |
+
if return_exif:
|
182 |
+
return np_img, alpha_channel, exif_infos
|
183 |
+
return np_img, alpha_channel
|
184 |
+
|
185 |
+
|
186 |
+
def norm_img(np_img):
|
187 |
+
if len(np_img.shape) == 2:
|
188 |
+
np_img = np_img[:, :, np.newaxis]
|
189 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
190 |
+
np_img = np_img.astype("float32") / 255
|
191 |
+
return np_img
|
192 |
+
|
193 |
+
|
194 |
+
def resize_max_size(
|
195 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
196 |
+
) -> np.ndarray:
|
197 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
198 |
+
h, w = np_img.shape[:2]
|
199 |
+
if max(h, w) > size_limit:
|
200 |
+
ratio = size_limit / max(h, w)
|
201 |
+
new_w = int(w * ratio + 0.5)
|
202 |
+
new_h = int(h * ratio + 0.5)
|
203 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
204 |
+
else:
|
205 |
+
return np_img
|
206 |
+
|
207 |
+
|
208 |
+
def pad_img_to_modulo(
|
209 |
+
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
210 |
+
):
|
211 |
+
"""
|
212 |
+
|
213 |
+
Args:
|
214 |
+
img: [H, W, C]
|
215 |
+
mod:
|
216 |
+
square: 是否为正方形
|
217 |
+
min_size:
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
|
221 |
+
"""
|
222 |
+
if len(img.shape) == 2:
|
223 |
+
img = img[:, :, np.newaxis]
|
224 |
+
height, width = img.shape[:2]
|
225 |
+
out_height = ceil_modulo(height, mod)
|
226 |
+
out_width = ceil_modulo(width, mod)
|
227 |
+
|
228 |
+
if min_size is not None:
|
229 |
+
assert min_size % mod == 0
|
230 |
+
out_width = max(min_size, out_width)
|
231 |
+
out_height = max(min_size, out_height)
|
232 |
+
|
233 |
+
if square:
|
234 |
+
max_size = max(out_height, out_width)
|
235 |
+
out_height = max_size
|
236 |
+
out_width = max_size
|
237 |
+
|
238 |
+
return np.pad(
|
239 |
+
img,
|
240 |
+
((0, out_height - height), (0, out_width - width), (0, 0)),
|
241 |
+
mode="symmetric",
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
mask: (h, w, 1) 0~255
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
|
252 |
+
"""
|
253 |
+
height, width = mask.shape[:2]
|
254 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
255 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
256 |
+
|
257 |
+
boxes = []
|
258 |
+
for cnt in contours:
|
259 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
260 |
+
box = np.array([x, y, x + w, y + h]).astype(int)
|
261 |
+
|
262 |
+
box[::2] = np.clip(box[::2], 0, width)
|
263 |
+
box[1::2] = np.clip(box[1::2], 0, height)
|
264 |
+
boxes.append(box)
|
265 |
+
|
266 |
+
return boxes
|
267 |
+
|
268 |
+
|
269 |
+
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
mask: (h, w) 0~255
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
|
276 |
+
"""
|
277 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
278 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
279 |
+
|
280 |
+
max_area = 0
|
281 |
+
max_index = -1
|
282 |
+
for i, cnt in enumerate(contours):
|
283 |
+
area = cv2.contourArea(cnt)
|
284 |
+
if area > max_area:
|
285 |
+
max_area = area
|
286 |
+
max_index = i
|
287 |
+
|
288 |
+
if max_index != -1:
|
289 |
+
new_mask = np.zeros_like(mask)
|
290 |
+
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
291 |
+
else:
|
292 |
+
return mask
|
lama_cleaner/installer.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
def install(package):
|
6 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
7 |
+
|
8 |
+
|
9 |
+
def install_plugins_package():
|
10 |
+
install("rembg")
|
11 |
+
install("realesrgan")
|
12 |
+
install("gfpgan")
|
lama_cleaner/model/__init__.py
ADDED
File without changes
|
lama_cleaner/model/base.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from lama_cleaner.helper import (
|
10 |
+
boxes_from_mask,
|
11 |
+
resize_max_size,
|
12 |
+
pad_img_to_modulo,
|
13 |
+
switch_mps_device,
|
14 |
+
)
|
15 |
+
from lama_cleaner.schema import Config, HDStrategy
|
16 |
+
|
17 |
+
|
18 |
+
class InpaintModel:
|
19 |
+
name = "base"
|
20 |
+
min_size: Optional[int] = None
|
21 |
+
pad_mod = 8
|
22 |
+
pad_to_square = False
|
23 |
+
|
24 |
+
def __init__(self, device, **kwargs):
|
25 |
+
"""
|
26 |
+
|
27 |
+
Args:
|
28 |
+
device:
|
29 |
+
"""
|
30 |
+
device = switch_mps_device(self.name, device)
|
31 |
+
self.device = device
|
32 |
+
self.init_model(device, **kwargs)
|
33 |
+
|
34 |
+
@abc.abstractmethod
|
35 |
+
def init_model(self, device, **kwargs):
|
36 |
+
...
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
@abc.abstractmethod
|
40 |
+
def is_downloaded() -> bool:
|
41 |
+
...
|
42 |
+
|
43 |
+
@abc.abstractmethod
|
44 |
+
def forward(self, image, mask, config: Config):
|
45 |
+
"""Input images and output images have same size
|
46 |
+
images: [H, W, C] RGB
|
47 |
+
masks: [H, W, 1] 255 为 masks 区域
|
48 |
+
return: BGR IMAGE
|
49 |
+
"""
|
50 |
+
...
|
51 |
+
|
52 |
+
def _pad_forward(self, image, mask, config: Config):
|
53 |
+
origin_height, origin_width = image.shape[:2]
|
54 |
+
pad_image = pad_img_to_modulo(
|
55 |
+
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
56 |
+
)
|
57 |
+
pad_mask = pad_img_to_modulo(
|
58 |
+
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
59 |
+
)
|
60 |
+
|
61 |
+
logger.info(f"final forward pad size: {pad_image.shape}")
|
62 |
+
|
63 |
+
result = self.forward(pad_image, pad_mask, config)
|
64 |
+
result = result[0:origin_height, 0:origin_width, :]
|
65 |
+
|
66 |
+
result, image, mask = self.forward_post_process(result, image, mask, config)
|
67 |
+
|
68 |
+
mask = mask[:, :, np.newaxis]
|
69 |
+
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
70 |
+
return result
|
71 |
+
|
72 |
+
def forward_post_process(self, result, image, mask, config):
|
73 |
+
return result, image, mask
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def __call__(self, image, mask, config: Config):
|
77 |
+
"""
|
78 |
+
images: [H, W, C] RGB, not normalized
|
79 |
+
masks: [H, W]
|
80 |
+
return: BGR IMAGE
|
81 |
+
"""
|
82 |
+
inpaint_result = None
|
83 |
+
logger.info(f"hd_strategy: {config.hd_strategy}")
|
84 |
+
if config.hd_strategy == HDStrategy.CROP:
|
85 |
+
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
86 |
+
logger.info(f"Run crop strategy")
|
87 |
+
boxes = boxes_from_mask(mask)
|
88 |
+
crop_result = []
|
89 |
+
for box in boxes:
|
90 |
+
crop_image, crop_box = self._run_box(image, mask, box, config)
|
91 |
+
crop_result.append((crop_image, crop_box))
|
92 |
+
|
93 |
+
inpaint_result = image[:, :, ::-1]
|
94 |
+
for crop_image, crop_box in crop_result:
|
95 |
+
x1, y1, x2, y2 = crop_box
|
96 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
97 |
+
|
98 |
+
elif config.hd_strategy == HDStrategy.RESIZE:
|
99 |
+
if max(image.shape) > config.hd_strategy_resize_limit:
|
100 |
+
origin_size = image.shape[:2]
|
101 |
+
downsize_image = resize_max_size(
|
102 |
+
image, size_limit=config.hd_strategy_resize_limit
|
103 |
+
)
|
104 |
+
downsize_mask = resize_max_size(
|
105 |
+
mask, size_limit=config.hd_strategy_resize_limit
|
106 |
+
)
|
107 |
+
|
108 |
+
logger.info(
|
109 |
+
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
110 |
+
)
|
111 |
+
inpaint_result = self._pad_forward(
|
112 |
+
downsize_image, downsize_mask, config
|
113 |
+
)
|
114 |
+
|
115 |
+
# only paste masked area result
|
116 |
+
inpaint_result = cv2.resize(
|
117 |
+
inpaint_result,
|
118 |
+
(origin_size[1], origin_size[0]),
|
119 |
+
interpolation=cv2.INTER_CUBIC,
|
120 |
+
)
|
121 |
+
original_pixel_indices = mask < 127
|
122 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
123 |
+
original_pixel_indices
|
124 |
+
]
|
125 |
+
|
126 |
+
if inpaint_result is None:
|
127 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
128 |
+
|
129 |
+
return inpaint_result
|
130 |
+
|
131 |
+
def _crop_box(self, image, mask, box, config: Config):
|
132 |
+
"""
|
133 |
+
|
134 |
+
Args:
|
135 |
+
image: [H, W, C] RGB
|
136 |
+
mask: [H, W, 1]
|
137 |
+
box: [left,top,right,bottom]
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
BGR IMAGE, (l, r, r, b)
|
141 |
+
"""
|
142 |
+
box_h = box[3] - box[1]
|
143 |
+
box_w = box[2] - box[0]
|
144 |
+
cx = (box[0] + box[2]) // 2
|
145 |
+
cy = (box[1] + box[3]) // 2
|
146 |
+
img_h, img_w = image.shape[:2]
|
147 |
+
|
148 |
+
w = box_w + config.hd_strategy_crop_margin * 2
|
149 |
+
h = box_h + config.hd_strategy_crop_margin * 2
|
150 |
+
|
151 |
+
_l = cx - w // 2
|
152 |
+
_r = cx + w // 2
|
153 |
+
_t = cy - h // 2
|
154 |
+
_b = cy + h // 2
|
155 |
+
|
156 |
+
l = max(_l, 0)
|
157 |
+
r = min(_r, img_w)
|
158 |
+
t = max(_t, 0)
|
159 |
+
b = min(_b, img_h)
|
160 |
+
|
161 |
+
# try to get more context when crop around image edge
|
162 |
+
if _l < 0:
|
163 |
+
r += abs(_l)
|
164 |
+
if _r > img_w:
|
165 |
+
l -= _r - img_w
|
166 |
+
if _t < 0:
|
167 |
+
b += abs(_t)
|
168 |
+
if _b > img_h:
|
169 |
+
t -= _b - img_h
|
170 |
+
|
171 |
+
l = max(l, 0)
|
172 |
+
r = min(r, img_w)
|
173 |
+
t = max(t, 0)
|
174 |
+
b = min(b, img_h)
|
175 |
+
|
176 |
+
crop_img = image[t:b, l:r, :]
|
177 |
+
crop_mask = mask[t:b, l:r]
|
178 |
+
|
179 |
+
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
180 |
+
|
181 |
+
return crop_img, crop_mask, [l, t, r, b]
|
182 |
+
|
183 |
+
def _calculate_cdf(self, histogram):
|
184 |
+
cdf = histogram.cumsum()
|
185 |
+
normalized_cdf = cdf / float(cdf.max())
|
186 |
+
return normalized_cdf
|
187 |
+
|
188 |
+
def _calculate_lookup(self, source_cdf, reference_cdf):
|
189 |
+
lookup_table = np.zeros(256)
|
190 |
+
lookup_val = 0
|
191 |
+
for source_index, source_val in enumerate(source_cdf):
|
192 |
+
for reference_index, reference_val in enumerate(reference_cdf):
|
193 |
+
if reference_val >= source_val:
|
194 |
+
lookup_val = reference_index
|
195 |
+
break
|
196 |
+
lookup_table[source_index] = lookup_val
|
197 |
+
return lookup_table
|
198 |
+
|
199 |
+
def _match_histograms(self, source, reference, mask):
|
200 |
+
transformed_channels = []
|
201 |
+
for channel in range(source.shape[-1]):
|
202 |
+
source_channel = source[:, :, channel]
|
203 |
+
reference_channel = reference[:, :, channel]
|
204 |
+
|
205 |
+
# only calculate histograms for non-masked parts
|
206 |
+
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
207 |
+
reference_histogram, _ = np.histogram(
|
208 |
+
reference_channel[mask == 0], 256, [0, 256]
|
209 |
+
)
|
210 |
+
|
211 |
+
source_cdf = self._calculate_cdf(source_histogram)
|
212 |
+
reference_cdf = self._calculate_cdf(reference_histogram)
|
213 |
+
|
214 |
+
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
215 |
+
|
216 |
+
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
217 |
+
|
218 |
+
result = cv2.merge(transformed_channels)
|
219 |
+
result = cv2.convertScaleAbs(result)
|
220 |
+
|
221 |
+
return result
|
222 |
+
|
223 |
+
def _apply_cropper(self, image, mask, config: Config):
|
224 |
+
img_h, img_w = image.shape[:2]
|
225 |
+
l, t, w, h = (
|
226 |
+
config.croper_x,
|
227 |
+
config.croper_y,
|
228 |
+
config.croper_width,
|
229 |
+
config.croper_height,
|
230 |
+
)
|
231 |
+
r = l + w
|
232 |
+
b = t + h
|
233 |
+
|
234 |
+
l = max(l, 0)
|
235 |
+
r = min(r, img_w)
|
236 |
+
t = max(t, 0)
|
237 |
+
b = min(b, img_h)
|
238 |
+
|
239 |
+
crop_img = image[t:b, l:r, :]
|
240 |
+
crop_mask = mask[t:b, l:r]
|
241 |
+
return crop_img, crop_mask, (l, t, r, b)
|
242 |
+
|
243 |
+
def _run_box(self, image, mask, box, config: Config):
|
244 |
+
"""
|
245 |
+
|
246 |
+
Args:
|
247 |
+
image: [H, W, C] RGB
|
248 |
+
mask: [H, W, 1]
|
249 |
+
box: [left,top,right,bottom]
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
BGR IMAGE
|
253 |
+
"""
|
254 |
+
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
255 |
+
|
256 |
+
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
257 |
+
|
258 |
+
|
259 |
+
class DiffusionInpaintModel(InpaintModel):
|
260 |
+
@torch.no_grad()
|
261 |
+
def __call__(self, image, mask, config: Config):
|
262 |
+
"""
|
263 |
+
images: [H, W, C] RGB, not normalized
|
264 |
+
masks: [H, W]
|
265 |
+
return: BGR IMAGE
|
266 |
+
"""
|
267 |
+
# boxes = boxes_from_mask(mask)
|
268 |
+
if config.use_croper:
|
269 |
+
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
|
270 |
+
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
|
271 |
+
inpaint_result = image[:, :, ::-1]
|
272 |
+
inpaint_result[t:b, l:r, :] = crop_image
|
273 |
+
else:
|
274 |
+
inpaint_result = self._scaled_pad_forward(image, mask, config)
|
275 |
+
|
276 |
+
return inpaint_result
|
277 |
+
|
278 |
+
def _scaled_pad_forward(self, image, mask, config: Config):
|
279 |
+
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
|
280 |
+
origin_size = image.shape[:2]
|
281 |
+
downsize_image = resize_max_size(image, size_limit=longer_side_length)
|
282 |
+
downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
|
283 |
+
if config.sd_scale != 1:
|
284 |
+
logger.info(
|
285 |
+
f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
|
286 |
+
)
|
287 |
+
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
288 |
+
# only paste masked area result
|
289 |
+
inpaint_result = cv2.resize(
|
290 |
+
inpaint_result,
|
291 |
+
(origin_size[1], origin_size[0]),
|
292 |
+
interpolation=cv2.INTER_CUBIC,
|
293 |
+
)
|
294 |
+
original_pixel_indices = mask < 127
|
295 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
296 |
+
original_pixel_indices
|
297 |
+
]
|
298 |
+
return inpaint_result
|
lama_cleaner/model/controlnet.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import PIL.Image
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from diffusers import ControlNetModel
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
from lama_cleaner.model.base import DiffusionInpaintModel
|
11 |
+
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
12 |
+
from lama_cleaner.schema import Config
|
13 |
+
|
14 |
+
|
15 |
+
class CPUTextEncoderWrapper:
|
16 |
+
def __init__(self, text_encoder, torch_dtype):
|
17 |
+
self.config = text_encoder.config
|
18 |
+
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
19 |
+
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
20 |
+
self.torch_dtype = torch_dtype
|
21 |
+
del text_encoder
|
22 |
+
torch_gc()
|
23 |
+
|
24 |
+
def __call__(self, x, **kwargs):
|
25 |
+
input_device = x.device
|
26 |
+
return [
|
27 |
+
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
28 |
+
.to(input_device)
|
29 |
+
.to(self.torch_dtype)
|
30 |
+
]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def dtype(self):
|
34 |
+
return self.torch_dtype
|
35 |
+
|
36 |
+
|
37 |
+
NAMES_MAP = {
|
38 |
+
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
39 |
+
"anything4": "Sanster/anything-4.0-inpainting",
|
40 |
+
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
41 |
+
}
|
42 |
+
|
43 |
+
NATIVE_NAMES_MAP = {
|
44 |
+
"sd1.5": "runwayml/stable-diffusion-v1-5",
|
45 |
+
"anything4": "andite/anything-v4.0",
|
46 |
+
"realisticVision1.4": "SG161222/Realistic_Vision_V1.4",
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def make_inpaint_condition(image, image_mask):
|
51 |
+
"""
|
52 |
+
image: [H, W, C] RGB
|
53 |
+
mask: [H, W, 1] 255 means area to repaint
|
54 |
+
"""
|
55 |
+
image = image.astype(np.float32) / 255.0
|
56 |
+
image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel
|
57 |
+
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
58 |
+
image = torch.from_numpy(image)
|
59 |
+
return image
|
60 |
+
|
61 |
+
|
62 |
+
def load_from_local_model(
|
63 |
+
local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint
|
64 |
+
):
|
65 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
66 |
+
download_from_original_stable_diffusion_ckpt,
|
67 |
+
)
|
68 |
+
|
69 |
+
logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
|
70 |
+
|
71 |
+
try:
|
72 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
73 |
+
local_model_path,
|
74 |
+
num_in_channels=4 if is_native_control_inpaint else 9,
|
75 |
+
from_safetensors=local_model_path.endswith("safetensors"),
|
76 |
+
device="cpu",
|
77 |
+
load_safety_checker=False,
|
78 |
+
)
|
79 |
+
except Exception as e:
|
80 |
+
err_msg = str(e)
|
81 |
+
logger.exception(e)
|
82 |
+
if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg:
|
83 |
+
logger.error(
|
84 |
+
"control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model"
|
85 |
+
)
|
86 |
+
if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg:
|
87 |
+
logger.error(
|
88 |
+
f"{controlnet.config['_name_or_path']} method requires inpainting SD model, "
|
89 |
+
f"you can convert any SD model to inpainting model in AUTO1111: \n"
|
90 |
+
f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/"
|
91 |
+
)
|
92 |
+
exit(-1)
|
93 |
+
|
94 |
+
inpaint_pipe = pipe_class(
|
95 |
+
vae=pipe.vae,
|
96 |
+
text_encoder=pipe.text_encoder,
|
97 |
+
tokenizer=pipe.tokenizer,
|
98 |
+
unet=pipe.unet,
|
99 |
+
controlnet=controlnet,
|
100 |
+
scheduler=pipe.scheduler,
|
101 |
+
safety_checker=None,
|
102 |
+
feature_extractor=None,
|
103 |
+
requires_safety_checker=False,
|
104 |
+
)
|
105 |
+
|
106 |
+
del pipe
|
107 |
+
gc.collect()
|
108 |
+
return inpaint_pipe.to(torch_dtype=torch_dtype)
|
109 |
+
|
110 |
+
|
111 |
+
class ControlNet(DiffusionInpaintModel):
|
112 |
+
name = "controlnet"
|
113 |
+
pad_mod = 8
|
114 |
+
min_size = 512
|
115 |
+
|
116 |
+
def init_model(self, device: torch.device, **kwargs):
|
117 |
+
fp16 = not kwargs.get("no_half", False)
|
118 |
+
|
119 |
+
model_kwargs = {
|
120 |
+
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
121 |
+
}
|
122 |
+
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
123 |
+
logger.info("Disable Stable Diffusion Model NSFW checker")
|
124 |
+
model_kwargs.update(
|
125 |
+
dict(
|
126 |
+
safety_checker=None,
|
127 |
+
feature_extractor=None,
|
128 |
+
requires_safety_checker=False,
|
129 |
+
)
|
130 |
+
)
|
131 |
+
|
132 |
+
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
133 |
+
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
134 |
+
|
135 |
+
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
136 |
+
self.sd_controlnet_method = sd_controlnet_method
|
137 |
+
|
138 |
+
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
139 |
+
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
140 |
+
|
141 |
+
self.is_native_control_inpaint = True
|
142 |
+
else:
|
143 |
+
from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass
|
144 |
+
|
145 |
+
self.is_native_control_inpaint = False
|
146 |
+
|
147 |
+
if self.is_native_control_inpaint:
|
148 |
+
model_id = NATIVE_NAMES_MAP[kwargs["name"]]
|
149 |
+
else:
|
150 |
+
model_id = NAMES_MAP[kwargs["name"]]
|
151 |
+
|
152 |
+
controlnet = ControlNetModel.from_pretrained(
|
153 |
+
f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype
|
154 |
+
)
|
155 |
+
self.is_local_sd_model = False
|
156 |
+
if kwargs.get("sd_local_model_path", None):
|
157 |
+
self.is_local_sd_model = True
|
158 |
+
self.model = load_from_local_model(
|
159 |
+
kwargs["sd_local_model_path"],
|
160 |
+
torch_dtype=torch_dtype,
|
161 |
+
controlnet=controlnet,
|
162 |
+
pipe_class=PipeClass,
|
163 |
+
is_native_control_inpaint=self.is_native_control_inpaint,
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
self.model = PipeClass.from_pretrained(
|
167 |
+
model_id,
|
168 |
+
controlnet=controlnet,
|
169 |
+
revision="fp16" if use_gpu and fp16 else "main",
|
170 |
+
torch_dtype=torch_dtype,
|
171 |
+
**model_kwargs,
|
172 |
+
)
|
173 |
+
|
174 |
+
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
175 |
+
self.model.enable_attention_slicing()
|
176 |
+
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
177 |
+
if kwargs.get("enable_xformers", False):
|
178 |
+
self.model.enable_xformers_memory_efficient_attention()
|
179 |
+
|
180 |
+
if kwargs.get("cpu_offload", False) and use_gpu:
|
181 |
+
logger.info("Enable sequential cpu offload")
|
182 |
+
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
183 |
+
else:
|
184 |
+
self.model = self.model.to(device)
|
185 |
+
if kwargs["sd_cpu_textencoder"]:
|
186 |
+
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
187 |
+
self.model.text_encoder = CPUTextEncoderWrapper(
|
188 |
+
self.model.text_encoder, torch_dtype
|
189 |
+
)
|
190 |
+
|
191 |
+
self.callback = kwargs.pop("callback", None)
|
192 |
+
|
193 |
+
def forward(self, image, mask, config: Config):
|
194 |
+
"""Input image and output image have same size
|
195 |
+
image: [H, W, C] RGB
|
196 |
+
mask: [H, W, 1] 255 means area to repaint
|
197 |
+
return: BGR IMAGE
|
198 |
+
"""
|
199 |
+
scheduler_config = self.model.scheduler.config
|
200 |
+
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
201 |
+
self.model.scheduler = scheduler
|
202 |
+
|
203 |
+
if config.sd_mask_blur != 0:
|
204 |
+
k = 2 * config.sd_mask_blur + 1
|
205 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
206 |
+
|
207 |
+
img_h, img_w = image.shape[:2]
|
208 |
+
|
209 |
+
if self.is_native_control_inpaint:
|
210 |
+
control_image = make_inpaint_condition(image, mask)
|
211 |
+
output = self.model(
|
212 |
+
prompt=config.prompt,
|
213 |
+
image=control_image,
|
214 |
+
height=img_h,
|
215 |
+
width=img_w,
|
216 |
+
num_inference_steps=config.sd_steps,
|
217 |
+
guidance_scale=config.sd_guidance_scale,
|
218 |
+
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
219 |
+
negative_prompt=config.negative_prompt,
|
220 |
+
generator=torch.manual_seed(config.sd_seed),
|
221 |
+
output_type="np.array",
|
222 |
+
callback=self.callback,
|
223 |
+
).images[0]
|
224 |
+
else:
|
225 |
+
if "canny" in self.sd_controlnet_method:
|
226 |
+
canny_image = cv2.Canny(image, 100, 200)
|
227 |
+
canny_image = canny_image[:, :, None]
|
228 |
+
canny_image = np.concatenate(
|
229 |
+
[canny_image, canny_image, canny_image], axis=2
|
230 |
+
)
|
231 |
+
canny_image = PIL.Image.fromarray(canny_image)
|
232 |
+
control_image = canny_image
|
233 |
+
elif "openpose" in self.sd_controlnet_method:
|
234 |
+
from controlnet_aux import OpenposeDetector
|
235 |
+
|
236 |
+
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
237 |
+
control_image = processor(image, hand_and_face=True)
|
238 |
+
elif "depth" in self.sd_controlnet_method:
|
239 |
+
from transformers import pipeline
|
240 |
+
|
241 |
+
depth_estimator = pipeline("depth-estimation")
|
242 |
+
depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
|
243 |
+
depth_image = np.array(depth_image)
|
244 |
+
depth_image = depth_image[:, :, None]
|
245 |
+
depth_image = np.concatenate(
|
246 |
+
[depth_image, depth_image, depth_image], axis=2
|
247 |
+
)
|
248 |
+
control_image = PIL.Image.fromarray(depth_image)
|
249 |
+
else:
|
250 |
+
raise NotImplementedError(
|
251 |
+
f"{self.sd_controlnet_method} not implemented"
|
252 |
+
)
|
253 |
+
|
254 |
+
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
255 |
+
image = PIL.Image.fromarray(image)
|
256 |
+
|
257 |
+
output = self.model(
|
258 |
+
image=image,
|
259 |
+
control_image=control_image,
|
260 |
+
prompt=config.prompt,
|
261 |
+
negative_prompt=config.negative_prompt,
|
262 |
+
mask_image=mask_image,
|
263 |
+
num_inference_steps=config.sd_steps,
|
264 |
+
guidance_scale=config.sd_guidance_scale,
|
265 |
+
output_type="np.array",
|
266 |
+
callback=self.callback,
|
267 |
+
height=img_h,
|
268 |
+
width=img_w,
|
269 |
+
generator=torch.manual_seed(config.sd_seed),
|
270 |
+
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
271 |
+
).images[0]
|
272 |
+
|
273 |
+
output = (output * 255).round().astype("uint8")
|
274 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
275 |
+
return output
|
276 |
+
|
277 |
+
def forward_post_process(self, result, image, mask, config):
|
278 |
+
if config.sd_match_histograms:
|
279 |
+
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
280 |
+
|
281 |
+
if config.sd_mask_blur != 0:
|
282 |
+
k = 2 * config.sd_mask_blur + 1
|
283 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
284 |
+
return result, image, mask
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def is_downloaded() -> bool:
|
288 |
+
# model will be downloaded when app start, and can't switch in frontend settings
|
289 |
+
return True
|
lama_cleaner/model/ddim_sampler.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
|
10 |
+
class DDIMSampler(object):
|
11 |
+
def __init__(self, model, schedule="linear"):
|
12 |
+
super().__init__()
|
13 |
+
self.model = model
|
14 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
15 |
+
self.schedule = schedule
|
16 |
+
|
17 |
+
def register_buffer(self, name, attr):
|
18 |
+
setattr(self, name, attr)
|
19 |
+
|
20 |
+
def make_schedule(
|
21 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
22 |
+
):
|
23 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
24 |
+
ddim_discr_method=ddim_discretize,
|
25 |
+
num_ddim_timesteps=ddim_num_steps,
|
26 |
+
# array([1])
|
27 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
28 |
+
verbose=verbose,
|
29 |
+
)
|
30 |
+
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
31 |
+
assert (
|
32 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
33 |
+
), "alphas have to be defined for each timestep"
|
34 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
35 |
+
|
36 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
37 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
38 |
+
self.register_buffer(
|
39 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
40 |
+
)
|
41 |
+
|
42 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
43 |
+
self.register_buffer(
|
44 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
45 |
+
)
|
46 |
+
self.register_buffer(
|
47 |
+
"sqrt_one_minus_alphas_cumprod",
|
48 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
49 |
+
)
|
50 |
+
self.register_buffer(
|
51 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
52 |
+
)
|
53 |
+
self.register_buffer(
|
54 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
55 |
+
)
|
56 |
+
self.register_buffer(
|
57 |
+
"sqrt_recipm1_alphas_cumprod",
|
58 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
59 |
+
)
|
60 |
+
|
61 |
+
# ddim sampling parameters
|
62 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
63 |
+
alphacums=alphas_cumprod.cpu(),
|
64 |
+
ddim_timesteps=self.ddim_timesteps,
|
65 |
+
eta=ddim_eta,
|
66 |
+
verbose=verbose,
|
67 |
+
)
|
68 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
69 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
70 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
71 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
72 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
73 |
+
(1 - self.alphas_cumprod_prev)
|
74 |
+
/ (1 - self.alphas_cumprod)
|
75 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
76 |
+
)
|
77 |
+
self.register_buffer(
|
78 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
79 |
+
)
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def sample(self, steps, conditioning, batch_size, shape):
|
83 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
84 |
+
# sampling
|
85 |
+
C, H, W = shape
|
86 |
+
size = (batch_size, C, H, W)
|
87 |
+
|
88 |
+
# samples: 1,3,128,128
|
89 |
+
return self.ddim_sampling(
|
90 |
+
conditioning,
|
91 |
+
size,
|
92 |
+
quantize_denoised=False,
|
93 |
+
ddim_use_original_steps=False,
|
94 |
+
noise_dropout=0,
|
95 |
+
temperature=1.0,
|
96 |
+
)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def ddim_sampling(
|
100 |
+
self,
|
101 |
+
cond,
|
102 |
+
shape,
|
103 |
+
ddim_use_original_steps=False,
|
104 |
+
quantize_denoised=False,
|
105 |
+
temperature=1.0,
|
106 |
+
noise_dropout=0.0,
|
107 |
+
):
|
108 |
+
device = self.model.betas.device
|
109 |
+
b = shape[0]
|
110 |
+
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
111 |
+
timesteps = (
|
112 |
+
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
113 |
+
)
|
114 |
+
|
115 |
+
time_range = (
|
116 |
+
reversed(range(0, timesteps))
|
117 |
+
if ddim_use_original_steps
|
118 |
+
else np.flip(timesteps)
|
119 |
+
)
|
120 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
121 |
+
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
122 |
+
|
123 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
124 |
+
|
125 |
+
for i, step in enumerate(iterator):
|
126 |
+
index = total_steps - i - 1
|
127 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
128 |
+
|
129 |
+
outs = self.p_sample_ddim(
|
130 |
+
img,
|
131 |
+
cond,
|
132 |
+
ts,
|
133 |
+
index=index,
|
134 |
+
use_original_steps=ddim_use_original_steps,
|
135 |
+
quantize_denoised=quantize_denoised,
|
136 |
+
temperature=temperature,
|
137 |
+
noise_dropout=noise_dropout,
|
138 |
+
)
|
139 |
+
img, _ = outs
|
140 |
+
|
141 |
+
return img
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def p_sample_ddim(
|
145 |
+
self,
|
146 |
+
x,
|
147 |
+
c,
|
148 |
+
t,
|
149 |
+
index,
|
150 |
+
repeat_noise=False,
|
151 |
+
use_original_steps=False,
|
152 |
+
quantize_denoised=False,
|
153 |
+
temperature=1.0,
|
154 |
+
noise_dropout=0.0,
|
155 |
+
):
|
156 |
+
b, *_, device = *x.shape, x.device
|
157 |
+
e_t = self.model.apply_model(x, t, c)
|
158 |
+
|
159 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
160 |
+
alphas_prev = (
|
161 |
+
self.model.alphas_cumprod_prev
|
162 |
+
if use_original_steps
|
163 |
+
else self.ddim_alphas_prev
|
164 |
+
)
|
165 |
+
sqrt_one_minus_alphas = (
|
166 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
167 |
+
if use_original_steps
|
168 |
+
else self.ddim_sqrt_one_minus_alphas
|
169 |
+
)
|
170 |
+
sigmas = (
|
171 |
+
self.model.ddim_sigmas_for_original_num_steps
|
172 |
+
if use_original_steps
|
173 |
+
else self.ddim_sigmas
|
174 |
+
)
|
175 |
+
# select parameters corresponding to the currently considered timestep
|
176 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
177 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
178 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
179 |
+
sqrt_one_minus_at = torch.full(
|
180 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
181 |
+
)
|
182 |
+
|
183 |
+
# current prediction for x_0
|
184 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
185 |
+
if quantize_denoised: # 没用
|
186 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
187 |
+
# direction pointing to x_t
|
188 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
189 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
190 |
+
if noise_dropout > 0.0: # 没用
|
191 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
192 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
193 |
+
return x_prev, pred_x0
|
lama_cleaner/model/fcf.py
ADDED
@@ -0,0 +1,1733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch.fft as fft
|
8 |
+
|
9 |
+
from lama_cleaner.schema import Config
|
10 |
+
|
11 |
+
from lama_cleaner.helper import (
|
12 |
+
load_model,
|
13 |
+
get_cache_path_by_url,
|
14 |
+
norm_img,
|
15 |
+
boxes_from_mask,
|
16 |
+
resize_max_size,
|
17 |
+
)
|
18 |
+
from lama_cleaner.model.base import InpaintModel
|
19 |
+
from torch import conv2d, nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from lama_cleaner.model.utils import (
|
23 |
+
setup_filter,
|
24 |
+
_parse_scaling,
|
25 |
+
_parse_padding,
|
26 |
+
Conv2dLayer,
|
27 |
+
FullyConnectedLayer,
|
28 |
+
MinibatchStdLayer,
|
29 |
+
activation_funcs,
|
30 |
+
conv2d_resample,
|
31 |
+
bias_act,
|
32 |
+
upsample2d,
|
33 |
+
normalize_2nd_moment,
|
34 |
+
downsample2d,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
39 |
+
assert isinstance(x, torch.Tensor)
|
40 |
+
return _upfirdn2d_ref(
|
41 |
+
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
46 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
|
47 |
+
# Validate arguments.
|
48 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
49 |
+
if f is None:
|
50 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
51 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
52 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
53 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
54 |
+
upx, upy = _parse_scaling(up)
|
55 |
+
downx, downy = _parse_scaling(down)
|
56 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
57 |
+
|
58 |
+
# Upsample by inserting zeros.
|
59 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
60 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
61 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
62 |
+
|
63 |
+
# Pad or crop.
|
64 |
+
x = torch.nn.functional.pad(
|
65 |
+
x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
|
66 |
+
)
|
67 |
+
x = x[
|
68 |
+
:,
|
69 |
+
:,
|
70 |
+
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
|
71 |
+
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
|
72 |
+
]
|
73 |
+
|
74 |
+
# Setup filter.
|
75 |
+
f = f * (gain ** (f.ndim / 2))
|
76 |
+
f = f.to(x.dtype)
|
77 |
+
if not flip_filter:
|
78 |
+
f = f.flip(list(range(f.ndim)))
|
79 |
+
|
80 |
+
# Convolve with the filter.
|
81 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
82 |
+
if f.ndim == 4:
|
83 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
84 |
+
else:
|
85 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
86 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
87 |
+
|
88 |
+
# Downsample by throwing away pixels.
|
89 |
+
x = x[:, :, ::downy, ::downx]
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class EncoderEpilogue(torch.nn.Module):
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
in_channels, # Number of input channels.
|
97 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
98 |
+
z_dim, # Output Latent (Z) dimensionality.
|
99 |
+
resolution, # Resolution of this block.
|
100 |
+
img_channels, # Number of input color channels.
|
101 |
+
architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'.
|
102 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
103 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
104 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
105 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
106 |
+
):
|
107 |
+
assert architecture in ["orig", "skip", "resnet"]
|
108 |
+
super().__init__()
|
109 |
+
self.in_channels = in_channels
|
110 |
+
self.cmap_dim = cmap_dim
|
111 |
+
self.resolution = resolution
|
112 |
+
self.img_channels = img_channels
|
113 |
+
self.architecture = architecture
|
114 |
+
|
115 |
+
if architecture == "skip":
|
116 |
+
self.fromrgb = Conv2dLayer(
|
117 |
+
self.img_channels, in_channels, kernel_size=1, activation=activation
|
118 |
+
)
|
119 |
+
self.mbstd = (
|
120 |
+
MinibatchStdLayer(
|
121 |
+
group_size=mbstd_group_size, num_channels=mbstd_num_channels
|
122 |
+
)
|
123 |
+
if mbstd_num_channels > 0
|
124 |
+
else None
|
125 |
+
)
|
126 |
+
self.conv = Conv2dLayer(
|
127 |
+
in_channels + mbstd_num_channels,
|
128 |
+
in_channels,
|
129 |
+
kernel_size=3,
|
130 |
+
activation=activation,
|
131 |
+
conv_clamp=conv_clamp,
|
132 |
+
)
|
133 |
+
self.fc = FullyConnectedLayer(
|
134 |
+
in_channels * (resolution**2), z_dim, activation=activation
|
135 |
+
)
|
136 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
137 |
+
|
138 |
+
def forward(self, x, cmap, force_fp32=False):
|
139 |
+
_ = force_fp32 # unused
|
140 |
+
dtype = torch.float32
|
141 |
+
memory_format = torch.contiguous_format
|
142 |
+
|
143 |
+
# FromRGB.
|
144 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
145 |
+
|
146 |
+
# Main layers.
|
147 |
+
if self.mbstd is not None:
|
148 |
+
x = self.mbstd(x)
|
149 |
+
const_e = self.conv(x)
|
150 |
+
x = self.fc(const_e.flatten(1))
|
151 |
+
x = self.dropout(x)
|
152 |
+
|
153 |
+
# Conditioning.
|
154 |
+
if self.cmap_dim > 0:
|
155 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
156 |
+
|
157 |
+
assert x.dtype == dtype
|
158 |
+
return x, const_e
|
159 |
+
|
160 |
+
|
161 |
+
class EncoderBlock(torch.nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
in_channels, # Number of input channels, 0 = first block.
|
165 |
+
tmp_channels, # Number of intermediate channels.
|
166 |
+
out_channels, # Number of output channels.
|
167 |
+
resolution, # Resolution of this block.
|
168 |
+
img_channels, # Number of input color channels.
|
169 |
+
first_layer_idx, # Index of the first layer.
|
170 |
+
architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
|
171 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
172 |
+
resample_filter=[
|
173 |
+
1,
|
174 |
+
3,
|
175 |
+
3,
|
176 |
+
1,
|
177 |
+
], # Low-pass filter to apply when resampling activations.
|
178 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
179 |
+
use_fp16=False, # Use FP16 for this block?
|
180 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
181 |
+
freeze_layers=0, # Freeze-D: Number of layers to freeze.
|
182 |
+
):
|
183 |
+
assert in_channels in [0, tmp_channels]
|
184 |
+
assert architecture in ["orig", "skip", "resnet"]
|
185 |
+
super().__init__()
|
186 |
+
self.in_channels = in_channels
|
187 |
+
self.resolution = resolution
|
188 |
+
self.img_channels = img_channels + 1
|
189 |
+
self.first_layer_idx = first_layer_idx
|
190 |
+
self.architecture = architecture
|
191 |
+
self.use_fp16 = use_fp16
|
192 |
+
self.channels_last = use_fp16 and fp16_channels_last
|
193 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
194 |
+
|
195 |
+
self.num_layers = 0
|
196 |
+
|
197 |
+
def trainable_gen():
|
198 |
+
while True:
|
199 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
200 |
+
trainable = layer_idx >= freeze_layers
|
201 |
+
self.num_layers += 1
|
202 |
+
yield trainable
|
203 |
+
|
204 |
+
trainable_iter = trainable_gen()
|
205 |
+
|
206 |
+
if in_channels == 0:
|
207 |
+
self.fromrgb = Conv2dLayer(
|
208 |
+
self.img_channels,
|
209 |
+
tmp_channels,
|
210 |
+
kernel_size=1,
|
211 |
+
activation=activation,
|
212 |
+
trainable=next(trainable_iter),
|
213 |
+
conv_clamp=conv_clamp,
|
214 |
+
channels_last=self.channels_last,
|
215 |
+
)
|
216 |
+
|
217 |
+
self.conv0 = Conv2dLayer(
|
218 |
+
tmp_channels,
|
219 |
+
tmp_channels,
|
220 |
+
kernel_size=3,
|
221 |
+
activation=activation,
|
222 |
+
trainable=next(trainable_iter),
|
223 |
+
conv_clamp=conv_clamp,
|
224 |
+
channels_last=self.channels_last,
|
225 |
+
)
|
226 |
+
|
227 |
+
self.conv1 = Conv2dLayer(
|
228 |
+
tmp_channels,
|
229 |
+
out_channels,
|
230 |
+
kernel_size=3,
|
231 |
+
activation=activation,
|
232 |
+
down=2,
|
233 |
+
trainable=next(trainable_iter),
|
234 |
+
resample_filter=resample_filter,
|
235 |
+
conv_clamp=conv_clamp,
|
236 |
+
channels_last=self.channels_last,
|
237 |
+
)
|
238 |
+
|
239 |
+
if architecture == "resnet":
|
240 |
+
self.skip = Conv2dLayer(
|
241 |
+
tmp_channels,
|
242 |
+
out_channels,
|
243 |
+
kernel_size=1,
|
244 |
+
bias=False,
|
245 |
+
down=2,
|
246 |
+
trainable=next(trainable_iter),
|
247 |
+
resample_filter=resample_filter,
|
248 |
+
channels_last=self.channels_last,
|
249 |
+
)
|
250 |
+
|
251 |
+
def forward(self, x, img, force_fp32=False):
|
252 |
+
# dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
253 |
+
dtype = torch.float32
|
254 |
+
memory_format = (
|
255 |
+
torch.channels_last
|
256 |
+
if self.channels_last and not force_fp32
|
257 |
+
else torch.contiguous_format
|
258 |
+
)
|
259 |
+
|
260 |
+
# Input.
|
261 |
+
if x is not None:
|
262 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
263 |
+
|
264 |
+
# FromRGB.
|
265 |
+
if self.in_channels == 0:
|
266 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
267 |
+
y = self.fromrgb(img)
|
268 |
+
x = x + y if x is not None else y
|
269 |
+
img = (
|
270 |
+
downsample2d(img, self.resample_filter)
|
271 |
+
if self.architecture == "skip"
|
272 |
+
else None
|
273 |
+
)
|
274 |
+
|
275 |
+
# Main layers.
|
276 |
+
if self.architecture == "resnet":
|
277 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
278 |
+
x = self.conv0(x)
|
279 |
+
feat = x.clone()
|
280 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
281 |
+
x = y.add_(x)
|
282 |
+
else:
|
283 |
+
x = self.conv0(x)
|
284 |
+
feat = x.clone()
|
285 |
+
x = self.conv1(x)
|
286 |
+
|
287 |
+
assert x.dtype == dtype
|
288 |
+
return x, img, feat
|
289 |
+
|
290 |
+
|
291 |
+
class EncoderNetwork(torch.nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
c_dim, # Conditioning label (C) dimensionality.
|
295 |
+
z_dim, # Input latent (Z) dimensionality.
|
296 |
+
img_resolution, # Input resolution.
|
297 |
+
img_channels, # Number of input color channels.
|
298 |
+
architecture="orig", # Architecture: 'orig', 'skip', 'resnet'.
|
299 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
300 |
+
channel_max=512, # Maximum number of channels in any layer.
|
301 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
302 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
303 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
304 |
+
block_kwargs={}, # Arguments for DiscriminatorBlock.
|
305 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
306 |
+
epilogue_kwargs={}, # Arguments for EncoderEpilogue.
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
self.c_dim = c_dim
|
310 |
+
self.z_dim = z_dim
|
311 |
+
self.img_resolution = img_resolution
|
312 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
313 |
+
self.img_channels = img_channels
|
314 |
+
self.block_resolutions = [
|
315 |
+
2**i for i in range(self.img_resolution_log2, 2, -1)
|
316 |
+
]
|
317 |
+
channels_dict = {
|
318 |
+
res: min(channel_base // res, channel_max)
|
319 |
+
for res in self.block_resolutions + [4]
|
320 |
+
}
|
321 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
322 |
+
|
323 |
+
if cmap_dim is None:
|
324 |
+
cmap_dim = channels_dict[4]
|
325 |
+
if c_dim == 0:
|
326 |
+
cmap_dim = 0
|
327 |
+
|
328 |
+
common_kwargs = dict(
|
329 |
+
img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp
|
330 |
+
)
|
331 |
+
cur_layer_idx = 0
|
332 |
+
for res in self.block_resolutions:
|
333 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
334 |
+
tmp_channels = channels_dict[res]
|
335 |
+
out_channels = channels_dict[res // 2]
|
336 |
+
use_fp16 = res >= fp16_resolution
|
337 |
+
use_fp16 = False
|
338 |
+
block = EncoderBlock(
|
339 |
+
in_channels,
|
340 |
+
tmp_channels,
|
341 |
+
out_channels,
|
342 |
+
resolution=res,
|
343 |
+
first_layer_idx=cur_layer_idx,
|
344 |
+
use_fp16=use_fp16,
|
345 |
+
**block_kwargs,
|
346 |
+
**common_kwargs,
|
347 |
+
)
|
348 |
+
setattr(self, f"b{res}", block)
|
349 |
+
cur_layer_idx += block.num_layers
|
350 |
+
if c_dim > 0:
|
351 |
+
self.mapping = MappingNetwork(
|
352 |
+
z_dim=0,
|
353 |
+
c_dim=c_dim,
|
354 |
+
w_dim=cmap_dim,
|
355 |
+
num_ws=None,
|
356 |
+
w_avg_beta=None,
|
357 |
+
**mapping_kwargs,
|
358 |
+
)
|
359 |
+
self.b4 = EncoderEpilogue(
|
360 |
+
channels_dict[4],
|
361 |
+
cmap_dim=cmap_dim,
|
362 |
+
z_dim=z_dim * 2,
|
363 |
+
resolution=4,
|
364 |
+
**epilogue_kwargs,
|
365 |
+
**common_kwargs,
|
366 |
+
)
|
367 |
+
|
368 |
+
def forward(self, img, c, **block_kwargs):
|
369 |
+
x = None
|
370 |
+
feats = {}
|
371 |
+
for res in self.block_resolutions:
|
372 |
+
block = getattr(self, f"b{res}")
|
373 |
+
x, img, feat = block(x, img, **block_kwargs)
|
374 |
+
feats[res] = feat
|
375 |
+
|
376 |
+
cmap = None
|
377 |
+
if self.c_dim > 0:
|
378 |
+
cmap = self.mapping(None, c)
|
379 |
+
x, const_e = self.b4(x, cmap)
|
380 |
+
feats[4] = const_e
|
381 |
+
|
382 |
+
B, _ = x.shape
|
383 |
+
z = torch.zeros(
|
384 |
+
(B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device
|
385 |
+
) ## Noise for Co-Modulation
|
386 |
+
return x, z, feats
|
387 |
+
|
388 |
+
|
389 |
+
def fma(a, b, c): # => a * b + c
|
390 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
391 |
+
|
392 |
+
|
393 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
394 |
+
@staticmethod
|
395 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
396 |
+
out = torch.addcmul(c, a, b)
|
397 |
+
ctx.save_for_backward(a, b)
|
398 |
+
ctx.c_shape = c.shape
|
399 |
+
return out
|
400 |
+
|
401 |
+
@staticmethod
|
402 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
403 |
+
a, b = ctx.saved_tensors
|
404 |
+
c_shape = ctx.c_shape
|
405 |
+
da = None
|
406 |
+
db = None
|
407 |
+
dc = None
|
408 |
+
|
409 |
+
if ctx.needs_input_grad[0]:
|
410 |
+
da = _unbroadcast(dout * b, a.shape)
|
411 |
+
|
412 |
+
if ctx.needs_input_grad[1]:
|
413 |
+
db = _unbroadcast(dout * a, b.shape)
|
414 |
+
|
415 |
+
if ctx.needs_input_grad[2]:
|
416 |
+
dc = _unbroadcast(dout, c_shape)
|
417 |
+
|
418 |
+
return da, db, dc
|
419 |
+
|
420 |
+
|
421 |
+
def _unbroadcast(x, shape):
|
422 |
+
extra_dims = x.ndim - len(shape)
|
423 |
+
assert extra_dims >= 0
|
424 |
+
dim = [
|
425 |
+
i
|
426 |
+
for i in range(x.ndim)
|
427 |
+
if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
|
428 |
+
]
|
429 |
+
if len(dim):
|
430 |
+
x = x.sum(dim=dim, keepdim=True)
|
431 |
+
if extra_dims:
|
432 |
+
x = x.reshape(-1, *x.shape[extra_dims + 1 :])
|
433 |
+
assert x.shape == shape
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
def modulated_conv2d(
|
438 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
439 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
440 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
441 |
+
noise=None, # Optional noise tensor to add to the output activations.
|
442 |
+
up=1, # Integer upsampling factor.
|
443 |
+
down=1, # Integer downsampling factor.
|
444 |
+
padding=0, # Padding with respect to the upsampled image.
|
445 |
+
resample_filter=None,
|
446 |
+
# Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
447 |
+
demodulate=True, # Apply weight demodulation?
|
448 |
+
flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
449 |
+
fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
450 |
+
):
|
451 |
+
batch_size = x.shape[0]
|
452 |
+
out_channels, in_channels, kh, kw = weight.shape
|
453 |
+
|
454 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
455 |
+
if x.dtype == torch.float16 and demodulate:
|
456 |
+
weight = weight * (
|
457 |
+
1
|
458 |
+
/ np.sqrt(in_channels * kh * kw)
|
459 |
+
/ weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)
|
460 |
+
) # max_Ikk
|
461 |
+
styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I
|
462 |
+
|
463 |
+
# Calculate per-sample weights and demodulation coefficients.
|
464 |
+
w = None
|
465 |
+
dcoefs = None
|
466 |
+
if demodulate or fused_modconv:
|
467 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
468 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
469 |
+
if demodulate:
|
470 |
+
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
|
471 |
+
if demodulate and fused_modconv:
|
472 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
473 |
+
# Execute by scaling the activations before and after the convolution.
|
474 |
+
if not fused_modconv:
|
475 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
476 |
+
x = conv2d_resample.conv2d_resample(
|
477 |
+
x=x,
|
478 |
+
w=weight.to(x.dtype),
|
479 |
+
f=resample_filter,
|
480 |
+
up=up,
|
481 |
+
down=down,
|
482 |
+
padding=padding,
|
483 |
+
flip_weight=flip_weight,
|
484 |
+
)
|
485 |
+
if demodulate and noise is not None:
|
486 |
+
x = fma(
|
487 |
+
x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)
|
488 |
+
)
|
489 |
+
elif demodulate:
|
490 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
491 |
+
elif noise is not None:
|
492 |
+
x = x.add_(noise.to(x.dtype))
|
493 |
+
return x
|
494 |
+
|
495 |
+
# Execute as one fused op using grouped convolution.
|
496 |
+
batch_size = int(batch_size)
|
497 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
498 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
499 |
+
x = conv2d_resample(
|
500 |
+
x=x,
|
501 |
+
w=w.to(x.dtype),
|
502 |
+
f=resample_filter,
|
503 |
+
up=up,
|
504 |
+
down=down,
|
505 |
+
padding=padding,
|
506 |
+
groups=batch_size,
|
507 |
+
flip_weight=flip_weight,
|
508 |
+
)
|
509 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
510 |
+
if noise is not None:
|
511 |
+
x = x.add_(noise)
|
512 |
+
return x
|
513 |
+
|
514 |
+
|
515 |
+
class SynthesisLayer(torch.nn.Module):
|
516 |
+
def __init__(
|
517 |
+
self,
|
518 |
+
in_channels, # Number of input channels.
|
519 |
+
out_channels, # Number of output channels.
|
520 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
521 |
+
resolution, # Resolution of this layer.
|
522 |
+
kernel_size=3, # Convolution kernel size.
|
523 |
+
up=1, # Integer upsampling factor.
|
524 |
+
use_noise=True, # Enable noise input?
|
525 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
526 |
+
resample_filter=[
|
527 |
+
1,
|
528 |
+
3,
|
529 |
+
3,
|
530 |
+
1,
|
531 |
+
], # Low-pass filter to apply when resampling activations.
|
532 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
533 |
+
channels_last=False, # Use channels_last format for the weights?
|
534 |
+
):
|
535 |
+
super().__init__()
|
536 |
+
self.resolution = resolution
|
537 |
+
self.up = up
|
538 |
+
self.use_noise = use_noise
|
539 |
+
self.activation = activation
|
540 |
+
self.conv_clamp = conv_clamp
|
541 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
542 |
+
self.padding = kernel_size // 2
|
543 |
+
self.act_gain = activation_funcs[activation].def_gain
|
544 |
+
|
545 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
546 |
+
memory_format = (
|
547 |
+
torch.channels_last if channels_last else torch.contiguous_format
|
548 |
+
)
|
549 |
+
self.weight = torch.nn.Parameter(
|
550 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
|
551 |
+
memory_format=memory_format
|
552 |
+
)
|
553 |
+
)
|
554 |
+
if use_noise:
|
555 |
+
self.register_buffer("noise_const", torch.randn([resolution, resolution]))
|
556 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
557 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
558 |
+
|
559 |
+
def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1):
|
560 |
+
assert noise_mode in ["random", "const", "none"]
|
561 |
+
in_resolution = self.resolution // self.up
|
562 |
+
styles = self.affine(w)
|
563 |
+
|
564 |
+
noise = None
|
565 |
+
if self.use_noise and noise_mode == "random":
|
566 |
+
noise = (
|
567 |
+
torch.randn(
|
568 |
+
[x.shape[0], 1, self.resolution, self.resolution], device=x.device
|
569 |
+
)
|
570 |
+
* self.noise_strength
|
571 |
+
)
|
572 |
+
if self.use_noise and noise_mode == "const":
|
573 |
+
noise = self.noise_const * self.noise_strength
|
574 |
+
|
575 |
+
flip_weight = self.up == 1 # slightly faster
|
576 |
+
x = modulated_conv2d(
|
577 |
+
x=x,
|
578 |
+
weight=self.weight,
|
579 |
+
styles=styles,
|
580 |
+
noise=noise,
|
581 |
+
up=self.up,
|
582 |
+
padding=self.padding,
|
583 |
+
resample_filter=self.resample_filter,
|
584 |
+
flip_weight=flip_weight,
|
585 |
+
fused_modconv=fused_modconv,
|
586 |
+
)
|
587 |
+
|
588 |
+
act_gain = self.act_gain * gain
|
589 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
590 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
|
591 |
+
if act_gain != 1:
|
592 |
+
x = x * act_gain
|
593 |
+
if act_clamp is not None:
|
594 |
+
x = x.clamp(-act_clamp, act_clamp)
|
595 |
+
return x
|
596 |
+
|
597 |
+
|
598 |
+
class ToRGBLayer(torch.nn.Module):
|
599 |
+
def __init__(
|
600 |
+
self,
|
601 |
+
in_channels,
|
602 |
+
out_channels,
|
603 |
+
w_dim,
|
604 |
+
kernel_size=1,
|
605 |
+
conv_clamp=None,
|
606 |
+
channels_last=False,
|
607 |
+
):
|
608 |
+
super().__init__()
|
609 |
+
self.conv_clamp = conv_clamp
|
610 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
611 |
+
memory_format = (
|
612 |
+
torch.channels_last if channels_last else torch.contiguous_format
|
613 |
+
)
|
614 |
+
self.weight = torch.nn.Parameter(
|
615 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
|
616 |
+
memory_format=memory_format
|
617 |
+
)
|
618 |
+
)
|
619 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
620 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
621 |
+
|
622 |
+
def forward(self, x, w, fused_modconv=True):
|
623 |
+
styles = self.affine(w) * self.weight_gain
|
624 |
+
x = modulated_conv2d(
|
625 |
+
x=x,
|
626 |
+
weight=self.weight,
|
627 |
+
styles=styles,
|
628 |
+
demodulate=False,
|
629 |
+
fused_modconv=fused_modconv,
|
630 |
+
)
|
631 |
+
x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
632 |
+
return x
|
633 |
+
|
634 |
+
|
635 |
+
class SynthesisForeword(torch.nn.Module):
|
636 |
+
def __init__(
|
637 |
+
self,
|
638 |
+
z_dim, # Output Latent (Z) dimensionality.
|
639 |
+
resolution, # Resolution of this block.
|
640 |
+
in_channels,
|
641 |
+
img_channels, # Number of input color channels.
|
642 |
+
architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
|
643 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
644 |
+
):
|
645 |
+
super().__init__()
|
646 |
+
self.in_channels = in_channels
|
647 |
+
self.z_dim = z_dim
|
648 |
+
self.resolution = resolution
|
649 |
+
self.img_channels = img_channels
|
650 |
+
self.architecture = architecture
|
651 |
+
|
652 |
+
self.fc = FullyConnectedLayer(
|
653 |
+
self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation
|
654 |
+
)
|
655 |
+
self.conv = SynthesisLayer(
|
656 |
+
self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4
|
657 |
+
)
|
658 |
+
|
659 |
+
if architecture == "skip":
|
660 |
+
self.torgb = ToRGBLayer(
|
661 |
+
self.in_channels,
|
662 |
+
self.img_channels,
|
663 |
+
kernel_size=1,
|
664 |
+
w_dim=(z_dim // 2) * 3,
|
665 |
+
)
|
666 |
+
|
667 |
+
def forward(self, x, ws, feats, img, force_fp32=False):
|
668 |
+
_ = force_fp32 # unused
|
669 |
+
dtype = torch.float32
|
670 |
+
memory_format = torch.contiguous_format
|
671 |
+
|
672 |
+
x_global = x.clone()
|
673 |
+
# ToRGB.
|
674 |
+
x = self.fc(x)
|
675 |
+
x = x.view(-1, self.z_dim // 2, 4, 4)
|
676 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
677 |
+
|
678 |
+
# Main layers.
|
679 |
+
x_skip = feats[4].clone()
|
680 |
+
x = x + x_skip
|
681 |
+
|
682 |
+
mod_vector = []
|
683 |
+
mod_vector.append(ws[:, 0])
|
684 |
+
mod_vector.append(x_global.clone())
|
685 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
686 |
+
|
687 |
+
x = self.conv(x, mod_vector)
|
688 |
+
|
689 |
+
mod_vector = []
|
690 |
+
mod_vector.append(ws[:, 2 * 2 - 3])
|
691 |
+
mod_vector.append(x_global.clone())
|
692 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
693 |
+
|
694 |
+
if self.architecture == "skip":
|
695 |
+
img = self.torgb(x, mod_vector)
|
696 |
+
img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
697 |
+
|
698 |
+
assert x.dtype == dtype
|
699 |
+
return x, img
|
700 |
+
|
701 |
+
|
702 |
+
class SELayer(nn.Module):
|
703 |
+
def __init__(self, channel, reduction=16):
|
704 |
+
super(SELayer, self).__init__()
|
705 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
706 |
+
self.fc = nn.Sequential(
|
707 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
708 |
+
nn.ReLU(inplace=False),
|
709 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
710 |
+
nn.Sigmoid(),
|
711 |
+
)
|
712 |
+
|
713 |
+
def forward(self, x):
|
714 |
+
b, c, _, _ = x.size()
|
715 |
+
y = self.avg_pool(x).view(b, c)
|
716 |
+
y = self.fc(y).view(b, c, 1, 1)
|
717 |
+
res = x * y.expand_as(x)
|
718 |
+
return res
|
719 |
+
|
720 |
+
|
721 |
+
class FourierUnit(nn.Module):
|
722 |
+
def __init__(
|
723 |
+
self,
|
724 |
+
in_channels,
|
725 |
+
out_channels,
|
726 |
+
groups=1,
|
727 |
+
spatial_scale_factor=None,
|
728 |
+
spatial_scale_mode="bilinear",
|
729 |
+
spectral_pos_encoding=False,
|
730 |
+
use_se=False,
|
731 |
+
se_kwargs=None,
|
732 |
+
ffc3d=False,
|
733 |
+
fft_norm="ortho",
|
734 |
+
):
|
735 |
+
# bn_layer not used
|
736 |
+
super(FourierUnit, self).__init__()
|
737 |
+
self.groups = groups
|
738 |
+
|
739 |
+
self.conv_layer = torch.nn.Conv2d(
|
740 |
+
in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
741 |
+
out_channels=out_channels * 2,
|
742 |
+
kernel_size=1,
|
743 |
+
stride=1,
|
744 |
+
padding=0,
|
745 |
+
groups=self.groups,
|
746 |
+
bias=False,
|
747 |
+
)
|
748 |
+
self.relu = torch.nn.ReLU(inplace=False)
|
749 |
+
|
750 |
+
# squeeze and excitation block
|
751 |
+
self.use_se = use_se
|
752 |
+
if use_se:
|
753 |
+
if se_kwargs is None:
|
754 |
+
se_kwargs = {}
|
755 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
756 |
+
|
757 |
+
self.spatial_scale_factor = spatial_scale_factor
|
758 |
+
self.spatial_scale_mode = spatial_scale_mode
|
759 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
760 |
+
self.ffc3d = ffc3d
|
761 |
+
self.fft_norm = fft_norm
|
762 |
+
|
763 |
+
def forward(self, x):
|
764 |
+
batch = x.shape[0]
|
765 |
+
|
766 |
+
if self.spatial_scale_factor is not None:
|
767 |
+
orig_size = x.shape[-2:]
|
768 |
+
x = F.interpolate(
|
769 |
+
x,
|
770 |
+
scale_factor=self.spatial_scale_factor,
|
771 |
+
mode=self.spatial_scale_mode,
|
772 |
+
align_corners=False,
|
773 |
+
)
|
774 |
+
|
775 |
+
r_size = x.size()
|
776 |
+
# (batch, c, h, w/2+1, 2)
|
777 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
778 |
+
ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
779 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
780 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
781 |
+
ffted = ffted.view(
|
782 |
+
(
|
783 |
+
batch,
|
784 |
+
-1,
|
785 |
+
)
|
786 |
+
+ ffted.size()[3:]
|
787 |
+
)
|
788 |
+
|
789 |
+
if self.spectral_pos_encoding:
|
790 |
+
height, width = ffted.shape[-2:]
|
791 |
+
coords_vert = (
|
792 |
+
torch.linspace(0, 1, height)[None, None, :, None]
|
793 |
+
.expand(batch, 1, height, width)
|
794 |
+
.to(ffted)
|
795 |
+
)
|
796 |
+
coords_hor = (
|
797 |
+
torch.linspace(0, 1, width)[None, None, None, :]
|
798 |
+
.expand(batch, 1, height, width)
|
799 |
+
.to(ffted)
|
800 |
+
)
|
801 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
802 |
+
|
803 |
+
if self.use_se:
|
804 |
+
ffted = self.se(ffted)
|
805 |
+
|
806 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
807 |
+
ffted = self.relu(ffted)
|
808 |
+
|
809 |
+
ffted = (
|
810 |
+
ffted.view(
|
811 |
+
(
|
812 |
+
batch,
|
813 |
+
-1,
|
814 |
+
2,
|
815 |
+
)
|
816 |
+
+ ffted.size()[2:]
|
817 |
+
)
|
818 |
+
.permute(0, 1, 3, 4, 2)
|
819 |
+
.contiguous()
|
820 |
+
) # (batch,c, t, h, w/2+1, 2)
|
821 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
822 |
+
|
823 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
824 |
+
output = torch.fft.irfftn(
|
825 |
+
ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
|
826 |
+
)
|
827 |
+
|
828 |
+
if self.spatial_scale_factor is not None:
|
829 |
+
output = F.interpolate(
|
830 |
+
output,
|
831 |
+
size=orig_size,
|
832 |
+
mode=self.spatial_scale_mode,
|
833 |
+
align_corners=False,
|
834 |
+
)
|
835 |
+
|
836 |
+
return output
|
837 |
+
|
838 |
+
|
839 |
+
class SpectralTransform(nn.Module):
|
840 |
+
def __init__(
|
841 |
+
self,
|
842 |
+
in_channels,
|
843 |
+
out_channels,
|
844 |
+
stride=1,
|
845 |
+
groups=1,
|
846 |
+
enable_lfu=True,
|
847 |
+
**fu_kwargs,
|
848 |
+
):
|
849 |
+
# bn_layer not used
|
850 |
+
super(SpectralTransform, self).__init__()
|
851 |
+
self.enable_lfu = enable_lfu
|
852 |
+
if stride == 2:
|
853 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
854 |
+
else:
|
855 |
+
self.downsample = nn.Identity()
|
856 |
+
|
857 |
+
self.stride = stride
|
858 |
+
self.conv1 = nn.Sequential(
|
859 |
+
nn.Conv2d(
|
860 |
+
in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
|
861 |
+
),
|
862 |
+
# nn.BatchNorm2d(out_channels // 2),
|
863 |
+
nn.ReLU(inplace=True),
|
864 |
+
)
|
865 |
+
self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
866 |
+
if self.enable_lfu:
|
867 |
+
self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
|
868 |
+
self.conv2 = torch.nn.Conv2d(
|
869 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
|
870 |
+
)
|
871 |
+
|
872 |
+
def forward(self, x):
|
873 |
+
|
874 |
+
x = self.downsample(x)
|
875 |
+
x = self.conv1(x)
|
876 |
+
output = self.fu(x)
|
877 |
+
|
878 |
+
if self.enable_lfu:
|
879 |
+
n, c, h, w = x.shape
|
880 |
+
split_no = 2
|
881 |
+
split_s = h // split_no
|
882 |
+
xs = torch.cat(
|
883 |
+
torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
|
884 |
+
).contiguous()
|
885 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
|
886 |
+
xs = self.lfu(xs)
|
887 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
888 |
+
else:
|
889 |
+
xs = 0
|
890 |
+
|
891 |
+
output = self.conv2(x + output + xs)
|
892 |
+
|
893 |
+
return output
|
894 |
+
|
895 |
+
|
896 |
+
class FFC(nn.Module):
|
897 |
+
def __init__(
|
898 |
+
self,
|
899 |
+
in_channels,
|
900 |
+
out_channels,
|
901 |
+
kernel_size,
|
902 |
+
ratio_gin,
|
903 |
+
ratio_gout,
|
904 |
+
stride=1,
|
905 |
+
padding=0,
|
906 |
+
dilation=1,
|
907 |
+
groups=1,
|
908 |
+
bias=False,
|
909 |
+
enable_lfu=True,
|
910 |
+
padding_type="reflect",
|
911 |
+
gated=False,
|
912 |
+
**spectral_kwargs,
|
913 |
+
):
|
914 |
+
super(FFC, self).__init__()
|
915 |
+
|
916 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
917 |
+
self.stride = stride
|
918 |
+
|
919 |
+
in_cg = int(in_channels * ratio_gin)
|
920 |
+
in_cl = in_channels - in_cg
|
921 |
+
out_cg = int(out_channels * ratio_gout)
|
922 |
+
out_cl = out_channels - out_cg
|
923 |
+
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
924 |
+
# groups_l = 1 if groups == 1 else groups - groups_g
|
925 |
+
|
926 |
+
self.ratio_gin = ratio_gin
|
927 |
+
self.ratio_gout = ratio_gout
|
928 |
+
self.global_in_num = in_cg
|
929 |
+
|
930 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
931 |
+
self.convl2l = module(
|
932 |
+
in_cl,
|
933 |
+
out_cl,
|
934 |
+
kernel_size,
|
935 |
+
stride,
|
936 |
+
padding,
|
937 |
+
dilation,
|
938 |
+
groups,
|
939 |
+
bias,
|
940 |
+
padding_mode=padding_type,
|
941 |
+
)
|
942 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
943 |
+
self.convl2g = module(
|
944 |
+
in_cl,
|
945 |
+
out_cg,
|
946 |
+
kernel_size,
|
947 |
+
stride,
|
948 |
+
padding,
|
949 |
+
dilation,
|
950 |
+
groups,
|
951 |
+
bias,
|
952 |
+
padding_mode=padding_type,
|
953 |
+
)
|
954 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
955 |
+
self.convg2l = module(
|
956 |
+
in_cg,
|
957 |
+
out_cl,
|
958 |
+
kernel_size,
|
959 |
+
stride,
|
960 |
+
padding,
|
961 |
+
dilation,
|
962 |
+
groups,
|
963 |
+
bias,
|
964 |
+
padding_mode=padding_type,
|
965 |
+
)
|
966 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
967 |
+
self.convg2g = module(
|
968 |
+
in_cg,
|
969 |
+
out_cg,
|
970 |
+
stride,
|
971 |
+
1 if groups == 1 else groups // 2,
|
972 |
+
enable_lfu,
|
973 |
+
**spectral_kwargs,
|
974 |
+
)
|
975 |
+
|
976 |
+
self.gated = gated
|
977 |
+
module = (
|
978 |
+
nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
979 |
+
)
|
980 |
+
self.gate = module(in_channels, 2, 1)
|
981 |
+
|
982 |
+
def forward(self, x, fname=None):
|
983 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
984 |
+
out_xl, out_xg = 0, 0
|
985 |
+
|
986 |
+
if self.gated:
|
987 |
+
total_input_parts = [x_l]
|
988 |
+
if torch.is_tensor(x_g):
|
989 |
+
total_input_parts.append(x_g)
|
990 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
991 |
+
|
992 |
+
gates = torch.sigmoid(self.gate(total_input))
|
993 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
994 |
+
else:
|
995 |
+
g2l_gate, l2g_gate = 1, 1
|
996 |
+
|
997 |
+
spec_x = self.convg2g(x_g)
|
998 |
+
|
999 |
+
if self.ratio_gout != 1:
|
1000 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
1001 |
+
if self.ratio_gout != 0:
|
1002 |
+
out_xg = self.convl2g(x_l) * l2g_gate + spec_x
|
1003 |
+
|
1004 |
+
return out_xl, out_xg
|
1005 |
+
|
1006 |
+
|
1007 |
+
class FFC_BN_ACT(nn.Module):
|
1008 |
+
def __init__(
|
1009 |
+
self,
|
1010 |
+
in_channels,
|
1011 |
+
out_channels,
|
1012 |
+
kernel_size,
|
1013 |
+
ratio_gin,
|
1014 |
+
ratio_gout,
|
1015 |
+
stride=1,
|
1016 |
+
padding=0,
|
1017 |
+
dilation=1,
|
1018 |
+
groups=1,
|
1019 |
+
bias=False,
|
1020 |
+
norm_layer=nn.SyncBatchNorm,
|
1021 |
+
activation_layer=nn.Identity,
|
1022 |
+
padding_type="reflect",
|
1023 |
+
enable_lfu=True,
|
1024 |
+
**kwargs,
|
1025 |
+
):
|
1026 |
+
super(FFC_BN_ACT, self).__init__()
|
1027 |
+
self.ffc = FFC(
|
1028 |
+
in_channels,
|
1029 |
+
out_channels,
|
1030 |
+
kernel_size,
|
1031 |
+
ratio_gin,
|
1032 |
+
ratio_gout,
|
1033 |
+
stride,
|
1034 |
+
padding,
|
1035 |
+
dilation,
|
1036 |
+
groups,
|
1037 |
+
bias,
|
1038 |
+
enable_lfu,
|
1039 |
+
padding_type=padding_type,
|
1040 |
+
**kwargs,
|
1041 |
+
)
|
1042 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
1043 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
1044 |
+
global_channels = int(out_channels * ratio_gout)
|
1045 |
+
# self.bn_l = lnorm(out_channels - global_channels)
|
1046 |
+
# self.bn_g = gnorm(global_channels)
|
1047 |
+
|
1048 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
1049 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
1050 |
+
self.act_l = lact(inplace=True)
|
1051 |
+
self.act_g = gact(inplace=True)
|
1052 |
+
|
1053 |
+
def forward(self, x, fname=None):
|
1054 |
+
x_l, x_g = self.ffc(
|
1055 |
+
x,
|
1056 |
+
fname=fname,
|
1057 |
+
)
|
1058 |
+
x_l = self.act_l(x_l)
|
1059 |
+
x_g = self.act_g(x_g)
|
1060 |
+
return x_l, x_g
|
1061 |
+
|
1062 |
+
|
1063 |
+
class FFCResnetBlock(nn.Module):
|
1064 |
+
def __init__(
|
1065 |
+
self,
|
1066 |
+
dim,
|
1067 |
+
padding_type,
|
1068 |
+
norm_layer,
|
1069 |
+
activation_layer=nn.ReLU,
|
1070 |
+
dilation=1,
|
1071 |
+
spatial_transform_kwargs=None,
|
1072 |
+
inline=False,
|
1073 |
+
ratio_gin=0.75,
|
1074 |
+
ratio_gout=0.75,
|
1075 |
+
):
|
1076 |
+
super().__init__()
|
1077 |
+
self.conv1 = FFC_BN_ACT(
|
1078 |
+
dim,
|
1079 |
+
dim,
|
1080 |
+
kernel_size=3,
|
1081 |
+
padding=dilation,
|
1082 |
+
dilation=dilation,
|
1083 |
+
norm_layer=norm_layer,
|
1084 |
+
activation_layer=activation_layer,
|
1085 |
+
padding_type=padding_type,
|
1086 |
+
ratio_gin=ratio_gin,
|
1087 |
+
ratio_gout=ratio_gout,
|
1088 |
+
)
|
1089 |
+
self.conv2 = FFC_BN_ACT(
|
1090 |
+
dim,
|
1091 |
+
dim,
|
1092 |
+
kernel_size=3,
|
1093 |
+
padding=dilation,
|
1094 |
+
dilation=dilation,
|
1095 |
+
norm_layer=norm_layer,
|
1096 |
+
activation_layer=activation_layer,
|
1097 |
+
padding_type=padding_type,
|
1098 |
+
ratio_gin=ratio_gin,
|
1099 |
+
ratio_gout=ratio_gout,
|
1100 |
+
)
|
1101 |
+
self.inline = inline
|
1102 |
+
|
1103 |
+
def forward(self, x, fname=None):
|
1104 |
+
if self.inline:
|
1105 |
+
x_l, x_g = (
|
1106 |
+
x[:, : -self.conv1.ffc.global_in_num],
|
1107 |
+
x[:, -self.conv1.ffc.global_in_num :],
|
1108 |
+
)
|
1109 |
+
else:
|
1110 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
1111 |
+
|
1112 |
+
id_l, id_g = x_l, x_g
|
1113 |
+
|
1114 |
+
x_l, x_g = self.conv1((x_l, x_g), fname=fname)
|
1115 |
+
x_l, x_g = self.conv2((x_l, x_g), fname=fname)
|
1116 |
+
|
1117 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
1118 |
+
out = x_l, x_g
|
1119 |
+
if self.inline:
|
1120 |
+
out = torch.cat(out, dim=1)
|
1121 |
+
return out
|
1122 |
+
|
1123 |
+
|
1124 |
+
class ConcatTupleLayer(nn.Module):
|
1125 |
+
def forward(self, x):
|
1126 |
+
assert isinstance(x, tuple)
|
1127 |
+
x_l, x_g = x
|
1128 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
1129 |
+
if not torch.is_tensor(x_g):
|
1130 |
+
return x_l
|
1131 |
+
return torch.cat(x, dim=1)
|
1132 |
+
|
1133 |
+
|
1134 |
+
class FFCBlock(torch.nn.Module):
|
1135 |
+
def __init__(
|
1136 |
+
self,
|
1137 |
+
dim, # Number of output/input channels.
|
1138 |
+
kernel_size, # Width and height of the convolution kernel.
|
1139 |
+
padding,
|
1140 |
+
ratio_gin=0.75,
|
1141 |
+
ratio_gout=0.75,
|
1142 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
1143 |
+
):
|
1144 |
+
super().__init__()
|
1145 |
+
if activation == "linear":
|
1146 |
+
self.activation = nn.Identity
|
1147 |
+
else:
|
1148 |
+
self.activation = nn.ReLU
|
1149 |
+
self.padding = padding
|
1150 |
+
self.kernel_size = kernel_size
|
1151 |
+
self.ffc_block = FFCResnetBlock(
|
1152 |
+
dim=dim,
|
1153 |
+
padding_type="reflect",
|
1154 |
+
norm_layer=nn.SyncBatchNorm,
|
1155 |
+
activation_layer=self.activation,
|
1156 |
+
dilation=1,
|
1157 |
+
ratio_gin=ratio_gin,
|
1158 |
+
ratio_gout=ratio_gout,
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
self.concat_layer = ConcatTupleLayer()
|
1162 |
+
|
1163 |
+
def forward(self, gen_ft, mask, fname=None):
|
1164 |
+
x = gen_ft.float()
|
1165 |
+
|
1166 |
+
x_l, x_g = (
|
1167 |
+
x[:, : -self.ffc_block.conv1.ffc.global_in_num],
|
1168 |
+
x[:, -self.ffc_block.conv1.ffc.global_in_num :],
|
1169 |
+
)
|
1170 |
+
id_l, id_g = x_l, x_g
|
1171 |
+
|
1172 |
+
x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
|
1173 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
1174 |
+
x = self.concat_layer((x_l, x_g))
|
1175 |
+
|
1176 |
+
return x + gen_ft.float()
|
1177 |
+
|
1178 |
+
|
1179 |
+
class FFCSkipLayer(torch.nn.Module):
|
1180 |
+
def __init__(
|
1181 |
+
self,
|
1182 |
+
dim, # Number of input/output channels.
|
1183 |
+
kernel_size=3, # Convolution kernel size.
|
1184 |
+
ratio_gin=0.75,
|
1185 |
+
ratio_gout=0.75,
|
1186 |
+
):
|
1187 |
+
super().__init__()
|
1188 |
+
self.padding = kernel_size // 2
|
1189 |
+
|
1190 |
+
self.ffc_act = FFCBlock(
|
1191 |
+
dim=dim,
|
1192 |
+
kernel_size=kernel_size,
|
1193 |
+
activation=nn.ReLU,
|
1194 |
+
padding=self.padding,
|
1195 |
+
ratio_gin=ratio_gin,
|
1196 |
+
ratio_gout=ratio_gout,
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
def forward(self, gen_ft, mask, fname=None):
|
1200 |
+
x = self.ffc_act(gen_ft, mask, fname=fname)
|
1201 |
+
return x
|
1202 |
+
|
1203 |
+
|
1204 |
+
class SynthesisBlock(torch.nn.Module):
|
1205 |
+
def __init__(
|
1206 |
+
self,
|
1207 |
+
in_channels, # Number of input channels, 0 = first block.
|
1208 |
+
out_channels, # Number of output channels.
|
1209 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1210 |
+
resolution, # Resolution of this block.
|
1211 |
+
img_channels, # Number of output color channels.
|
1212 |
+
is_last, # Is this the last block?
|
1213 |
+
architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
|
1214 |
+
resample_filter=[
|
1215 |
+
1,
|
1216 |
+
3,
|
1217 |
+
3,
|
1218 |
+
1,
|
1219 |
+
], # Low-pass filter to apply when resampling activations.
|
1220 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
1221 |
+
use_fp16=False, # Use FP16 for this block?
|
1222 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
1223 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
1224 |
+
):
|
1225 |
+
assert architecture in ["orig", "skip", "resnet"]
|
1226 |
+
super().__init__()
|
1227 |
+
self.in_channels = in_channels
|
1228 |
+
self.w_dim = w_dim
|
1229 |
+
self.resolution = resolution
|
1230 |
+
self.img_channels = img_channels
|
1231 |
+
self.is_last = is_last
|
1232 |
+
self.architecture = architecture
|
1233 |
+
self.use_fp16 = use_fp16
|
1234 |
+
self.channels_last = use_fp16 and fp16_channels_last
|
1235 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
1236 |
+
self.num_conv = 0
|
1237 |
+
self.num_torgb = 0
|
1238 |
+
self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
|
1239 |
+
|
1240 |
+
if in_channels != 0 and resolution >= 8:
|
1241 |
+
self.ffc_skip = nn.ModuleList()
|
1242 |
+
for _ in range(self.res_ffc[resolution]):
|
1243 |
+
self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
|
1244 |
+
|
1245 |
+
if in_channels == 0:
|
1246 |
+
self.const = torch.nn.Parameter(
|
1247 |
+
torch.randn([out_channels, resolution, resolution])
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
if in_channels != 0:
|
1251 |
+
self.conv0 = SynthesisLayer(
|
1252 |
+
in_channels,
|
1253 |
+
out_channels,
|
1254 |
+
w_dim=w_dim * 3,
|
1255 |
+
resolution=resolution,
|
1256 |
+
up=2,
|
1257 |
+
resample_filter=resample_filter,
|
1258 |
+
conv_clamp=conv_clamp,
|
1259 |
+
channels_last=self.channels_last,
|
1260 |
+
**layer_kwargs,
|
1261 |
+
)
|
1262 |
+
self.num_conv += 1
|
1263 |
+
|
1264 |
+
self.conv1 = SynthesisLayer(
|
1265 |
+
out_channels,
|
1266 |
+
out_channels,
|
1267 |
+
w_dim=w_dim * 3,
|
1268 |
+
resolution=resolution,
|
1269 |
+
conv_clamp=conv_clamp,
|
1270 |
+
channels_last=self.channels_last,
|
1271 |
+
**layer_kwargs,
|
1272 |
+
)
|
1273 |
+
self.num_conv += 1
|
1274 |
+
|
1275 |
+
if is_last or architecture == "skip":
|
1276 |
+
self.torgb = ToRGBLayer(
|
1277 |
+
out_channels,
|
1278 |
+
img_channels,
|
1279 |
+
w_dim=w_dim * 3,
|
1280 |
+
conv_clamp=conv_clamp,
|
1281 |
+
channels_last=self.channels_last,
|
1282 |
+
)
|
1283 |
+
self.num_torgb += 1
|
1284 |
+
|
1285 |
+
if in_channels != 0 and architecture == "resnet":
|
1286 |
+
self.skip = Conv2dLayer(
|
1287 |
+
in_channels,
|
1288 |
+
out_channels,
|
1289 |
+
kernel_size=1,
|
1290 |
+
bias=False,
|
1291 |
+
up=2,
|
1292 |
+
resample_filter=resample_filter,
|
1293 |
+
channels_last=self.channels_last,
|
1294 |
+
)
|
1295 |
+
|
1296 |
+
def forward(
|
1297 |
+
self,
|
1298 |
+
x,
|
1299 |
+
mask,
|
1300 |
+
feats,
|
1301 |
+
img,
|
1302 |
+
ws,
|
1303 |
+
fname=None,
|
1304 |
+
force_fp32=False,
|
1305 |
+
fused_modconv=None,
|
1306 |
+
**layer_kwargs,
|
1307 |
+
):
|
1308 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
1309 |
+
dtype = torch.float32
|
1310 |
+
memory_format = (
|
1311 |
+
torch.channels_last
|
1312 |
+
if self.channels_last and not force_fp32
|
1313 |
+
else torch.contiguous_format
|
1314 |
+
)
|
1315 |
+
if fused_modconv is None:
|
1316 |
+
fused_modconv = (not self.training) and (
|
1317 |
+
dtype == torch.float32 or int(x.shape[0]) == 1
|
1318 |
+
)
|
1319 |
+
|
1320 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
1321 |
+
x_skip = (
|
1322 |
+
feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
|
1323 |
+
)
|
1324 |
+
|
1325 |
+
# Main layers.
|
1326 |
+
if self.in_channels == 0:
|
1327 |
+
x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
|
1328 |
+
elif self.architecture == "resnet":
|
1329 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
1330 |
+
x = self.conv0(
|
1331 |
+
x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
|
1332 |
+
)
|
1333 |
+
if len(self.ffc_skip) > 0:
|
1334 |
+
mask = F.interpolate(
|
1335 |
+
mask,
|
1336 |
+
size=x_skip.shape[2:],
|
1337 |
+
)
|
1338 |
+
z = x + x_skip
|
1339 |
+
for fres in self.ffc_skip:
|
1340 |
+
z = fres(z, mask)
|
1341 |
+
x = x + z
|
1342 |
+
else:
|
1343 |
+
x = x + x_skip
|
1344 |
+
x = self.conv1(
|
1345 |
+
x,
|
1346 |
+
ws[1].clone(),
|
1347 |
+
fused_modconv=fused_modconv,
|
1348 |
+
gain=np.sqrt(0.5),
|
1349 |
+
**layer_kwargs,
|
1350 |
+
)
|
1351 |
+
x = y.add_(x)
|
1352 |
+
else:
|
1353 |
+
x = self.conv0(
|
1354 |
+
x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
|
1355 |
+
)
|
1356 |
+
if len(self.ffc_skip) > 0:
|
1357 |
+
mask = F.interpolate(
|
1358 |
+
mask,
|
1359 |
+
size=x_skip.shape[2:],
|
1360 |
+
)
|
1361 |
+
z = x + x_skip
|
1362 |
+
for fres in self.ffc_skip:
|
1363 |
+
z = fres(z, mask)
|
1364 |
+
x = x + z
|
1365 |
+
else:
|
1366 |
+
x = x + x_skip
|
1367 |
+
x = self.conv1(
|
1368 |
+
x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs
|
1369 |
+
)
|
1370 |
+
# ToRGB.
|
1371 |
+
if img is not None:
|
1372 |
+
img = upsample2d(img, self.resample_filter)
|
1373 |
+
if self.is_last or self.architecture == "skip":
|
1374 |
+
y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
|
1375 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
1376 |
+
img = img.add_(y) if img is not None else y
|
1377 |
+
|
1378 |
+
x = x.to(dtype=dtype)
|
1379 |
+
assert x.dtype == dtype
|
1380 |
+
assert img is None or img.dtype == torch.float32
|
1381 |
+
return x, img
|
1382 |
+
|
1383 |
+
|
1384 |
+
class SynthesisNetwork(torch.nn.Module):
|
1385 |
+
def __init__(
|
1386 |
+
self,
|
1387 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1388 |
+
z_dim, # Output Latent (Z) dimensionality.
|
1389 |
+
img_resolution, # Output image resolution.
|
1390 |
+
img_channels, # Number of color channels.
|
1391 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
1392 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1393 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
1394 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
1395 |
+
):
|
1396 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
1397 |
+
super().__init__()
|
1398 |
+
self.w_dim = w_dim
|
1399 |
+
self.img_resolution = img_resolution
|
1400 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
1401 |
+
self.img_channels = img_channels
|
1402 |
+
self.block_resolutions = [
|
1403 |
+
2**i for i in range(3, self.img_resolution_log2 + 1)
|
1404 |
+
]
|
1405 |
+
channels_dict = {
|
1406 |
+
res: min(channel_base // res, channel_max) for res in self.block_resolutions
|
1407 |
+
}
|
1408 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
1409 |
+
|
1410 |
+
self.foreword = SynthesisForeword(
|
1411 |
+
img_channels=img_channels,
|
1412 |
+
in_channels=min(channel_base // 4, channel_max),
|
1413 |
+
z_dim=z_dim * 2,
|
1414 |
+
resolution=4,
|
1415 |
+
)
|
1416 |
+
|
1417 |
+
self.num_ws = self.img_resolution_log2 * 2 - 2
|
1418 |
+
for res in self.block_resolutions:
|
1419 |
+
if res // 2 in channels_dict.keys():
|
1420 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
1421 |
+
else:
|
1422 |
+
in_channels = min(channel_base // (res // 2), channel_max)
|
1423 |
+
out_channels = channels_dict[res]
|
1424 |
+
use_fp16 = res >= fp16_resolution
|
1425 |
+
use_fp16 = False
|
1426 |
+
is_last = res == self.img_resolution
|
1427 |
+
block = SynthesisBlock(
|
1428 |
+
in_channels,
|
1429 |
+
out_channels,
|
1430 |
+
w_dim=w_dim,
|
1431 |
+
resolution=res,
|
1432 |
+
img_channels=img_channels,
|
1433 |
+
is_last=is_last,
|
1434 |
+
use_fp16=use_fp16,
|
1435 |
+
**block_kwargs,
|
1436 |
+
)
|
1437 |
+
setattr(self, f"b{res}", block)
|
1438 |
+
|
1439 |
+
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
|
1440 |
+
|
1441 |
+
img = None
|
1442 |
+
|
1443 |
+
x, img = self.foreword(x_global, ws, feats, img)
|
1444 |
+
|
1445 |
+
for res in self.block_resolutions:
|
1446 |
+
block = getattr(self, f"b{res}")
|
1447 |
+
mod_vector0 = []
|
1448 |
+
mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
|
1449 |
+
mod_vector0.append(x_global.clone())
|
1450 |
+
mod_vector0 = torch.cat(mod_vector0, dim=1)
|
1451 |
+
|
1452 |
+
mod_vector1 = []
|
1453 |
+
mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
|
1454 |
+
mod_vector1.append(x_global.clone())
|
1455 |
+
mod_vector1 = torch.cat(mod_vector1, dim=1)
|
1456 |
+
|
1457 |
+
mod_vector_rgb = []
|
1458 |
+
mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
|
1459 |
+
mod_vector_rgb.append(x_global.clone())
|
1460 |
+
mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
|
1461 |
+
x, img = block(
|
1462 |
+
x,
|
1463 |
+
mask,
|
1464 |
+
feats,
|
1465 |
+
img,
|
1466 |
+
(mod_vector0, mod_vector1, mod_vector_rgb),
|
1467 |
+
fname=fname,
|
1468 |
+
**block_kwargs,
|
1469 |
+
)
|
1470 |
+
return img
|
1471 |
+
|
1472 |
+
|
1473 |
+
class MappingNetwork(torch.nn.Module):
|
1474 |
+
def __init__(
|
1475 |
+
self,
|
1476 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1477 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1478 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1479 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
1480 |
+
num_layers=8, # Number of mapping layers.
|
1481 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
1482 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
1483 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
1484 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
1485 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
1486 |
+
):
|
1487 |
+
super().__init__()
|
1488 |
+
self.z_dim = z_dim
|
1489 |
+
self.c_dim = c_dim
|
1490 |
+
self.w_dim = w_dim
|
1491 |
+
self.num_ws = num_ws
|
1492 |
+
self.num_layers = num_layers
|
1493 |
+
self.w_avg_beta = w_avg_beta
|
1494 |
+
|
1495 |
+
if embed_features is None:
|
1496 |
+
embed_features = w_dim
|
1497 |
+
if c_dim == 0:
|
1498 |
+
embed_features = 0
|
1499 |
+
if layer_features is None:
|
1500 |
+
layer_features = w_dim
|
1501 |
+
features_list = (
|
1502 |
+
[z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
1503 |
+
)
|
1504 |
+
|
1505 |
+
if c_dim > 0:
|
1506 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
1507 |
+
for idx in range(num_layers):
|
1508 |
+
in_features = features_list[idx]
|
1509 |
+
out_features = features_list[idx + 1]
|
1510 |
+
layer = FullyConnectedLayer(
|
1511 |
+
in_features,
|
1512 |
+
out_features,
|
1513 |
+
activation=activation,
|
1514 |
+
lr_multiplier=lr_multiplier,
|
1515 |
+
)
|
1516 |
+
setattr(self, f"fc{idx}", layer)
|
1517 |
+
|
1518 |
+
if num_ws is not None and w_avg_beta is not None:
|
1519 |
+
self.register_buffer("w_avg", torch.zeros([w_dim]))
|
1520 |
+
|
1521 |
+
def forward(
|
1522 |
+
self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
|
1523 |
+
):
|
1524 |
+
# Embed, normalize, and concat inputs.
|
1525 |
+
x = None
|
1526 |
+
with torch.autograd.profiler.record_function("input"):
|
1527 |
+
if self.z_dim > 0:
|
1528 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
1529 |
+
if self.c_dim > 0:
|
1530 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
1531 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
1532 |
+
|
1533 |
+
# Main layers.
|
1534 |
+
for idx in range(self.num_layers):
|
1535 |
+
layer = getattr(self, f"fc{idx}")
|
1536 |
+
x = layer(x)
|
1537 |
+
|
1538 |
+
# Update moving average of W.
|
1539 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
1540 |
+
with torch.autograd.profiler.record_function("update_w_avg"):
|
1541 |
+
self.w_avg.copy_(
|
1542 |
+
x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
|
1543 |
+
)
|
1544 |
+
|
1545 |
+
# Broadcast.
|
1546 |
+
if self.num_ws is not None:
|
1547 |
+
with torch.autograd.profiler.record_function("broadcast"):
|
1548 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
1549 |
+
|
1550 |
+
# Apply truncation.
|
1551 |
+
if truncation_psi != 1:
|
1552 |
+
with torch.autograd.profiler.record_function("truncate"):
|
1553 |
+
assert self.w_avg_beta is not None
|
1554 |
+
if self.num_ws is None or truncation_cutoff is None:
|
1555 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
1556 |
+
else:
|
1557 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(
|
1558 |
+
x[:, :truncation_cutoff], truncation_psi
|
1559 |
+
)
|
1560 |
+
return x
|
1561 |
+
|
1562 |
+
|
1563 |
+
class Generator(torch.nn.Module):
|
1564 |
+
def __init__(
|
1565 |
+
self,
|
1566 |
+
z_dim, # Input latent (Z) dimensionality.
|
1567 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1568 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1569 |
+
img_resolution, # Output resolution.
|
1570 |
+
img_channels, # Number of output color channels.
|
1571 |
+
encoder_kwargs={}, # Arguments for EncoderNetwork.
|
1572 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1573 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1574 |
+
):
|
1575 |
+
super().__init__()
|
1576 |
+
self.z_dim = z_dim
|
1577 |
+
self.c_dim = c_dim
|
1578 |
+
self.w_dim = w_dim
|
1579 |
+
self.img_resolution = img_resolution
|
1580 |
+
self.img_channels = img_channels
|
1581 |
+
self.encoder = EncoderNetwork(
|
1582 |
+
c_dim=c_dim,
|
1583 |
+
z_dim=z_dim,
|
1584 |
+
img_resolution=img_resolution,
|
1585 |
+
img_channels=img_channels,
|
1586 |
+
**encoder_kwargs,
|
1587 |
+
)
|
1588 |
+
self.synthesis = SynthesisNetwork(
|
1589 |
+
z_dim=z_dim,
|
1590 |
+
w_dim=w_dim,
|
1591 |
+
img_resolution=img_resolution,
|
1592 |
+
img_channels=img_channels,
|
1593 |
+
**synthesis_kwargs,
|
1594 |
+
)
|
1595 |
+
self.num_ws = self.synthesis.num_ws
|
1596 |
+
self.mapping = MappingNetwork(
|
1597 |
+
z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs
|
1598 |
+
)
|
1599 |
+
|
1600 |
+
def forward(
|
1601 |
+
self,
|
1602 |
+
img,
|
1603 |
+
c,
|
1604 |
+
fname=None,
|
1605 |
+
truncation_psi=1,
|
1606 |
+
truncation_cutoff=None,
|
1607 |
+
**synthesis_kwargs,
|
1608 |
+
):
|
1609 |
+
mask = img[:, -1].unsqueeze(1)
|
1610 |
+
x_global, z, feats = self.encoder(img, c)
|
1611 |
+
ws = self.mapping(
|
1612 |
+
z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff
|
1613 |
+
)
|
1614 |
+
img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
|
1615 |
+
return img
|
1616 |
+
|
1617 |
+
|
1618 |
+
FCF_MODEL_URL = os.environ.get(
|
1619 |
+
"FCF_MODEL_URL",
|
1620 |
+
"https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
|
1621 |
+
)
|
1622 |
+
FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211")
|
1623 |
+
|
1624 |
+
|
1625 |
+
class FcF(InpaintModel):
|
1626 |
+
name = "fcf"
|
1627 |
+
min_size = 512
|
1628 |
+
pad_mod = 512
|
1629 |
+
pad_to_square = True
|
1630 |
+
|
1631 |
+
def init_model(self, device, **kwargs):
|
1632 |
+
seed = 0
|
1633 |
+
random.seed(seed)
|
1634 |
+
np.random.seed(seed)
|
1635 |
+
torch.manual_seed(seed)
|
1636 |
+
torch.cuda.manual_seed_all(seed)
|
1637 |
+
torch.backends.cudnn.deterministic = True
|
1638 |
+
torch.backends.cudnn.benchmark = False
|
1639 |
+
|
1640 |
+
kwargs = {
|
1641 |
+
"channel_base": 1 * 32768,
|
1642 |
+
"channel_max": 512,
|
1643 |
+
"num_fp16_res": 4,
|
1644 |
+
"conv_clamp": 256,
|
1645 |
+
}
|
1646 |
+
G = Generator(
|
1647 |
+
z_dim=512,
|
1648 |
+
c_dim=0,
|
1649 |
+
w_dim=512,
|
1650 |
+
img_resolution=512,
|
1651 |
+
img_channels=3,
|
1652 |
+
synthesis_kwargs=kwargs,
|
1653 |
+
encoder_kwargs=kwargs,
|
1654 |
+
mapping_kwargs={"num_layers": 2},
|
1655 |
+
)
|
1656 |
+
self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5)
|
1657 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
1658 |
+
|
1659 |
+
@staticmethod
|
1660 |
+
def is_downloaded() -> bool:
|
1661 |
+
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
1662 |
+
|
1663 |
+
@torch.no_grad()
|
1664 |
+
def __call__(self, image, mask, config: Config):
|
1665 |
+
"""
|
1666 |
+
images: [H, W, C] RGB, not normalized
|
1667 |
+
masks: [H, W]
|
1668 |
+
return: BGR IMAGE
|
1669 |
+
"""
|
1670 |
+
if image.shape[0] == 512 and image.shape[1] == 512:
|
1671 |
+
return self._pad_forward(image, mask, config)
|
1672 |
+
|
1673 |
+
boxes = boxes_from_mask(mask)
|
1674 |
+
crop_result = []
|
1675 |
+
config.hd_strategy_crop_margin = 128
|
1676 |
+
for box in boxes:
|
1677 |
+
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
1678 |
+
origin_size = crop_image.shape[:2]
|
1679 |
+
resize_image = resize_max_size(crop_image, size_limit=512)
|
1680 |
+
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
1681 |
+
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
1682 |
+
|
1683 |
+
# only paste masked area result
|
1684 |
+
inpaint_result = cv2.resize(
|
1685 |
+
inpaint_result,
|
1686 |
+
(origin_size[1], origin_size[0]),
|
1687 |
+
interpolation=cv2.INTER_CUBIC,
|
1688 |
+
)
|
1689 |
+
|
1690 |
+
original_pixel_indices = crop_mask < 127
|
1691 |
+
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
|
1692 |
+
original_pixel_indices
|
1693 |
+
]
|
1694 |
+
|
1695 |
+
crop_result.append((inpaint_result, crop_box))
|
1696 |
+
|
1697 |
+
inpaint_result = image[:, :, ::-1]
|
1698 |
+
for crop_image, crop_box in crop_result:
|
1699 |
+
x1, y1, x2, y2 = crop_box
|
1700 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
1701 |
+
|
1702 |
+
return inpaint_result
|
1703 |
+
|
1704 |
+
def forward(self, image, mask, config: Config):
|
1705 |
+
"""Input images and output images have same size
|
1706 |
+
images: [H, W, C] RGB
|
1707 |
+
masks: [H, W] mask area == 255
|
1708 |
+
return: BGR IMAGE
|
1709 |
+
"""
|
1710 |
+
|
1711 |
+
image = norm_img(image) # [0, 1]
|
1712 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1713 |
+
mask = (mask > 120) * 255
|
1714 |
+
mask = norm_img(mask)
|
1715 |
+
|
1716 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
1717 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
1718 |
+
|
1719 |
+
erased_img = image * (1 - mask)
|
1720 |
+
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
1721 |
+
|
1722 |
+
output = self.model(
|
1723 |
+
input_image, self.label, truncation_psi=0.1, noise_mode="none"
|
1724 |
+
)
|
1725 |
+
output = (
|
1726 |
+
(output.permute(0, 2, 3, 1) * 127.5 + 127.5)
|
1727 |
+
.round()
|
1728 |
+
.clamp(0, 255)
|
1729 |
+
.to(torch.uint8)
|
1730 |
+
)
|
1731 |
+
output = output[0].cpu().numpy()
|
1732 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1733 |
+
return cur_res
|
lama_cleaner/model/instruct_pix2pix.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL.Image
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from lama_cleaner.model.base import DiffusionInpaintModel
|
7 |
+
from lama_cleaner.model.utils import set_seed
|
8 |
+
from lama_cleaner.schema import Config
|
9 |
+
|
10 |
+
|
11 |
+
class InstructPix2Pix(DiffusionInpaintModel):
|
12 |
+
name = "instruct_pix2pix"
|
13 |
+
pad_mod = 8
|
14 |
+
min_size = 512
|
15 |
+
|
16 |
+
def init_model(self, device: torch.device, **kwargs):
|
17 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline
|
18 |
+
fp16 = not kwargs.get('no_half', False)
|
19 |
+
|
20 |
+
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
21 |
+
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
22 |
+
logger.info("Disable Stable Diffusion Model NSFW checker")
|
23 |
+
model_kwargs.update(dict(
|
24 |
+
safety_checker=None,
|
25 |
+
feature_extractor=None,
|
26 |
+
requires_safety_checker=False
|
27 |
+
))
|
28 |
+
|
29 |
+
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
30 |
+
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
31 |
+
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
32 |
+
"timbrooks/instruct-pix2pix",
|
33 |
+
revision="fp16" if use_gpu and fp16 else "main",
|
34 |
+
torch_dtype=torch_dtype,
|
35 |
+
**model_kwargs
|
36 |
+
)
|
37 |
+
|
38 |
+
self.model.enable_attention_slicing()
|
39 |
+
if kwargs.get('enable_xformers', False):
|
40 |
+
self.model.enable_xformers_memory_efficient_attention()
|
41 |
+
|
42 |
+
if kwargs.get('cpu_offload', False) and use_gpu:
|
43 |
+
logger.info("Enable sequential cpu offload")
|
44 |
+
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
45 |
+
else:
|
46 |
+
self.model = self.model.to(device)
|
47 |
+
|
48 |
+
def forward(self, image, mask, config: Config):
|
49 |
+
"""Input image and output image have same size
|
50 |
+
image: [H, W, C] RGB
|
51 |
+
mask: [H, W, 1] 255 means area to repaint
|
52 |
+
return: BGR IMAGE
|
53 |
+
edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
|
54 |
+
"""
|
55 |
+
output = self.model(
|
56 |
+
image=PIL.Image.fromarray(image),
|
57 |
+
prompt=config.prompt,
|
58 |
+
negative_prompt=config.negative_prompt,
|
59 |
+
num_inference_steps=config.p2p_steps,
|
60 |
+
image_guidance_scale=config.p2p_image_guidance_scale,
|
61 |
+
guidance_scale=config.p2p_guidance_scale,
|
62 |
+
output_type="np.array",
|
63 |
+
generator=torch.manual_seed(config.sd_seed)
|
64 |
+
).images[0]
|
65 |
+
|
66 |
+
output = (output * 255).round().astype("uint8")
|
67 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
68 |
+
return output
|
69 |
+
|
70 |
+
#
|
71 |
+
# def forward_post_process(self, result, image, mask, config):
|
72 |
+
# if config.sd_match_histograms:
|
73 |
+
# result = self._match_histograms(result, image[:, :, ::-1], mask)
|
74 |
+
#
|
75 |
+
# if config.sd_mask_blur != 0:
|
76 |
+
# k = 2 * config.sd_mask_blur + 1
|
77 |
+
# mask = cv2.GaussianBlur(mask, (k, k), 0)
|
78 |
+
# return result, image, mask
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def is_downloaded() -> bool:
|
82 |
+
# model will be downloaded when app start, and can't switch in frontend settings
|
83 |
+
return True
|
lama_cleaner/model/lama.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from lama_cleaner.helper import (
|
8 |
+
norm_img,
|
9 |
+
get_cache_path_by_url,
|
10 |
+
load_jit_model,
|
11 |
+
)
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from lama_cleaner.schema import Config
|
14 |
+
|
15 |
+
LAMA_MODEL_URL = os.environ.get(
|
16 |
+
"LAMA_MODEL_URL",
|
17 |
+
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
18 |
+
)
|
19 |
+
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
|
20 |
+
|
21 |
+
|
22 |
+
class LaMa(InpaintModel):
|
23 |
+
name = "lama"
|
24 |
+
pad_mod = 8
|
25 |
+
|
26 |
+
def init_model(self, device, **kwargs):
|
27 |
+
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def is_downloaded() -> bool:
|
31 |
+
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
32 |
+
|
33 |
+
def forward(self, image, mask, config: Config):
|
34 |
+
"""Input image and output image have same size
|
35 |
+
image: [H, W, C] RGB
|
36 |
+
mask: [H, W]
|
37 |
+
return: BGR IMAGE
|
38 |
+
"""
|
39 |
+
image = norm_img(image)
|
40 |
+
mask = norm_img(mask)
|
41 |
+
|
42 |
+
mask = (mask > 0) * 1
|
43 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
44 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
45 |
+
|
46 |
+
inpainted_image = self.model(image, mask)
|
47 |
+
|
48 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
49 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
50 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
51 |
+
return cur_res
|
lama_cleaner/model/ldm.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import wraps
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img
|
9 |
+
from lama_cleaner.model.base import InpaintModel
|
10 |
+
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
11 |
+
from lama_cleaner.model.plms_sampler import PLMSSampler
|
12 |
+
from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding
|
13 |
+
from lama_cleaner.schema import Config, LDMSampler
|
14 |
+
|
15 |
+
# torch.manual_seed(42)
|
16 |
+
|
17 |
+
|
18 |
+
def conditional_autocast(func):
|
19 |
+
@wraps(func)
|
20 |
+
def wrapper(*args, **kwargs):
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
with torch.cuda.amp.autocast():
|
23 |
+
return func(*args, **kwargs)
|
24 |
+
else:
|
25 |
+
return func(*args, **kwargs)
|
26 |
+
return wrapper
|
27 |
+
|
28 |
+
|
29 |
+
LDM_ENCODE_MODEL_URL = os.environ.get(
|
30 |
+
"LDM_ENCODE_MODEL_URL",
|
31 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
32 |
+
)
|
33 |
+
LDM_ENCODE_MODEL_MD5 = os.environ.get(
|
34 |
+
"LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
|
35 |
+
)
|
36 |
+
|
37 |
+
LDM_DECODE_MODEL_URL = os.environ.get(
|
38 |
+
"LDM_DECODE_MODEL_URL",
|
39 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
40 |
+
)
|
41 |
+
LDM_DECODE_MODEL_MD5 = os.environ.get(
|
42 |
+
"LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
|
43 |
+
)
|
44 |
+
|
45 |
+
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
46 |
+
"LDM_DIFFUSION_MODEL_URL",
|
47 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
48 |
+
)
|
49 |
+
|
50 |
+
LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
|
51 |
+
"LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class DDPM(nn.Module):
|
56 |
+
# classic DDPM with Gaussian diffusion, in image space
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
device,
|
60 |
+
timesteps=1000,
|
61 |
+
beta_schedule="linear",
|
62 |
+
linear_start=0.0015,
|
63 |
+
linear_end=0.0205,
|
64 |
+
cosine_s=0.008,
|
65 |
+
original_elbo_weight=0.0,
|
66 |
+
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
67 |
+
l_simple_weight=1.0,
|
68 |
+
parameterization="eps", # all assuming fixed variance schedules
|
69 |
+
use_positional_encodings=False,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.device = device
|
73 |
+
self.parameterization = parameterization
|
74 |
+
self.use_positional_encodings = use_positional_encodings
|
75 |
+
|
76 |
+
self.v_posterior = v_posterior
|
77 |
+
self.original_elbo_weight = original_elbo_weight
|
78 |
+
self.l_simple_weight = l_simple_weight
|
79 |
+
|
80 |
+
self.register_schedule(
|
81 |
+
beta_schedule=beta_schedule,
|
82 |
+
timesteps=timesteps,
|
83 |
+
linear_start=linear_start,
|
84 |
+
linear_end=linear_end,
|
85 |
+
cosine_s=cosine_s,
|
86 |
+
)
|
87 |
+
|
88 |
+
def register_schedule(
|
89 |
+
self,
|
90 |
+
given_betas=None,
|
91 |
+
beta_schedule="linear",
|
92 |
+
timesteps=1000,
|
93 |
+
linear_start=1e-4,
|
94 |
+
linear_end=2e-2,
|
95 |
+
cosine_s=8e-3,
|
96 |
+
):
|
97 |
+
betas = make_beta_schedule(
|
98 |
+
self.device,
|
99 |
+
beta_schedule,
|
100 |
+
timesteps,
|
101 |
+
linear_start=linear_start,
|
102 |
+
linear_end=linear_end,
|
103 |
+
cosine_s=cosine_s,
|
104 |
+
)
|
105 |
+
alphas = 1.0 - betas
|
106 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
107 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
108 |
+
|
109 |
+
(timesteps,) = betas.shape
|
110 |
+
self.num_timesteps = int(timesteps)
|
111 |
+
self.linear_start = linear_start
|
112 |
+
self.linear_end = linear_end
|
113 |
+
assert (
|
114 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
115 |
+
), "alphas have to be defined for each timestep"
|
116 |
+
|
117 |
+
def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device)
|
118 |
+
|
119 |
+
self.register_buffer("betas", to_torch(betas))
|
120 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
121 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
122 |
+
|
123 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
124 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
125 |
+
self.register_buffer(
|
126 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
127 |
+
)
|
128 |
+
self.register_buffer(
|
129 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
130 |
+
)
|
131 |
+
self.register_buffer(
|
132 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
133 |
+
)
|
134 |
+
self.register_buffer(
|
135 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
136 |
+
)
|
137 |
+
|
138 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
139 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
140 |
+
1.0 - alphas_cumprod_prev
|
141 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
142 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
143 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
144 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
145 |
+
self.register_buffer(
|
146 |
+
"posterior_log_variance_clipped",
|
147 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
148 |
+
)
|
149 |
+
self.register_buffer(
|
150 |
+
"posterior_mean_coef1",
|
151 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
152 |
+
)
|
153 |
+
self.register_buffer(
|
154 |
+
"posterior_mean_coef2",
|
155 |
+
to_torch(
|
156 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
157 |
+
),
|
158 |
+
)
|
159 |
+
|
160 |
+
if self.parameterization == "eps":
|
161 |
+
lvlb_weights = self.betas**2 / (
|
162 |
+
2
|
163 |
+
* self.posterior_variance
|
164 |
+
* to_torch(alphas)
|
165 |
+
* (1 - self.alphas_cumprod)
|
166 |
+
)
|
167 |
+
elif self.parameterization == "x0":
|
168 |
+
lvlb_weights = (
|
169 |
+
0.5
|
170 |
+
* np.sqrt(torch.Tensor(alphas_cumprod))
|
171 |
+
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
raise NotImplementedError("mu not supported")
|
175 |
+
# TODO how to choose this term
|
176 |
+
lvlb_weights[0] = lvlb_weights[1]
|
177 |
+
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
178 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
179 |
+
|
180 |
+
|
181 |
+
class LatentDiffusion(DDPM):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
diffusion_model,
|
185 |
+
device,
|
186 |
+
cond_stage_key="image",
|
187 |
+
cond_stage_trainable=False,
|
188 |
+
concat_mode=True,
|
189 |
+
scale_factor=1.0,
|
190 |
+
scale_by_std=False,
|
191 |
+
*args,
|
192 |
+
**kwargs,
|
193 |
+
):
|
194 |
+
self.num_timesteps_cond = 1
|
195 |
+
self.scale_by_std = scale_by_std
|
196 |
+
super().__init__(device, *args, **kwargs)
|
197 |
+
self.diffusion_model = diffusion_model
|
198 |
+
self.concat_mode = concat_mode
|
199 |
+
self.cond_stage_trainable = cond_stage_trainable
|
200 |
+
self.cond_stage_key = cond_stage_key
|
201 |
+
self.num_downs = 2
|
202 |
+
self.scale_factor = scale_factor
|
203 |
+
|
204 |
+
def make_cond_schedule(
|
205 |
+
self,
|
206 |
+
):
|
207 |
+
self.cond_ids = torch.full(
|
208 |
+
size=(self.num_timesteps,),
|
209 |
+
fill_value=self.num_timesteps - 1,
|
210 |
+
dtype=torch.long,
|
211 |
+
)
|
212 |
+
ids = torch.round(
|
213 |
+
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
214 |
+
).long()
|
215 |
+
self.cond_ids[: self.num_timesteps_cond] = ids
|
216 |
+
|
217 |
+
def register_schedule(
|
218 |
+
self,
|
219 |
+
given_betas=None,
|
220 |
+
beta_schedule="linear",
|
221 |
+
timesteps=1000,
|
222 |
+
linear_start=1e-4,
|
223 |
+
linear_end=2e-2,
|
224 |
+
cosine_s=8e-3,
|
225 |
+
):
|
226 |
+
super().register_schedule(
|
227 |
+
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
228 |
+
)
|
229 |
+
|
230 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
231 |
+
if self.shorten_cond_schedule:
|
232 |
+
self.make_cond_schedule()
|
233 |
+
|
234 |
+
def apply_model(self, x_noisy, t, cond):
|
235 |
+
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
236 |
+
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
237 |
+
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
238 |
+
return x_recon
|
239 |
+
|
240 |
+
|
241 |
+
class LDM(InpaintModel):
|
242 |
+
name = "ldm"
|
243 |
+
pad_mod = 32
|
244 |
+
|
245 |
+
def __init__(self, device, fp16: bool = True, **kwargs):
|
246 |
+
self.fp16 = fp16
|
247 |
+
super().__init__(device)
|
248 |
+
self.device = device
|
249 |
+
|
250 |
+
def init_model(self, device, **kwargs):
|
251 |
+
self.diffusion_model = load_jit_model(
|
252 |
+
LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
|
253 |
+
)
|
254 |
+
self.cond_stage_model_decode = load_jit_model(
|
255 |
+
LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
|
256 |
+
)
|
257 |
+
self.cond_stage_model_encode = load_jit_model(
|
258 |
+
LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
|
259 |
+
)
|
260 |
+
if self.fp16 and "cuda" in str(device):
|
261 |
+
self.diffusion_model = self.diffusion_model.half()
|
262 |
+
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
263 |
+
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
264 |
+
|
265 |
+
self.model = LatentDiffusion(self.diffusion_model, device)
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def is_downloaded() -> bool:
|
269 |
+
model_paths = [
|
270 |
+
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
271 |
+
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
272 |
+
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
273 |
+
]
|
274 |
+
return all([os.path.exists(it) for it in model_paths])
|
275 |
+
|
276 |
+
@conditional_autocast
|
277 |
+
def forward(self, image, mask, config: Config):
|
278 |
+
"""
|
279 |
+
image: [H, W, C] RGB
|
280 |
+
mask: [H, W, 1]
|
281 |
+
return: BGR IMAGE
|
282 |
+
"""
|
283 |
+
# image [1,3,512,512] float32
|
284 |
+
# mask: [1,1,512,512] float32
|
285 |
+
# masked_image: [1,3,512,512] float32
|
286 |
+
if config.ldm_sampler == LDMSampler.ddim:
|
287 |
+
sampler = DDIMSampler(self.model)
|
288 |
+
elif config.ldm_sampler == LDMSampler.plms:
|
289 |
+
sampler = PLMSSampler(self.model)
|
290 |
+
else:
|
291 |
+
raise ValueError()
|
292 |
+
|
293 |
+
steps = config.ldm_steps
|
294 |
+
image = norm_img(image)
|
295 |
+
mask = norm_img(mask)
|
296 |
+
|
297 |
+
mask[mask < 0.5] = 0
|
298 |
+
mask[mask >= 0.5] = 1
|
299 |
+
|
300 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
301 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
302 |
+
masked_image = (1 - mask) * image
|
303 |
+
|
304 |
+
mask = self._norm(mask)
|
305 |
+
masked_image = self._norm(masked_image)
|
306 |
+
|
307 |
+
c = self.cond_stage_model_encode(masked_image)
|
308 |
+
torch.cuda.empty_cache()
|
309 |
+
|
310 |
+
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
311 |
+
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
312 |
+
|
313 |
+
shape = (c.shape[1] - 1,) + c.shape[2:]
|
314 |
+
samples_ddim = sampler.sample(
|
315 |
+
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
316 |
+
)
|
317 |
+
torch.cuda.empty_cache()
|
318 |
+
x_samples_ddim = self.cond_stage_model_decode(
|
319 |
+
samples_ddim
|
320 |
+
) # samples_ddim: 1, 3, 128, 128 float32
|
321 |
+
torch.cuda.empty_cache()
|
322 |
+
|
323 |
+
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
324 |
+
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
325 |
+
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
326 |
+
|
327 |
+
# inpainted = (1 - mask) * image + mask * predicted_image
|
328 |
+
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
329 |
+
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
330 |
+
return inpainted_image
|
331 |
+
|
332 |
+
def _norm(self, tensor):
|
333 |
+
return tensor * 2.0 - 1.0
|
lama_cleaner/model/manga.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import time
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
11 |
+
from lama_cleaner.model.base import InpaintModel
|
12 |
+
from lama_cleaner.schema import Config
|
13 |
+
|
14 |
+
|
15 |
+
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
16 |
+
"MANGA_INPAINTOR_MODEL_URL",
|
17 |
+
"https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
|
18 |
+
)
|
19 |
+
MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
|
20 |
+
"MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
|
21 |
+
)
|
22 |
+
|
23 |
+
MANGA_LINE_MODEL_URL = os.environ.get(
|
24 |
+
"MANGA_LINE_MODEL_URL",
|
25 |
+
"https://github.com/Sanster/models/releases/download/manga/erika.jit",
|
26 |
+
)
|
27 |
+
MANGA_LINE_MODEL_MD5 = os.environ.get(
|
28 |
+
"MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class Manga(InpaintModel):
|
33 |
+
name = "manga"
|
34 |
+
pad_mod = 16
|
35 |
+
|
36 |
+
def init_model(self, device, **kwargs):
|
37 |
+
self.inpaintor_model = load_jit_model(
|
38 |
+
MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
|
39 |
+
)
|
40 |
+
self.line_model = load_jit_model(
|
41 |
+
MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
|
42 |
+
)
|
43 |
+
self.seed = 42
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def is_downloaded() -> bool:
|
47 |
+
model_paths = [
|
48 |
+
get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
|
49 |
+
get_cache_path_by_url(MANGA_LINE_MODEL_URL),
|
50 |
+
]
|
51 |
+
return all([os.path.exists(it) for it in model_paths])
|
52 |
+
|
53 |
+
def forward(self, image, mask, config: Config):
|
54 |
+
"""
|
55 |
+
image: [H, W, C] RGB
|
56 |
+
mask: [H, W, 1]
|
57 |
+
return: BGR IMAGE
|
58 |
+
"""
|
59 |
+
seed = self.seed
|
60 |
+
random.seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
torch.manual_seed(seed)
|
63 |
+
torch.cuda.manual_seed_all(seed)
|
64 |
+
|
65 |
+
gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
66 |
+
gray_img = torch.from_numpy(
|
67 |
+
gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
|
68 |
+
).to(self.device)
|
69 |
+
start = time.time()
|
70 |
+
lines = self.line_model(gray_img)
|
71 |
+
torch.cuda.empty_cache()
|
72 |
+
lines = torch.clamp(lines, 0, 255)
|
73 |
+
logger.info(f"erika_model time: {time.time() - start}")
|
74 |
+
|
75 |
+
mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
|
76 |
+
mask = mask.permute(0, 3, 1, 2)
|
77 |
+
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
78 |
+
noise = torch.randn_like(mask)
|
79 |
+
ones = torch.ones_like(mask)
|
80 |
+
|
81 |
+
gray_img = gray_img / 255 * 2 - 1.0
|
82 |
+
lines = lines / 255 * 2 - 1.0
|
83 |
+
|
84 |
+
start = time.time()
|
85 |
+
inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
|
86 |
+
logger.info(f"image_inpaintor_model time: {time.time() - start}")
|
87 |
+
|
88 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
89 |
+
cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
|
90 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
|
91 |
+
return cur_res
|
lama_cleaner/model/mat.py
ADDED
@@ -0,0 +1,1935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
|
11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
|
12 |
+
from lama_cleaner.model.base import InpaintModel
|
13 |
+
from lama_cleaner.model.utils import (
|
14 |
+
setup_filter,
|
15 |
+
Conv2dLayer,
|
16 |
+
FullyConnectedLayer,
|
17 |
+
conv2d_resample,
|
18 |
+
bias_act,
|
19 |
+
upsample2d,
|
20 |
+
activation_funcs,
|
21 |
+
MinibatchStdLayer,
|
22 |
+
to_2tuple,
|
23 |
+
normalize_2nd_moment,
|
24 |
+
set_seed,
|
25 |
+
)
|
26 |
+
from lama_cleaner.schema import Config
|
27 |
+
|
28 |
+
|
29 |
+
class ModulatedConv2d(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
in_channels, # Number of input channels.
|
33 |
+
out_channels, # Number of output channels.
|
34 |
+
kernel_size, # Width and height of the convolution kernel.
|
35 |
+
style_dim, # dimension of the style code
|
36 |
+
demodulate=True, # perfrom demodulation
|
37 |
+
up=1, # Integer upsampling factor.
|
38 |
+
down=1, # Integer downsampling factor.
|
39 |
+
resample_filter=[
|
40 |
+
1,
|
41 |
+
3,
|
42 |
+
3,
|
43 |
+
1,
|
44 |
+
], # Low-pass filter to apply when resampling activations.
|
45 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.demodulate = demodulate
|
49 |
+
|
50 |
+
self.weight = torch.nn.Parameter(
|
51 |
+
torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])
|
52 |
+
)
|
53 |
+
self.out_channels = out_channels
|
54 |
+
self.kernel_size = kernel_size
|
55 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
56 |
+
self.padding = self.kernel_size // 2
|
57 |
+
self.up = up
|
58 |
+
self.down = down
|
59 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
60 |
+
self.conv_clamp = conv_clamp
|
61 |
+
|
62 |
+
self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
|
63 |
+
|
64 |
+
def forward(self, x, style):
|
65 |
+
batch, in_channels, height, width = x.shape
|
66 |
+
style = self.affine(style).view(batch, 1, in_channels, 1, 1)
|
67 |
+
weight = self.weight * self.weight_gain * style
|
68 |
+
|
69 |
+
if self.demodulate:
|
70 |
+
decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
|
71 |
+
weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
|
72 |
+
|
73 |
+
weight = weight.view(
|
74 |
+
batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size
|
75 |
+
)
|
76 |
+
x = x.view(1, batch * in_channels, height, width)
|
77 |
+
x = conv2d_resample(
|
78 |
+
x=x,
|
79 |
+
w=weight,
|
80 |
+
f=self.resample_filter,
|
81 |
+
up=self.up,
|
82 |
+
down=self.down,
|
83 |
+
padding=self.padding,
|
84 |
+
groups=batch,
|
85 |
+
)
|
86 |
+
out = x.view(batch, self.out_channels, *x.shape[2:])
|
87 |
+
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class StyleConv(torch.nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
in_channels, # Number of input channels.
|
95 |
+
out_channels, # Number of output channels.
|
96 |
+
style_dim, # Intermediate latent (W) dimensionality.
|
97 |
+
resolution, # Resolution of this layer.
|
98 |
+
kernel_size=3, # Convolution kernel size.
|
99 |
+
up=1, # Integer upsampling factor.
|
100 |
+
use_noise=False, # Enable noise input?
|
101 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
102 |
+
resample_filter=[
|
103 |
+
1,
|
104 |
+
3,
|
105 |
+
3,
|
106 |
+
1,
|
107 |
+
], # Low-pass filter to apply when resampling activations.
|
108 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
109 |
+
demodulate=True, # perform demodulation
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
self.conv = ModulatedConv2d(
|
114 |
+
in_channels=in_channels,
|
115 |
+
out_channels=out_channels,
|
116 |
+
kernel_size=kernel_size,
|
117 |
+
style_dim=style_dim,
|
118 |
+
demodulate=demodulate,
|
119 |
+
up=up,
|
120 |
+
resample_filter=resample_filter,
|
121 |
+
conv_clamp=conv_clamp,
|
122 |
+
)
|
123 |
+
|
124 |
+
self.use_noise = use_noise
|
125 |
+
self.resolution = resolution
|
126 |
+
if use_noise:
|
127 |
+
self.register_buffer("noise_const", torch.randn([resolution, resolution]))
|
128 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
129 |
+
|
130 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
131 |
+
self.activation = activation
|
132 |
+
self.act_gain = activation_funcs[activation].def_gain
|
133 |
+
self.conv_clamp = conv_clamp
|
134 |
+
|
135 |
+
def forward(self, x, style, noise_mode="random", gain=1):
|
136 |
+
x = self.conv(x, style)
|
137 |
+
|
138 |
+
assert noise_mode in ["random", "const", "none"]
|
139 |
+
|
140 |
+
if self.use_noise:
|
141 |
+
if noise_mode == "random":
|
142 |
+
xh, xw = x.size()[-2:]
|
143 |
+
noise = (
|
144 |
+
torch.randn([x.shape[0], 1, xh, xw], device=x.device)
|
145 |
+
* self.noise_strength
|
146 |
+
)
|
147 |
+
if noise_mode == "const":
|
148 |
+
noise = self.noise_const * self.noise_strength
|
149 |
+
x = x + noise
|
150 |
+
|
151 |
+
act_gain = self.act_gain * gain
|
152 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
153 |
+
out = bias_act(
|
154 |
+
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
|
155 |
+
)
|
156 |
+
|
157 |
+
return out
|
158 |
+
|
159 |
+
|
160 |
+
class ToRGB(torch.nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_channels,
|
164 |
+
out_channels,
|
165 |
+
style_dim,
|
166 |
+
kernel_size=1,
|
167 |
+
resample_filter=[1, 3, 3, 1],
|
168 |
+
conv_clamp=None,
|
169 |
+
demodulate=False,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
self.conv = ModulatedConv2d(
|
174 |
+
in_channels=in_channels,
|
175 |
+
out_channels=out_channels,
|
176 |
+
kernel_size=kernel_size,
|
177 |
+
style_dim=style_dim,
|
178 |
+
demodulate=demodulate,
|
179 |
+
resample_filter=resample_filter,
|
180 |
+
conv_clamp=conv_clamp,
|
181 |
+
)
|
182 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
183 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
184 |
+
self.conv_clamp = conv_clamp
|
185 |
+
|
186 |
+
def forward(self, x, style, skip=None):
|
187 |
+
x = self.conv(x, style)
|
188 |
+
out = bias_act(x, self.bias, clamp=self.conv_clamp)
|
189 |
+
|
190 |
+
if skip is not None:
|
191 |
+
if skip.shape != out.shape:
|
192 |
+
skip = upsample2d(skip, self.resample_filter)
|
193 |
+
out = out + skip
|
194 |
+
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
def get_style_code(a, b):
|
199 |
+
return torch.cat([a, b], dim=1)
|
200 |
+
|
201 |
+
|
202 |
+
class DecBlockFirst(nn.Module):
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
in_channels,
|
206 |
+
out_channels,
|
207 |
+
activation,
|
208 |
+
style_dim,
|
209 |
+
use_noise,
|
210 |
+
demodulate,
|
211 |
+
img_channels,
|
212 |
+
):
|
213 |
+
super().__init__()
|
214 |
+
self.fc = FullyConnectedLayer(
|
215 |
+
in_features=in_channels * 2,
|
216 |
+
out_features=in_channels * 4 ** 2,
|
217 |
+
activation=activation,
|
218 |
+
)
|
219 |
+
self.conv = StyleConv(
|
220 |
+
in_channels=in_channels,
|
221 |
+
out_channels=out_channels,
|
222 |
+
style_dim=style_dim,
|
223 |
+
resolution=4,
|
224 |
+
kernel_size=3,
|
225 |
+
use_noise=use_noise,
|
226 |
+
activation=activation,
|
227 |
+
demodulate=demodulate,
|
228 |
+
)
|
229 |
+
self.toRGB = ToRGB(
|
230 |
+
in_channels=out_channels,
|
231 |
+
out_channels=img_channels,
|
232 |
+
style_dim=style_dim,
|
233 |
+
kernel_size=1,
|
234 |
+
demodulate=False,
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, x, ws, gs, E_features, noise_mode="random"):
|
238 |
+
x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
239 |
+
x = x + E_features[2]
|
240 |
+
style = get_style_code(ws[:, 0], gs)
|
241 |
+
x = self.conv(x, style, noise_mode=noise_mode)
|
242 |
+
style = get_style_code(ws[:, 1], gs)
|
243 |
+
img = self.toRGB(x, style, skip=None)
|
244 |
+
|
245 |
+
return x, img
|
246 |
+
|
247 |
+
|
248 |
+
class DecBlockFirstV2(nn.Module):
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
in_channels,
|
252 |
+
out_channels,
|
253 |
+
activation,
|
254 |
+
style_dim,
|
255 |
+
use_noise,
|
256 |
+
demodulate,
|
257 |
+
img_channels,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
self.conv0 = Conv2dLayer(
|
261 |
+
in_channels=in_channels,
|
262 |
+
out_channels=in_channels,
|
263 |
+
kernel_size=3,
|
264 |
+
activation=activation,
|
265 |
+
)
|
266 |
+
self.conv1 = StyleConv(
|
267 |
+
in_channels=in_channels,
|
268 |
+
out_channels=out_channels,
|
269 |
+
style_dim=style_dim,
|
270 |
+
resolution=4,
|
271 |
+
kernel_size=3,
|
272 |
+
use_noise=use_noise,
|
273 |
+
activation=activation,
|
274 |
+
demodulate=demodulate,
|
275 |
+
)
|
276 |
+
self.toRGB = ToRGB(
|
277 |
+
in_channels=out_channels,
|
278 |
+
out_channels=img_channels,
|
279 |
+
style_dim=style_dim,
|
280 |
+
kernel_size=1,
|
281 |
+
demodulate=False,
|
282 |
+
)
|
283 |
+
|
284 |
+
def forward(self, x, ws, gs, E_features, noise_mode="random"):
|
285 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
286 |
+
x = self.conv0(x)
|
287 |
+
x = x + E_features[2]
|
288 |
+
style = get_style_code(ws[:, 0], gs)
|
289 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
290 |
+
style = get_style_code(ws[:, 1], gs)
|
291 |
+
img = self.toRGB(x, style, skip=None)
|
292 |
+
|
293 |
+
return x, img
|
294 |
+
|
295 |
+
|
296 |
+
class DecBlock(nn.Module):
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
res,
|
300 |
+
in_channels,
|
301 |
+
out_channels,
|
302 |
+
activation,
|
303 |
+
style_dim,
|
304 |
+
use_noise,
|
305 |
+
demodulate,
|
306 |
+
img_channels,
|
307 |
+
): # res = 2, ..., resolution_log2
|
308 |
+
super().__init__()
|
309 |
+
self.res = res
|
310 |
+
|
311 |
+
self.conv0 = StyleConv(
|
312 |
+
in_channels=in_channels,
|
313 |
+
out_channels=out_channels,
|
314 |
+
style_dim=style_dim,
|
315 |
+
resolution=2 ** res,
|
316 |
+
kernel_size=3,
|
317 |
+
up=2,
|
318 |
+
use_noise=use_noise,
|
319 |
+
activation=activation,
|
320 |
+
demodulate=demodulate,
|
321 |
+
)
|
322 |
+
self.conv1 = StyleConv(
|
323 |
+
in_channels=out_channels,
|
324 |
+
out_channels=out_channels,
|
325 |
+
style_dim=style_dim,
|
326 |
+
resolution=2 ** res,
|
327 |
+
kernel_size=3,
|
328 |
+
use_noise=use_noise,
|
329 |
+
activation=activation,
|
330 |
+
demodulate=demodulate,
|
331 |
+
)
|
332 |
+
self.toRGB = ToRGB(
|
333 |
+
in_channels=out_channels,
|
334 |
+
out_channels=img_channels,
|
335 |
+
style_dim=style_dim,
|
336 |
+
kernel_size=1,
|
337 |
+
demodulate=False,
|
338 |
+
)
|
339 |
+
|
340 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
|
341 |
+
style = get_style_code(ws[:, self.res * 2 - 5], gs)
|
342 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
343 |
+
x = x + E_features[self.res]
|
344 |
+
style = get_style_code(ws[:, self.res * 2 - 4], gs)
|
345 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
346 |
+
style = get_style_code(ws[:, self.res * 2 - 3], gs)
|
347 |
+
img = self.toRGB(x, style, skip=img)
|
348 |
+
|
349 |
+
return x, img
|
350 |
+
|
351 |
+
|
352 |
+
class MappingNet(torch.nn.Module):
|
353 |
+
def __init__(
|
354 |
+
self,
|
355 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
356 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
357 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
358 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
359 |
+
num_layers=8, # Number of mapping layers.
|
360 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
361 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
362 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
363 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
364 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
365 |
+
torch_dtype=torch.float32,
|
366 |
+
):
|
367 |
+
super().__init__()
|
368 |
+
self.z_dim = z_dim
|
369 |
+
self.c_dim = c_dim
|
370 |
+
self.w_dim = w_dim
|
371 |
+
self.num_ws = num_ws
|
372 |
+
self.num_layers = num_layers
|
373 |
+
self.w_avg_beta = w_avg_beta
|
374 |
+
self.torch_dtype = torch_dtype
|
375 |
+
|
376 |
+
if embed_features is None:
|
377 |
+
embed_features = w_dim
|
378 |
+
if c_dim == 0:
|
379 |
+
embed_features = 0
|
380 |
+
if layer_features is None:
|
381 |
+
layer_features = w_dim
|
382 |
+
features_list = (
|
383 |
+
[z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
384 |
+
)
|
385 |
+
|
386 |
+
if c_dim > 0:
|
387 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
388 |
+
for idx in range(num_layers):
|
389 |
+
in_features = features_list[idx]
|
390 |
+
out_features = features_list[idx + 1]
|
391 |
+
layer = FullyConnectedLayer(
|
392 |
+
in_features,
|
393 |
+
out_features,
|
394 |
+
activation=activation,
|
395 |
+
lr_multiplier=lr_multiplier,
|
396 |
+
)
|
397 |
+
setattr(self, f"fc{idx}", layer)
|
398 |
+
|
399 |
+
if num_ws is not None and w_avg_beta is not None:
|
400 |
+
self.register_buffer("w_avg", torch.zeros([w_dim]))
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
|
404 |
+
):
|
405 |
+
# Embed, normalize, and concat inputs.
|
406 |
+
x = None
|
407 |
+
if self.z_dim > 0:
|
408 |
+
x = normalize_2nd_moment(z)
|
409 |
+
if self.c_dim > 0:
|
410 |
+
y = normalize_2nd_moment(self.embed(c))
|
411 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
412 |
+
|
413 |
+
# Main layers.
|
414 |
+
for idx in range(self.num_layers):
|
415 |
+
layer = getattr(self, f"fc{idx}")
|
416 |
+
x = layer(x)
|
417 |
+
|
418 |
+
# Update moving average of W.
|
419 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
420 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
421 |
+
|
422 |
+
# Broadcast.
|
423 |
+
if self.num_ws is not None:
|
424 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
425 |
+
|
426 |
+
# Apply truncation.
|
427 |
+
if truncation_psi != 1:
|
428 |
+
assert self.w_avg_beta is not None
|
429 |
+
if self.num_ws is None or truncation_cutoff is None:
|
430 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
431 |
+
else:
|
432 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(
|
433 |
+
x[:, :truncation_cutoff], truncation_psi
|
434 |
+
)
|
435 |
+
|
436 |
+
return x
|
437 |
+
|
438 |
+
|
439 |
+
class DisFromRGB(nn.Module):
|
440 |
+
def __init__(
|
441 |
+
self, in_channels, out_channels, activation
|
442 |
+
): # res = 2, ..., resolution_log2
|
443 |
+
super().__init__()
|
444 |
+
self.conv = Conv2dLayer(
|
445 |
+
in_channels=in_channels,
|
446 |
+
out_channels=out_channels,
|
447 |
+
kernel_size=1,
|
448 |
+
activation=activation,
|
449 |
+
)
|
450 |
+
|
451 |
+
def forward(self, x):
|
452 |
+
return self.conv(x)
|
453 |
+
|
454 |
+
|
455 |
+
class DisBlock(nn.Module):
|
456 |
+
def __init__(
|
457 |
+
self, in_channels, out_channels, activation
|
458 |
+
): # res = 2, ..., resolution_log2
|
459 |
+
super().__init__()
|
460 |
+
self.conv0 = Conv2dLayer(
|
461 |
+
in_channels=in_channels,
|
462 |
+
out_channels=in_channels,
|
463 |
+
kernel_size=3,
|
464 |
+
activation=activation,
|
465 |
+
)
|
466 |
+
self.conv1 = Conv2dLayer(
|
467 |
+
in_channels=in_channels,
|
468 |
+
out_channels=out_channels,
|
469 |
+
kernel_size=3,
|
470 |
+
down=2,
|
471 |
+
activation=activation,
|
472 |
+
)
|
473 |
+
self.skip = Conv2dLayer(
|
474 |
+
in_channels=in_channels,
|
475 |
+
out_channels=out_channels,
|
476 |
+
kernel_size=1,
|
477 |
+
down=2,
|
478 |
+
bias=False,
|
479 |
+
)
|
480 |
+
|
481 |
+
def forward(self, x):
|
482 |
+
skip = self.skip(x, gain=np.sqrt(0.5))
|
483 |
+
x = self.conv0(x)
|
484 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
485 |
+
out = skip + x
|
486 |
+
|
487 |
+
return out
|
488 |
+
|
489 |
+
|
490 |
+
class Discriminator(torch.nn.Module):
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
c_dim, # Conditioning label (C) dimensionality.
|
494 |
+
img_resolution, # Input resolution.
|
495 |
+
img_channels, # Number of input color channels.
|
496 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
497 |
+
channel_max=512, # Maximum number of channels in any layer.
|
498 |
+
channel_decay=1,
|
499 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
500 |
+
activation="lrelu",
|
501 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
502 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
503 |
+
):
|
504 |
+
super().__init__()
|
505 |
+
self.c_dim = c_dim
|
506 |
+
self.img_resolution = img_resolution
|
507 |
+
self.img_channels = img_channels
|
508 |
+
|
509 |
+
resolution_log2 = int(np.log2(img_resolution))
|
510 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
511 |
+
self.resolution_log2 = resolution_log2
|
512 |
+
|
513 |
+
def nf(stage):
|
514 |
+
return np.clip(
|
515 |
+
int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max
|
516 |
+
)
|
517 |
+
|
518 |
+
if cmap_dim == None:
|
519 |
+
cmap_dim = nf(2)
|
520 |
+
if c_dim == 0:
|
521 |
+
cmap_dim = 0
|
522 |
+
self.cmap_dim = cmap_dim
|
523 |
+
|
524 |
+
if c_dim > 0:
|
525 |
+
self.mapping = MappingNet(
|
526 |
+
z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
|
527 |
+
)
|
528 |
+
|
529 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
530 |
+
for res in range(resolution_log2, 2, -1):
|
531 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
532 |
+
|
533 |
+
if mbstd_num_channels > 0:
|
534 |
+
Dis.append(
|
535 |
+
MinibatchStdLayer(
|
536 |
+
group_size=mbstd_group_size, num_channels=mbstd_num_channels
|
537 |
+
)
|
538 |
+
)
|
539 |
+
Dis.append(
|
540 |
+
Conv2dLayer(
|
541 |
+
nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
|
542 |
+
)
|
543 |
+
)
|
544 |
+
self.Dis = nn.Sequential(*Dis)
|
545 |
+
|
546 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
547 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
548 |
+
|
549 |
+
def forward(self, images_in, masks_in, c):
|
550 |
+
x = torch.cat([masks_in - 0.5, images_in], dim=1)
|
551 |
+
x = self.Dis(x)
|
552 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
553 |
+
|
554 |
+
if self.c_dim > 0:
|
555 |
+
cmap = self.mapping(None, c)
|
556 |
+
|
557 |
+
if self.cmap_dim > 0:
|
558 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
559 |
+
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
|
564 |
+
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
|
565 |
+
return NF[2 ** stage]
|
566 |
+
|
567 |
+
|
568 |
+
class Mlp(nn.Module):
|
569 |
+
def __init__(
|
570 |
+
self,
|
571 |
+
in_features,
|
572 |
+
hidden_features=None,
|
573 |
+
out_features=None,
|
574 |
+
act_layer=nn.GELU,
|
575 |
+
drop=0.0,
|
576 |
+
):
|
577 |
+
super().__init__()
|
578 |
+
out_features = out_features or in_features
|
579 |
+
hidden_features = hidden_features or in_features
|
580 |
+
self.fc1 = FullyConnectedLayer(
|
581 |
+
in_features=in_features, out_features=hidden_features, activation="lrelu"
|
582 |
+
)
|
583 |
+
self.fc2 = FullyConnectedLayer(
|
584 |
+
in_features=hidden_features, out_features=out_features
|
585 |
+
)
|
586 |
+
|
587 |
+
def forward(self, x):
|
588 |
+
x = self.fc1(x)
|
589 |
+
x = self.fc2(x)
|
590 |
+
return x
|
591 |
+
|
592 |
+
|
593 |
+
def window_partition(x, window_size):
|
594 |
+
"""
|
595 |
+
Args:
|
596 |
+
x: (B, H, W, C)
|
597 |
+
window_size (int): window size
|
598 |
+
Returns:
|
599 |
+
windows: (num_windows*B, window_size, window_size, C)
|
600 |
+
"""
|
601 |
+
B, H, W, C = x.shape
|
602 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
603 |
+
windows = (
|
604 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
605 |
+
)
|
606 |
+
return windows
|
607 |
+
|
608 |
+
|
609 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
610 |
+
"""
|
611 |
+
Args:
|
612 |
+
windows: (num_windows*B, window_size, window_size, C)
|
613 |
+
window_size (int): Window size
|
614 |
+
H (int): Height of image
|
615 |
+
W (int): Width of image
|
616 |
+
Returns:
|
617 |
+
x: (B, H, W, C)
|
618 |
+
"""
|
619 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
620 |
+
# B = windows.shape[0] / (H * W / window_size / window_size)
|
621 |
+
x = windows.view(
|
622 |
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
623 |
+
)
|
624 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
625 |
+
return x
|
626 |
+
|
627 |
+
|
628 |
+
class Conv2dLayerPartial(nn.Module):
|
629 |
+
def __init__(
|
630 |
+
self,
|
631 |
+
in_channels, # Number of input channels.
|
632 |
+
out_channels, # Number of output channels.
|
633 |
+
kernel_size, # Width and height of the convolution kernel.
|
634 |
+
bias=True, # Apply additive bias before the activation function?
|
635 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
636 |
+
up=1, # Integer upsampling factor.
|
637 |
+
down=1, # Integer downsampling factor.
|
638 |
+
resample_filter=[
|
639 |
+
1,
|
640 |
+
3,
|
641 |
+
3,
|
642 |
+
1,
|
643 |
+
], # Low-pass filter to apply when resampling activations.
|
644 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
645 |
+
trainable=True, # Update the weights of this layer during training?
|
646 |
+
):
|
647 |
+
super().__init__()
|
648 |
+
self.conv = Conv2dLayer(
|
649 |
+
in_channels,
|
650 |
+
out_channels,
|
651 |
+
kernel_size,
|
652 |
+
bias,
|
653 |
+
activation,
|
654 |
+
up,
|
655 |
+
down,
|
656 |
+
resample_filter,
|
657 |
+
conv_clamp,
|
658 |
+
trainable,
|
659 |
+
)
|
660 |
+
|
661 |
+
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
|
662 |
+
self.slide_winsize = kernel_size ** 2
|
663 |
+
self.stride = down
|
664 |
+
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
|
665 |
+
|
666 |
+
def forward(self, x, mask=None):
|
667 |
+
if mask is not None:
|
668 |
+
with torch.no_grad():
|
669 |
+
if self.weight_maskUpdater.type() != x.type():
|
670 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(x)
|
671 |
+
update_mask = F.conv2d(
|
672 |
+
mask,
|
673 |
+
self.weight_maskUpdater,
|
674 |
+
bias=None,
|
675 |
+
stride=self.stride,
|
676 |
+
padding=self.padding,
|
677 |
+
)
|
678 |
+
mask_ratio = self.slide_winsize / (update_mask.to(torch.float32) + 1e-8)
|
679 |
+
update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
|
680 |
+
mask_ratio = torch.mul(mask_ratio, update_mask).to(x.dtype)
|
681 |
+
x = self.conv(x)
|
682 |
+
x = torch.mul(x, mask_ratio)
|
683 |
+
return x, update_mask
|
684 |
+
else:
|
685 |
+
x = self.conv(x)
|
686 |
+
return x, None
|
687 |
+
|
688 |
+
|
689 |
+
class WindowAttention(nn.Module):
|
690 |
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
691 |
+
It supports both of shifted and non-shifted window.
|
692 |
+
Args:
|
693 |
+
dim (int): Number of input channels.
|
694 |
+
window_size (tuple[int]): The height and width of the window.
|
695 |
+
num_heads (int): Number of attention heads.
|
696 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
697 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
698 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
699 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
700 |
+
"""
|
701 |
+
|
702 |
+
def __init__(
|
703 |
+
self,
|
704 |
+
dim,
|
705 |
+
window_size,
|
706 |
+
num_heads,
|
707 |
+
down_ratio=1,
|
708 |
+
qkv_bias=True,
|
709 |
+
qk_scale=None,
|
710 |
+
attn_drop=0.0,
|
711 |
+
proj_drop=0.0,
|
712 |
+
):
|
713 |
+
super().__init__()
|
714 |
+
self.dim = dim
|
715 |
+
self.window_size = window_size # Wh, Ww
|
716 |
+
self.num_heads = num_heads
|
717 |
+
head_dim = dim // num_heads
|
718 |
+
self.scale = qk_scale or head_dim ** -0.5
|
719 |
+
|
720 |
+
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
|
721 |
+
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
|
722 |
+
self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
|
723 |
+
self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
|
724 |
+
|
725 |
+
self.softmax = nn.Softmax(dim=-1)
|
726 |
+
|
727 |
+
def forward(self, x, mask_windows=None, mask=None):
|
728 |
+
"""
|
729 |
+
Args:
|
730 |
+
x: input features with shape of (num_windows*B, N, C)
|
731 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
732 |
+
"""
|
733 |
+
B_, N, C = x.shape
|
734 |
+
norm_x = F.normalize(x, p=2.0, dim=-1, eps=torch.finfo(x.dtype).eps)
|
735 |
+
q = (
|
736 |
+
self.q(norm_x)
|
737 |
+
.reshape(B_, N, self.num_heads, C // self.num_heads)
|
738 |
+
.permute(0, 2, 1, 3)
|
739 |
+
)
|
740 |
+
k = (
|
741 |
+
self.k(norm_x)
|
742 |
+
.view(B_, -1, self.num_heads, C // self.num_heads)
|
743 |
+
.permute(0, 2, 3, 1)
|
744 |
+
)
|
745 |
+
v = (
|
746 |
+
self.v(x)
|
747 |
+
.view(B_, -1, self.num_heads, C // self.num_heads)
|
748 |
+
.permute(0, 2, 1, 3)
|
749 |
+
)
|
750 |
+
|
751 |
+
attn = (q @ k) * self.scale
|
752 |
+
|
753 |
+
if mask is not None:
|
754 |
+
nW = mask.shape[0]
|
755 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
756 |
+
1
|
757 |
+
).unsqueeze(0)
|
758 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
759 |
+
|
760 |
+
if mask_windows is not None:
|
761 |
+
attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
|
762 |
+
attn = attn + attn_mask_windows.masked_fill(
|
763 |
+
attn_mask_windows == 0, float(-100.0)
|
764 |
+
).masked_fill(attn_mask_windows == 1, float(0.0))
|
765 |
+
with torch.no_grad():
|
766 |
+
mask_windows = torch.clamp(
|
767 |
+
torch.sum(mask_windows, dim=1, keepdim=True), 0, 1
|
768 |
+
).repeat(1, N, 1)
|
769 |
+
|
770 |
+
attn = self.softmax(attn)
|
771 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
772 |
+
x = self.proj(x)
|
773 |
+
return x, mask_windows
|
774 |
+
|
775 |
+
|
776 |
+
class SwinTransformerBlock(nn.Module):
|
777 |
+
r"""Swin Transformer Block.
|
778 |
+
Args:
|
779 |
+
dim (int): Number of input channels.
|
780 |
+
input_resolution (tuple[int]): Input resulotion.
|
781 |
+
num_heads (int): Number of attention heads.
|
782 |
+
window_size (int): Window size.
|
783 |
+
shift_size (int): Shift size for SW-MSA.
|
784 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
785 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
786 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
787 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
788 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
789 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
790 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
791 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
792 |
+
"""
|
793 |
+
|
794 |
+
def __init__(
|
795 |
+
self,
|
796 |
+
dim,
|
797 |
+
input_resolution,
|
798 |
+
num_heads,
|
799 |
+
down_ratio=1,
|
800 |
+
window_size=7,
|
801 |
+
shift_size=0,
|
802 |
+
mlp_ratio=4.0,
|
803 |
+
qkv_bias=True,
|
804 |
+
qk_scale=None,
|
805 |
+
drop=0.0,
|
806 |
+
attn_drop=0.0,
|
807 |
+
drop_path=0.0,
|
808 |
+
act_layer=nn.GELU,
|
809 |
+
norm_layer=nn.LayerNorm,
|
810 |
+
):
|
811 |
+
super().__init__()
|
812 |
+
self.dim = dim
|
813 |
+
self.input_resolution = input_resolution
|
814 |
+
self.num_heads = num_heads
|
815 |
+
self.window_size = window_size
|
816 |
+
self.shift_size = shift_size
|
817 |
+
self.mlp_ratio = mlp_ratio
|
818 |
+
if min(self.input_resolution) <= self.window_size:
|
819 |
+
# if window size is larger than input resolution, we don't partition windows
|
820 |
+
self.shift_size = 0
|
821 |
+
self.window_size = min(self.input_resolution)
|
822 |
+
assert (
|
823 |
+
0 <= self.shift_size < self.window_size
|
824 |
+
), "shift_size must in 0-window_size"
|
825 |
+
|
826 |
+
if self.shift_size > 0:
|
827 |
+
down_ratio = 1
|
828 |
+
self.attn = WindowAttention(
|
829 |
+
dim,
|
830 |
+
window_size=to_2tuple(self.window_size),
|
831 |
+
num_heads=num_heads,
|
832 |
+
down_ratio=down_ratio,
|
833 |
+
qkv_bias=qkv_bias,
|
834 |
+
qk_scale=qk_scale,
|
835 |
+
attn_drop=attn_drop,
|
836 |
+
proj_drop=drop,
|
837 |
+
)
|
838 |
+
|
839 |
+
self.fuse = FullyConnectedLayer(
|
840 |
+
in_features=dim * 2, out_features=dim, activation="lrelu"
|
841 |
+
)
|
842 |
+
|
843 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
844 |
+
self.mlp = Mlp(
|
845 |
+
in_features=dim,
|
846 |
+
hidden_features=mlp_hidden_dim,
|
847 |
+
act_layer=act_layer,
|
848 |
+
drop=drop,
|
849 |
+
)
|
850 |
+
|
851 |
+
if self.shift_size > 0:
|
852 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
853 |
+
else:
|
854 |
+
attn_mask = None
|
855 |
+
|
856 |
+
self.register_buffer("attn_mask", attn_mask)
|
857 |
+
|
858 |
+
def calculate_mask(self, x_size):
|
859 |
+
# calculate attention mask for SW-MSA
|
860 |
+
H, W = x_size
|
861 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
862 |
+
h_slices = (
|
863 |
+
slice(0, -self.window_size),
|
864 |
+
slice(-self.window_size, -self.shift_size),
|
865 |
+
slice(-self.shift_size, None),
|
866 |
+
)
|
867 |
+
w_slices = (
|
868 |
+
slice(0, -self.window_size),
|
869 |
+
slice(-self.window_size, -self.shift_size),
|
870 |
+
slice(-self.shift_size, None),
|
871 |
+
)
|
872 |
+
cnt = 0
|
873 |
+
for h in h_slices:
|
874 |
+
for w in w_slices:
|
875 |
+
img_mask[:, h, w, :] = cnt
|
876 |
+
cnt += 1
|
877 |
+
|
878 |
+
mask_windows = window_partition(
|
879 |
+
img_mask, self.window_size
|
880 |
+
) # nW, window_size, window_size, 1
|
881 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
882 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
883 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
884 |
+
attn_mask == 0, float(0.0)
|
885 |
+
)
|
886 |
+
|
887 |
+
return attn_mask
|
888 |
+
|
889 |
+
def forward(self, x, x_size, mask=None):
|
890 |
+
# H, W = self.input_resolution
|
891 |
+
H, W = x_size
|
892 |
+
B, L, C = x.shape
|
893 |
+
# assert L == H * W, "input feature has wrong size"
|
894 |
+
|
895 |
+
shortcut = x
|
896 |
+
x = x.view(B, H, W, C)
|
897 |
+
if mask is not None:
|
898 |
+
mask = mask.view(B, H, W, 1)
|
899 |
+
|
900 |
+
# cyclic shift
|
901 |
+
if self.shift_size > 0:
|
902 |
+
shifted_x = torch.roll(
|
903 |
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
904 |
+
)
|
905 |
+
if mask is not None:
|
906 |
+
shifted_mask = torch.roll(
|
907 |
+
mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
908 |
+
)
|
909 |
+
else:
|
910 |
+
shifted_x = x
|
911 |
+
if mask is not None:
|
912 |
+
shifted_mask = mask
|
913 |
+
|
914 |
+
# partition windows
|
915 |
+
x_windows = window_partition(
|
916 |
+
shifted_x, self.window_size
|
917 |
+
) # nW*B, window_size, window_size, C
|
918 |
+
x_windows = x_windows.view(
|
919 |
+
-1, self.window_size * self.window_size, C
|
920 |
+
) # nW*B, window_size*window_size, C
|
921 |
+
if mask is not None:
|
922 |
+
mask_windows = window_partition(shifted_mask, self.window_size)
|
923 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
|
924 |
+
else:
|
925 |
+
mask_windows = None
|
926 |
+
|
927 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
928 |
+
if self.input_resolution == x_size:
|
929 |
+
attn_windows, mask_windows = self.attn(
|
930 |
+
x_windows, mask_windows, mask=self.attn_mask
|
931 |
+
) # nW*B, window_size*window_size, C
|
932 |
+
else:
|
933 |
+
attn_windows, mask_windows = self.attn(
|
934 |
+
x_windows,
|
935 |
+
mask_windows,
|
936 |
+
mask=self.calculate_mask(x_size).to(x.dtype).to(x.device),
|
937 |
+
) # nW*B, window_size*window_size, C
|
938 |
+
|
939 |
+
# merge windows
|
940 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
941 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
942 |
+
if mask is not None:
|
943 |
+
mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
|
944 |
+
shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
|
945 |
+
|
946 |
+
# reverse cyclic shift
|
947 |
+
if self.shift_size > 0:
|
948 |
+
x = torch.roll(
|
949 |
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
950 |
+
)
|
951 |
+
if mask is not None:
|
952 |
+
mask = torch.roll(
|
953 |
+
shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
954 |
+
)
|
955 |
+
else:
|
956 |
+
x = shifted_x
|
957 |
+
if mask is not None:
|
958 |
+
mask = shifted_mask
|
959 |
+
x = x.view(B, H * W, C)
|
960 |
+
if mask is not None:
|
961 |
+
mask = mask.view(B, H * W, 1)
|
962 |
+
|
963 |
+
# FFN
|
964 |
+
x = self.fuse(torch.cat([shortcut, x], dim=-1))
|
965 |
+
x = self.mlp(x)
|
966 |
+
|
967 |
+
return x, mask
|
968 |
+
|
969 |
+
|
970 |
+
class PatchMerging(nn.Module):
|
971 |
+
def __init__(self, in_channels, out_channels, down=2):
|
972 |
+
super().__init__()
|
973 |
+
self.conv = Conv2dLayerPartial(
|
974 |
+
in_channels=in_channels,
|
975 |
+
out_channels=out_channels,
|
976 |
+
kernel_size=3,
|
977 |
+
activation="lrelu",
|
978 |
+
down=down,
|
979 |
+
)
|
980 |
+
self.down = down
|
981 |
+
|
982 |
+
def forward(self, x, x_size, mask=None):
|
983 |
+
x = token2feature(x, x_size)
|
984 |
+
if mask is not None:
|
985 |
+
mask = token2feature(mask, x_size)
|
986 |
+
x, mask = self.conv(x, mask)
|
987 |
+
if self.down != 1:
|
988 |
+
ratio = 1 / self.down
|
989 |
+
x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
|
990 |
+
x = feature2token(x)
|
991 |
+
if mask is not None:
|
992 |
+
mask = feature2token(mask)
|
993 |
+
return x, x_size, mask
|
994 |
+
|
995 |
+
|
996 |
+
class PatchUpsampling(nn.Module):
|
997 |
+
def __init__(self, in_channels, out_channels, up=2):
|
998 |
+
super().__init__()
|
999 |
+
self.conv = Conv2dLayerPartial(
|
1000 |
+
in_channels=in_channels,
|
1001 |
+
out_channels=out_channels,
|
1002 |
+
kernel_size=3,
|
1003 |
+
activation="lrelu",
|
1004 |
+
up=up,
|
1005 |
+
)
|
1006 |
+
self.up = up
|
1007 |
+
|
1008 |
+
def forward(self, x, x_size, mask=None):
|
1009 |
+
x = token2feature(x, x_size)
|
1010 |
+
if mask is not None:
|
1011 |
+
mask = token2feature(mask, x_size)
|
1012 |
+
x, mask = self.conv(x, mask)
|
1013 |
+
if self.up != 1:
|
1014 |
+
x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
|
1015 |
+
x = feature2token(x)
|
1016 |
+
if mask is not None:
|
1017 |
+
mask = feature2token(mask)
|
1018 |
+
return x, x_size, mask
|
1019 |
+
|
1020 |
+
|
1021 |
+
class BasicLayer(nn.Module):
|
1022 |
+
"""A basic Swin Transformer layer for one stage.
|
1023 |
+
Args:
|
1024 |
+
dim (int): Number of input channels.
|
1025 |
+
input_resolution (tuple[int]): Input resolution.
|
1026 |
+
depth (int): Number of blocks.
|
1027 |
+
num_heads (int): Number of attention heads.
|
1028 |
+
window_size (int): Local window size.
|
1029 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
1030 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
1031 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
1032 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
1033 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
1034 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
1035 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
1036 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
1037 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
1038 |
+
"""
|
1039 |
+
|
1040 |
+
def __init__(
|
1041 |
+
self,
|
1042 |
+
dim,
|
1043 |
+
input_resolution,
|
1044 |
+
depth,
|
1045 |
+
num_heads,
|
1046 |
+
window_size,
|
1047 |
+
down_ratio=1,
|
1048 |
+
mlp_ratio=2.0,
|
1049 |
+
qkv_bias=True,
|
1050 |
+
qk_scale=None,
|
1051 |
+
drop=0.0,
|
1052 |
+
attn_drop=0.0,
|
1053 |
+
drop_path=0.0,
|
1054 |
+
norm_layer=nn.LayerNorm,
|
1055 |
+
downsample=None,
|
1056 |
+
use_checkpoint=False,
|
1057 |
+
):
|
1058 |
+
super().__init__()
|
1059 |
+
self.dim = dim
|
1060 |
+
self.input_resolution = input_resolution
|
1061 |
+
self.depth = depth
|
1062 |
+
self.use_checkpoint = use_checkpoint
|
1063 |
+
|
1064 |
+
# patch merging layer
|
1065 |
+
if downsample is not None:
|
1066 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
1067 |
+
self.downsample = downsample
|
1068 |
+
else:
|
1069 |
+
self.downsample = None
|
1070 |
+
|
1071 |
+
# build blocks
|
1072 |
+
self.blocks = nn.ModuleList(
|
1073 |
+
[
|
1074 |
+
SwinTransformerBlock(
|
1075 |
+
dim=dim,
|
1076 |
+
input_resolution=input_resolution,
|
1077 |
+
num_heads=num_heads,
|
1078 |
+
down_ratio=down_ratio,
|
1079 |
+
window_size=window_size,
|
1080 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
1081 |
+
mlp_ratio=mlp_ratio,
|
1082 |
+
qkv_bias=qkv_bias,
|
1083 |
+
qk_scale=qk_scale,
|
1084 |
+
drop=drop,
|
1085 |
+
attn_drop=attn_drop,
|
1086 |
+
drop_path=drop_path[i]
|
1087 |
+
if isinstance(drop_path, list)
|
1088 |
+
else drop_path,
|
1089 |
+
norm_layer=norm_layer,
|
1090 |
+
)
|
1091 |
+
for i in range(depth)
|
1092 |
+
]
|
1093 |
+
)
|
1094 |
+
|
1095 |
+
self.conv = Conv2dLayerPartial(
|
1096 |
+
in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu"
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
def forward(self, x, x_size, mask=None):
|
1100 |
+
if self.downsample is not None:
|
1101 |
+
x, x_size, mask = self.downsample(x, x_size, mask)
|
1102 |
+
identity = x
|
1103 |
+
for blk in self.blocks:
|
1104 |
+
if self.use_checkpoint:
|
1105 |
+
x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
|
1106 |
+
else:
|
1107 |
+
x, mask = blk(x, x_size, mask)
|
1108 |
+
if mask is not None:
|
1109 |
+
mask = token2feature(mask, x_size)
|
1110 |
+
x, mask = self.conv(token2feature(x, x_size), mask)
|
1111 |
+
x = feature2token(x) + identity
|
1112 |
+
if mask is not None:
|
1113 |
+
mask = feature2token(mask)
|
1114 |
+
return x, x_size, mask
|
1115 |
+
|
1116 |
+
|
1117 |
+
class ToToken(nn.Module):
|
1118 |
+
def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
|
1119 |
+
super().__init__()
|
1120 |
+
|
1121 |
+
self.proj = Conv2dLayerPartial(
|
1122 |
+
in_channels=in_channels,
|
1123 |
+
out_channels=dim,
|
1124 |
+
kernel_size=kernel_size,
|
1125 |
+
activation="lrelu",
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
def forward(self, x, mask):
|
1129 |
+
x, mask = self.proj(x, mask)
|
1130 |
+
|
1131 |
+
return x, mask
|
1132 |
+
|
1133 |
+
|
1134 |
+
class EncFromRGB(nn.Module):
|
1135 |
+
def __init__(
|
1136 |
+
self, in_channels, out_channels, activation
|
1137 |
+
): # res = 2, ..., resolution_log2
|
1138 |
+
super().__init__()
|
1139 |
+
self.conv0 = Conv2dLayer(
|
1140 |
+
in_channels=in_channels,
|
1141 |
+
out_channels=out_channels,
|
1142 |
+
kernel_size=1,
|
1143 |
+
activation=activation,
|
1144 |
+
)
|
1145 |
+
self.conv1 = Conv2dLayer(
|
1146 |
+
in_channels=out_channels,
|
1147 |
+
out_channels=out_channels,
|
1148 |
+
kernel_size=3,
|
1149 |
+
activation=activation,
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
def forward(self, x):
|
1153 |
+
x = self.conv0(x)
|
1154 |
+
x = self.conv1(x)
|
1155 |
+
|
1156 |
+
return x
|
1157 |
+
|
1158 |
+
|
1159 |
+
class ConvBlockDown(nn.Module):
|
1160 |
+
def __init__(
|
1161 |
+
self, in_channels, out_channels, activation
|
1162 |
+
): # res = 2, ..., resolution_log
|
1163 |
+
super().__init__()
|
1164 |
+
|
1165 |
+
self.conv0 = Conv2dLayer(
|
1166 |
+
in_channels=in_channels,
|
1167 |
+
out_channels=out_channels,
|
1168 |
+
kernel_size=3,
|
1169 |
+
activation=activation,
|
1170 |
+
down=2,
|
1171 |
+
)
|
1172 |
+
self.conv1 = Conv2dLayer(
|
1173 |
+
in_channels=out_channels,
|
1174 |
+
out_channels=out_channels,
|
1175 |
+
kernel_size=3,
|
1176 |
+
activation=activation,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
def forward(self, x):
|
1180 |
+
x = self.conv0(x)
|
1181 |
+
x = self.conv1(x)
|
1182 |
+
|
1183 |
+
return x
|
1184 |
+
|
1185 |
+
|
1186 |
+
def token2feature(x, x_size):
|
1187 |
+
B, N, C = x.shape
|
1188 |
+
h, w = x_size
|
1189 |
+
x = x.permute(0, 2, 1).reshape(B, C, h, w)
|
1190 |
+
return x
|
1191 |
+
|
1192 |
+
|
1193 |
+
def feature2token(x):
|
1194 |
+
B, C, H, W = x.shape
|
1195 |
+
x = x.view(B, C, -1).transpose(1, 2)
|
1196 |
+
return x
|
1197 |
+
|
1198 |
+
|
1199 |
+
class Encoder(nn.Module):
|
1200 |
+
def __init__(
|
1201 |
+
self,
|
1202 |
+
res_log2,
|
1203 |
+
img_channels,
|
1204 |
+
activation,
|
1205 |
+
patch_size=5,
|
1206 |
+
channels=16,
|
1207 |
+
drop_path_rate=0.1,
|
1208 |
+
):
|
1209 |
+
super().__init__()
|
1210 |
+
|
1211 |
+
self.resolution = []
|
1212 |
+
|
1213 |
+
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
|
1214 |
+
res = 2 ** i
|
1215 |
+
self.resolution.append(res)
|
1216 |
+
if i == res_log2:
|
1217 |
+
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
|
1218 |
+
else:
|
1219 |
+
block = ConvBlockDown(nf(i + 1), nf(i), activation)
|
1220 |
+
setattr(self, "EncConv_Block_%dx%d" % (res, res), block)
|
1221 |
+
|
1222 |
+
def forward(self, x):
|
1223 |
+
out = {}
|
1224 |
+
for res in self.resolution:
|
1225 |
+
res_log2 = int(np.log2(res))
|
1226 |
+
x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x)
|
1227 |
+
out[res_log2] = x
|
1228 |
+
|
1229 |
+
return out
|
1230 |
+
|
1231 |
+
|
1232 |
+
class ToStyle(nn.Module):
|
1233 |
+
def __init__(self, in_channels, out_channels, activation, drop_rate):
|
1234 |
+
super().__init__()
|
1235 |
+
self.conv = nn.Sequential(
|
1236 |
+
Conv2dLayer(
|
1237 |
+
in_channels=in_channels,
|
1238 |
+
out_channels=in_channels,
|
1239 |
+
kernel_size=3,
|
1240 |
+
activation=activation,
|
1241 |
+
down=2,
|
1242 |
+
),
|
1243 |
+
Conv2dLayer(
|
1244 |
+
in_channels=in_channels,
|
1245 |
+
out_channels=in_channels,
|
1246 |
+
kernel_size=3,
|
1247 |
+
activation=activation,
|
1248 |
+
down=2,
|
1249 |
+
),
|
1250 |
+
Conv2dLayer(
|
1251 |
+
in_channels=in_channels,
|
1252 |
+
out_channels=in_channels,
|
1253 |
+
kernel_size=3,
|
1254 |
+
activation=activation,
|
1255 |
+
down=2,
|
1256 |
+
),
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
1260 |
+
self.fc = FullyConnectedLayer(
|
1261 |
+
in_features=in_channels, out_features=out_channels, activation=activation
|
1262 |
+
)
|
1263 |
+
# self.dropout = nn.Dropout(drop_rate)
|
1264 |
+
|
1265 |
+
def forward(self, x):
|
1266 |
+
x = self.conv(x)
|
1267 |
+
x = self.pool(x)
|
1268 |
+
x = self.fc(x.flatten(start_dim=1))
|
1269 |
+
# x = self.dropout(x)
|
1270 |
+
|
1271 |
+
return x
|
1272 |
+
|
1273 |
+
|
1274 |
+
class DecBlockFirstV2(nn.Module):
|
1275 |
+
def __init__(
|
1276 |
+
self,
|
1277 |
+
res,
|
1278 |
+
in_channels,
|
1279 |
+
out_channels,
|
1280 |
+
activation,
|
1281 |
+
style_dim,
|
1282 |
+
use_noise,
|
1283 |
+
demodulate,
|
1284 |
+
img_channels,
|
1285 |
+
):
|
1286 |
+
super().__init__()
|
1287 |
+
self.res = res
|
1288 |
+
|
1289 |
+
self.conv0 = Conv2dLayer(
|
1290 |
+
in_channels=in_channels,
|
1291 |
+
out_channels=in_channels,
|
1292 |
+
kernel_size=3,
|
1293 |
+
activation=activation,
|
1294 |
+
)
|
1295 |
+
self.conv1 = StyleConv(
|
1296 |
+
in_channels=in_channels,
|
1297 |
+
out_channels=out_channels,
|
1298 |
+
style_dim=style_dim,
|
1299 |
+
resolution=2 ** res,
|
1300 |
+
kernel_size=3,
|
1301 |
+
use_noise=use_noise,
|
1302 |
+
activation=activation,
|
1303 |
+
demodulate=demodulate,
|
1304 |
+
)
|
1305 |
+
self.toRGB = ToRGB(
|
1306 |
+
in_channels=out_channels,
|
1307 |
+
out_channels=img_channels,
|
1308 |
+
style_dim=style_dim,
|
1309 |
+
kernel_size=1,
|
1310 |
+
demodulate=False,
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
def forward(self, x, ws, gs, E_features, noise_mode="random"):
|
1314 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
1315 |
+
x = self.conv0(x)
|
1316 |
+
x = x + E_features[self.res]
|
1317 |
+
style = get_style_code(ws[:, 0], gs)
|
1318 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1319 |
+
style = get_style_code(ws[:, 1], gs)
|
1320 |
+
img = self.toRGB(x, style, skip=None)
|
1321 |
+
|
1322 |
+
return x, img
|
1323 |
+
|
1324 |
+
|
1325 |
+
class DecBlock(nn.Module):
|
1326 |
+
def __init__(
|
1327 |
+
self,
|
1328 |
+
res,
|
1329 |
+
in_channels,
|
1330 |
+
out_channels,
|
1331 |
+
activation,
|
1332 |
+
style_dim,
|
1333 |
+
use_noise,
|
1334 |
+
demodulate,
|
1335 |
+
img_channels,
|
1336 |
+
): # res = 4, ..., resolution_log2
|
1337 |
+
super().__init__()
|
1338 |
+
self.res = res
|
1339 |
+
|
1340 |
+
self.conv0 = StyleConv(
|
1341 |
+
in_channels=in_channels,
|
1342 |
+
out_channels=out_channels,
|
1343 |
+
style_dim=style_dim,
|
1344 |
+
resolution=2 ** res,
|
1345 |
+
kernel_size=3,
|
1346 |
+
up=2,
|
1347 |
+
use_noise=use_noise,
|
1348 |
+
activation=activation,
|
1349 |
+
demodulate=demodulate,
|
1350 |
+
)
|
1351 |
+
self.conv1 = StyleConv(
|
1352 |
+
in_channels=out_channels,
|
1353 |
+
out_channels=out_channels,
|
1354 |
+
style_dim=style_dim,
|
1355 |
+
resolution=2 ** res,
|
1356 |
+
kernel_size=3,
|
1357 |
+
use_noise=use_noise,
|
1358 |
+
activation=activation,
|
1359 |
+
demodulate=demodulate,
|
1360 |
+
)
|
1361 |
+
self.toRGB = ToRGB(
|
1362 |
+
in_channels=out_channels,
|
1363 |
+
out_channels=img_channels,
|
1364 |
+
style_dim=style_dim,
|
1365 |
+
kernel_size=1,
|
1366 |
+
demodulate=False,
|
1367 |
+
)
|
1368 |
+
|
1369 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
|
1370 |
+
style = get_style_code(ws[:, self.res * 2 - 9], gs)
|
1371 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1372 |
+
x = x + E_features[self.res]
|
1373 |
+
style = get_style_code(ws[:, self.res * 2 - 8], gs)
|
1374 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1375 |
+
style = get_style_code(ws[:, self.res * 2 - 7], gs)
|
1376 |
+
img = self.toRGB(x, style, skip=img)
|
1377 |
+
|
1378 |
+
return x, img
|
1379 |
+
|
1380 |
+
|
1381 |
+
class Decoder(nn.Module):
|
1382 |
+
def __init__(
|
1383 |
+
self, res_log2, activation, style_dim, use_noise, demodulate, img_channels
|
1384 |
+
):
|
1385 |
+
super().__init__()
|
1386 |
+
self.Dec_16x16 = DecBlockFirstV2(
|
1387 |
+
4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels
|
1388 |
+
)
|
1389 |
+
for res in range(5, res_log2 + 1):
|
1390 |
+
setattr(
|
1391 |
+
self,
|
1392 |
+
"Dec_%dx%d" % (2 ** res, 2 ** res),
|
1393 |
+
DecBlock(
|
1394 |
+
res,
|
1395 |
+
nf(res - 1),
|
1396 |
+
nf(res),
|
1397 |
+
activation,
|
1398 |
+
style_dim,
|
1399 |
+
use_noise,
|
1400 |
+
demodulate,
|
1401 |
+
img_channels,
|
1402 |
+
),
|
1403 |
+
)
|
1404 |
+
self.res_log2 = res_log2
|
1405 |
+
|
1406 |
+
def forward(self, x, ws, gs, E_features, noise_mode="random"):
|
1407 |
+
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
|
1408 |
+
for res in range(5, self.res_log2 + 1):
|
1409 |
+
block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res))
|
1410 |
+
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
|
1411 |
+
|
1412 |
+
return img
|
1413 |
+
|
1414 |
+
|
1415 |
+
class DecStyleBlock(nn.Module):
|
1416 |
+
def __init__(
|
1417 |
+
self,
|
1418 |
+
res,
|
1419 |
+
in_channels,
|
1420 |
+
out_channels,
|
1421 |
+
activation,
|
1422 |
+
style_dim,
|
1423 |
+
use_noise,
|
1424 |
+
demodulate,
|
1425 |
+
img_channels,
|
1426 |
+
):
|
1427 |
+
super().__init__()
|
1428 |
+
self.res = res
|
1429 |
+
|
1430 |
+
self.conv0 = StyleConv(
|
1431 |
+
in_channels=in_channels,
|
1432 |
+
out_channels=out_channels,
|
1433 |
+
style_dim=style_dim,
|
1434 |
+
resolution=2 ** res,
|
1435 |
+
kernel_size=3,
|
1436 |
+
up=2,
|
1437 |
+
use_noise=use_noise,
|
1438 |
+
activation=activation,
|
1439 |
+
demodulate=demodulate,
|
1440 |
+
)
|
1441 |
+
self.conv1 = StyleConv(
|
1442 |
+
in_channels=out_channels,
|
1443 |
+
out_channels=out_channels,
|
1444 |
+
style_dim=style_dim,
|
1445 |
+
resolution=2 ** res,
|
1446 |
+
kernel_size=3,
|
1447 |
+
use_noise=use_noise,
|
1448 |
+
activation=activation,
|
1449 |
+
demodulate=demodulate,
|
1450 |
+
)
|
1451 |
+
self.toRGB = ToRGB(
|
1452 |
+
in_channels=out_channels,
|
1453 |
+
out_channels=img_channels,
|
1454 |
+
style_dim=style_dim,
|
1455 |
+
kernel_size=1,
|
1456 |
+
demodulate=False,
|
1457 |
+
)
|
1458 |
+
|
1459 |
+
def forward(self, x, img, style, skip, noise_mode="random"):
|
1460 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
1461 |
+
x = x + skip
|
1462 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
1463 |
+
img = self.toRGB(x, style, skip=img)
|
1464 |
+
|
1465 |
+
return x, img
|
1466 |
+
|
1467 |
+
|
1468 |
+
class FirstStage(nn.Module):
|
1469 |
+
def __init__(
|
1470 |
+
self,
|
1471 |
+
img_channels,
|
1472 |
+
img_resolution=256,
|
1473 |
+
dim=180,
|
1474 |
+
w_dim=512,
|
1475 |
+
use_noise=False,
|
1476 |
+
demodulate=True,
|
1477 |
+
activation="lrelu",
|
1478 |
+
):
|
1479 |
+
super().__init__()
|
1480 |
+
res = 64
|
1481 |
+
|
1482 |
+
self.conv_first = Conv2dLayerPartial(
|
1483 |
+
in_channels=img_channels + 1,
|
1484 |
+
out_channels=dim,
|
1485 |
+
kernel_size=3,
|
1486 |
+
activation=activation,
|
1487 |
+
)
|
1488 |
+
self.enc_conv = nn.ModuleList()
|
1489 |
+
down_time = int(np.log2(img_resolution // res))
|
1490 |
+
# 根据图片尺寸构建 swim transformer 的层数
|
1491 |
+
for i in range(down_time): # from input size to 64
|
1492 |
+
self.enc_conv.append(
|
1493 |
+
Conv2dLayerPartial(
|
1494 |
+
in_channels=dim,
|
1495 |
+
out_channels=dim,
|
1496 |
+
kernel_size=3,
|
1497 |
+
down=2,
|
1498 |
+
activation=activation,
|
1499 |
+
)
|
1500 |
+
)
|
1501 |
+
|
1502 |
+
# from 64 -> 16 -> 64
|
1503 |
+
depths = [2, 3, 4, 3, 2]
|
1504 |
+
ratios = [1, 1 / 2, 1 / 2, 2, 2]
|
1505 |
+
num_heads = 6
|
1506 |
+
window_sizes = [8, 16, 16, 16, 8]
|
1507 |
+
drop_path_rate = 0.1
|
1508 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
1509 |
+
|
1510 |
+
self.tran = nn.ModuleList()
|
1511 |
+
for i, depth in enumerate(depths):
|
1512 |
+
res = int(res * ratios[i])
|
1513 |
+
if ratios[i] < 1:
|
1514 |
+
merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
|
1515 |
+
elif ratios[i] > 1:
|
1516 |
+
merge = PatchUpsampling(dim, dim, up=ratios[i])
|
1517 |
+
else:
|
1518 |
+
merge = None
|
1519 |
+
self.tran.append(
|
1520 |
+
BasicLayer(
|
1521 |
+
dim=dim,
|
1522 |
+
input_resolution=[res, res],
|
1523 |
+
depth=depth,
|
1524 |
+
num_heads=num_heads,
|
1525 |
+
window_size=window_sizes[i],
|
1526 |
+
drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
|
1527 |
+
downsample=merge,
|
1528 |
+
)
|
1529 |
+
)
|
1530 |
+
|
1531 |
+
# global style
|
1532 |
+
down_conv = []
|
1533 |
+
for i in range(int(np.log2(16))):
|
1534 |
+
down_conv.append(
|
1535 |
+
Conv2dLayer(
|
1536 |
+
in_channels=dim,
|
1537 |
+
out_channels=dim,
|
1538 |
+
kernel_size=3,
|
1539 |
+
down=2,
|
1540 |
+
activation=activation,
|
1541 |
+
)
|
1542 |
+
)
|
1543 |
+
down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
|
1544 |
+
self.down_conv = nn.Sequential(*down_conv)
|
1545 |
+
self.to_style = FullyConnectedLayer(
|
1546 |
+
in_features=dim, out_features=dim * 2, activation=activation
|
1547 |
+
)
|
1548 |
+
self.ws_style = FullyConnectedLayer(
|
1549 |
+
in_features=w_dim, out_features=dim, activation=activation
|
1550 |
+
)
|
1551 |
+
self.to_square = FullyConnectedLayer(
|
1552 |
+
in_features=dim, out_features=16 * 16, activation=activation
|
1553 |
+
)
|
1554 |
+
|
1555 |
+
style_dim = dim * 3
|
1556 |
+
self.dec_conv = nn.ModuleList()
|
1557 |
+
for i in range(down_time): # from 64 to input size
|
1558 |
+
res = res * 2
|
1559 |
+
self.dec_conv.append(
|
1560 |
+
DecStyleBlock(
|
1561 |
+
res,
|
1562 |
+
dim,
|
1563 |
+
dim,
|
1564 |
+
activation,
|
1565 |
+
style_dim,
|
1566 |
+
use_noise,
|
1567 |
+
demodulate,
|
1568 |
+
img_channels,
|
1569 |
+
)
|
1570 |
+
)
|
1571 |
+
|
1572 |
+
def forward(self, images_in, masks_in, ws, noise_mode="random"):
|
1573 |
+
x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
|
1574 |
+
|
1575 |
+
skips = []
|
1576 |
+
x, mask = self.conv_first(x, masks_in) # input size
|
1577 |
+
skips.append(x)
|
1578 |
+
for i, block in enumerate(self.enc_conv): # input size to 64
|
1579 |
+
x, mask = block(x, mask)
|
1580 |
+
if i != len(self.enc_conv) - 1:
|
1581 |
+
skips.append(x)
|
1582 |
+
|
1583 |
+
x_size = x.size()[-2:]
|
1584 |
+
x = feature2token(x)
|
1585 |
+
mask = feature2token(mask)
|
1586 |
+
mid = len(self.tran) // 2
|
1587 |
+
for i, block in enumerate(self.tran): # 64 to 16
|
1588 |
+
if i < mid:
|
1589 |
+
x, x_size, mask = block(x, x_size, mask)
|
1590 |
+
skips.append(x)
|
1591 |
+
elif i > mid:
|
1592 |
+
x, x_size, mask = block(x, x_size, None)
|
1593 |
+
x = x + skips[mid - i]
|
1594 |
+
else:
|
1595 |
+
x, x_size, mask = block(x, x_size, None)
|
1596 |
+
|
1597 |
+
mul_map = torch.ones_like(x) * 0.5
|
1598 |
+
mul_map = F.dropout(mul_map, training=True)
|
1599 |
+
ws = self.ws_style(ws[:, -1])
|
1600 |
+
add_n = self.to_square(ws).unsqueeze(1)
|
1601 |
+
add_n = (
|
1602 |
+
F.interpolate(
|
1603 |
+
add_n, size=x.size(1), mode="linear", align_corners=False
|
1604 |
+
)
|
1605 |
+
.squeeze(1)
|
1606 |
+
.unsqueeze(-1)
|
1607 |
+
)
|
1608 |
+
x = x * mul_map + add_n * (1 - mul_map)
|
1609 |
+
gs = self.to_style(
|
1610 |
+
self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)
|
1611 |
+
)
|
1612 |
+
style = torch.cat([gs, ws], dim=1)
|
1613 |
+
|
1614 |
+
x = token2feature(x, x_size).contiguous()
|
1615 |
+
img = None
|
1616 |
+
for i, block in enumerate(self.dec_conv):
|
1617 |
+
x, img = block(
|
1618 |
+
x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode
|
1619 |
+
)
|
1620 |
+
|
1621 |
+
# ensemble
|
1622 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1623 |
+
|
1624 |
+
return img
|
1625 |
+
|
1626 |
+
|
1627 |
+
class SynthesisNet(nn.Module):
|
1628 |
+
def __init__(
|
1629 |
+
self,
|
1630 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1631 |
+
img_resolution, # Output image resolution.
|
1632 |
+
img_channels=3, # Number of color channels.
|
1633 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1634 |
+
channel_decay=1.0,
|
1635 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1636 |
+
activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
|
1637 |
+
drop_rate=0.5,
|
1638 |
+
use_noise=False,
|
1639 |
+
demodulate=True,
|
1640 |
+
):
|
1641 |
+
super().__init__()
|
1642 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1643 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1644 |
+
|
1645 |
+
self.num_layers = resolution_log2 * 2 - 3 * 2
|
1646 |
+
self.img_resolution = img_resolution
|
1647 |
+
self.resolution_log2 = resolution_log2
|
1648 |
+
|
1649 |
+
# first stage
|
1650 |
+
self.first_stage = FirstStage(
|
1651 |
+
img_channels,
|
1652 |
+
img_resolution=img_resolution,
|
1653 |
+
w_dim=w_dim,
|
1654 |
+
use_noise=False,
|
1655 |
+
demodulate=demodulate,
|
1656 |
+
)
|
1657 |
+
|
1658 |
+
# second stage
|
1659 |
+
self.enc = Encoder(
|
1660 |
+
resolution_log2, img_channels, activation, patch_size=5, channels=16
|
1661 |
+
)
|
1662 |
+
self.to_square = FullyConnectedLayer(
|
1663 |
+
in_features=w_dim, out_features=16 * 16, activation=activation
|
1664 |
+
)
|
1665 |
+
self.to_style = ToStyle(
|
1666 |
+
in_channels=nf(4),
|
1667 |
+
out_channels=nf(2) * 2,
|
1668 |
+
activation=activation,
|
1669 |
+
drop_rate=drop_rate,
|
1670 |
+
)
|
1671 |
+
style_dim = w_dim + nf(2) * 2
|
1672 |
+
self.dec = Decoder(
|
1673 |
+
resolution_log2, activation, style_dim, use_noise, demodulate, img_channels
|
1674 |
+
)
|
1675 |
+
|
1676 |
+
def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False):
|
1677 |
+
out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
|
1678 |
+
|
1679 |
+
# encoder
|
1680 |
+
x = images_in * masks_in + out_stg1 * (1 - masks_in)
|
1681 |
+
x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
|
1682 |
+
E_features = self.enc(x)
|
1683 |
+
|
1684 |
+
fea_16 = E_features[4]
|
1685 |
+
mul_map = torch.ones_like(fea_16) * 0.5
|
1686 |
+
mul_map = F.dropout(mul_map, training=True)
|
1687 |
+
add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
|
1688 |
+
add_n = F.interpolate(
|
1689 |
+
add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False
|
1690 |
+
)
|
1691 |
+
fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
|
1692 |
+
E_features[4] = fea_16
|
1693 |
+
|
1694 |
+
# style
|
1695 |
+
gs = self.to_style(fea_16)
|
1696 |
+
|
1697 |
+
# decoder
|
1698 |
+
img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
|
1699 |
+
|
1700 |
+
# ensemble
|
1701 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
1702 |
+
|
1703 |
+
if not return_stg1:
|
1704 |
+
return img
|
1705 |
+
else:
|
1706 |
+
return img, out_stg1
|
1707 |
+
|
1708 |
+
|
1709 |
+
class Generator(nn.Module):
|
1710 |
+
def __init__(
|
1711 |
+
self,
|
1712 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
1713 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
1714 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
1715 |
+
img_resolution, # resolution of generated image
|
1716 |
+
img_channels, # Number of input color channels.
|
1717 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
1718 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
1719 |
+
):
|
1720 |
+
super().__init__()
|
1721 |
+
self.z_dim = z_dim
|
1722 |
+
self.c_dim = c_dim
|
1723 |
+
self.w_dim = w_dim
|
1724 |
+
self.img_resolution = img_resolution
|
1725 |
+
self.img_channels = img_channels
|
1726 |
+
|
1727 |
+
self.synthesis = SynthesisNet(
|
1728 |
+
w_dim=w_dim,
|
1729 |
+
img_resolution=img_resolution,
|
1730 |
+
img_channels=img_channels,
|
1731 |
+
**synthesis_kwargs,
|
1732 |
+
)
|
1733 |
+
self.mapping = MappingNet(
|
1734 |
+
z_dim=z_dim,
|
1735 |
+
c_dim=c_dim,
|
1736 |
+
w_dim=w_dim,
|
1737 |
+
num_ws=self.synthesis.num_layers,
|
1738 |
+
**mapping_kwargs,
|
1739 |
+
)
|
1740 |
+
|
1741 |
+
def forward(
|
1742 |
+
self,
|
1743 |
+
images_in,
|
1744 |
+
masks_in,
|
1745 |
+
z,
|
1746 |
+
c,
|
1747 |
+
truncation_psi=1,
|
1748 |
+
truncation_cutoff=None,
|
1749 |
+
skip_w_avg_update=False,
|
1750 |
+
noise_mode="none",
|
1751 |
+
return_stg1=False,
|
1752 |
+
):
|
1753 |
+
ws = self.mapping(
|
1754 |
+
z,
|
1755 |
+
c,
|
1756 |
+
truncation_psi=truncation_psi,
|
1757 |
+
truncation_cutoff=truncation_cutoff,
|
1758 |
+
skip_w_avg_update=skip_w_avg_update,
|
1759 |
+
)
|
1760 |
+
img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
|
1761 |
+
return img
|
1762 |
+
|
1763 |
+
|
1764 |
+
class Discriminator(torch.nn.Module):
|
1765 |
+
def __init__(
|
1766 |
+
self,
|
1767 |
+
c_dim, # Conditioning label (C) dimensionality.
|
1768 |
+
img_resolution, # Input resolution.
|
1769 |
+
img_channels, # Number of input color channels.
|
1770 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
1771 |
+
channel_max=512, # Maximum number of channels in any layer.
|
1772 |
+
channel_decay=1,
|
1773 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
1774 |
+
activation="lrelu",
|
1775 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
1776 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
1777 |
+
):
|
1778 |
+
super().__init__()
|
1779 |
+
self.c_dim = c_dim
|
1780 |
+
self.img_resolution = img_resolution
|
1781 |
+
self.img_channels = img_channels
|
1782 |
+
|
1783 |
+
resolution_log2 = int(np.log2(img_resolution))
|
1784 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
1785 |
+
self.resolution_log2 = resolution_log2
|
1786 |
+
|
1787 |
+
if cmap_dim == None:
|
1788 |
+
cmap_dim = nf(2)
|
1789 |
+
if c_dim == 0:
|
1790 |
+
cmap_dim = 0
|
1791 |
+
self.cmap_dim = cmap_dim
|
1792 |
+
|
1793 |
+
if c_dim > 0:
|
1794 |
+
self.mapping = MappingNet(
|
1795 |
+
z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
|
1796 |
+
)
|
1797 |
+
|
1798 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
1799 |
+
for res in range(resolution_log2, 2, -1):
|
1800 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
1801 |
+
|
1802 |
+
if mbstd_num_channels > 0:
|
1803 |
+
Dis.append(
|
1804 |
+
MinibatchStdLayer(
|
1805 |
+
group_size=mbstd_group_size, num_channels=mbstd_num_channels
|
1806 |
+
)
|
1807 |
+
)
|
1808 |
+
Dis.append(
|
1809 |
+
Conv2dLayer(
|
1810 |
+
nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
|
1811 |
+
)
|
1812 |
+
)
|
1813 |
+
self.Dis = nn.Sequential(*Dis)
|
1814 |
+
|
1815 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
1816 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
1817 |
+
|
1818 |
+
# for 64x64
|
1819 |
+
Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
|
1820 |
+
for res in range(resolution_log2, 2, -1):
|
1821 |
+
Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
|
1822 |
+
|
1823 |
+
if mbstd_num_channels > 0:
|
1824 |
+
Dis_stg1.append(
|
1825 |
+
MinibatchStdLayer(
|
1826 |
+
group_size=mbstd_group_size, num_channels=mbstd_num_channels
|
1827 |
+
)
|
1828 |
+
)
|
1829 |
+
Dis_stg1.append(
|
1830 |
+
Conv2dLayer(
|
1831 |
+
nf(2) // 2 + mbstd_num_channels,
|
1832 |
+
nf(2) // 2,
|
1833 |
+
kernel_size=3,
|
1834 |
+
activation=activation,
|
1835 |
+
)
|
1836 |
+
)
|
1837 |
+
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
|
1838 |
+
|
1839 |
+
self.fc0_stg1 = FullyConnectedLayer(
|
1840 |
+
nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation
|
1841 |
+
)
|
1842 |
+
self.fc1_stg1 = FullyConnectedLayer(
|
1843 |
+
nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim
|
1844 |
+
)
|
1845 |
+
|
1846 |
+
def forward(self, images_in, masks_in, images_stg1, c):
|
1847 |
+
x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
|
1848 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
1849 |
+
|
1850 |
+
x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
|
1851 |
+
x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
|
1852 |
+
|
1853 |
+
if self.c_dim > 0:
|
1854 |
+
cmap = self.mapping(None, c)
|
1855 |
+
|
1856 |
+
if self.cmap_dim > 0:
|
1857 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
1858 |
+
x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (
|
1859 |
+
1 / np.sqrt(self.cmap_dim)
|
1860 |
+
)
|
1861 |
+
|
1862 |
+
return x, x_stg1
|
1863 |
+
|
1864 |
+
|
1865 |
+
MAT_MODEL_URL = os.environ.get(
|
1866 |
+
"MAT_MODEL_URL",
|
1867 |
+
"https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
|
1868 |
+
)
|
1869 |
+
|
1870 |
+
MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed")
|
1871 |
+
|
1872 |
+
|
1873 |
+
class MAT(InpaintModel):
|
1874 |
+
name = "mat"
|
1875 |
+
min_size = 512
|
1876 |
+
pad_mod = 512
|
1877 |
+
pad_to_square = True
|
1878 |
+
|
1879 |
+
def init_model(self, device, **kwargs):
|
1880 |
+
seed = 240 # pick up a random number
|
1881 |
+
set_seed(seed)
|
1882 |
+
|
1883 |
+
fp16 = not kwargs.get("no_half", False)
|
1884 |
+
use_gpu = "cuda" in str(device) and torch.cuda.is_available()
|
1885 |
+
self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
1886 |
+
|
1887 |
+
G = Generator(
|
1888 |
+
z_dim=512,
|
1889 |
+
c_dim=0,
|
1890 |
+
w_dim=512,
|
1891 |
+
img_resolution=512,
|
1892 |
+
img_channels=3,
|
1893 |
+
mapping_kwargs={"torch_dtype": self.torch_dtype},
|
1894 |
+
).to(self.torch_dtype)
|
1895 |
+
# fmt: off
|
1896 |
+
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5)
|
1897 |
+
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
|
1898 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
|
1899 |
+
# fmt: on
|
1900 |
+
|
1901 |
+
@staticmethod
|
1902 |
+
def is_downloaded() -> bool:
|
1903 |
+
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
1904 |
+
|
1905 |
+
def forward(self, image, mask, config: Config):
|
1906 |
+
"""Input images and output images have same size
|
1907 |
+
images: [H, W, C] RGB
|
1908 |
+
masks: [H, W] mask area == 255
|
1909 |
+
return: BGR IMAGE
|
1910 |
+
"""
|
1911 |
+
|
1912 |
+
image = norm_img(image) # [0, 1]
|
1913 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
1914 |
+
|
1915 |
+
mask = (mask > 127) * 255
|
1916 |
+
mask = 255 - mask
|
1917 |
+
mask = norm_img(mask)
|
1918 |
+
|
1919 |
+
image = (
|
1920 |
+
torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
|
1921 |
+
)
|
1922 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
|
1923 |
+
|
1924 |
+
output = self.model(
|
1925 |
+
image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"
|
1926 |
+
)
|
1927 |
+
output = (
|
1928 |
+
(output.permute(0, 2, 3, 1) * 127.5 + 127.5)
|
1929 |
+
.round()
|
1930 |
+
.clamp(0, 255)
|
1931 |
+
.to(torch.uint8)
|
1932 |
+
)
|
1933 |
+
output = output[0].cpu().numpy()
|
1934 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
1935 |
+
return cur_res
|
lama_cleaner/model/opencv2.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from lama_cleaner.model.base import InpaintModel
|
3 |
+
from lama_cleaner.schema import Config
|
4 |
+
|
5 |
+
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
6 |
+
|
7 |
+
|
8 |
+
class OpenCV2(InpaintModel):
|
9 |
+
name = "cv2"
|
10 |
+
pad_mod = 1
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def is_downloaded() -> bool:
|
14 |
+
return True
|
15 |
+
|
16 |
+
def forward(self, image, mask, config: Config):
|
17 |
+
"""Input image and output image have same size
|
18 |
+
image: [H, W, C] RGB
|
19 |
+
mask: [H, W, 1]
|
20 |
+
return: BGR IMAGE
|
21 |
+
"""
|
22 |
+
cur_res = cv2.inpaint(
|
23 |
+
image[:, :, ::-1],
|
24 |
+
mask,
|
25 |
+
inpaintRadius=config.cv2_radius,
|
26 |
+
flags=flag_map[config.cv2_flag],
|
27 |
+
)
|
28 |
+
return cur_res
|
lama_cleaner/model/paint_by_example.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import PIL.Image
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from lama_cleaner.model.base import DiffusionInpaintModel
|
9 |
+
from lama_cleaner.model.utils import set_seed
|
10 |
+
from lama_cleaner.schema import Config
|
11 |
+
|
12 |
+
|
13 |
+
class PaintByExample(DiffusionInpaintModel):
|
14 |
+
name = "paint_by_example"
|
15 |
+
pad_mod = 8
|
16 |
+
min_size = 512
|
17 |
+
|
18 |
+
def init_model(self, device: torch.device, **kwargs):
|
19 |
+
fp16 = not kwargs.get('no_half', False)
|
20 |
+
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
21 |
+
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
22 |
+
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
23 |
+
|
24 |
+
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
25 |
+
logger.info("Disable Paint By Example Model NSFW checker")
|
26 |
+
model_kwargs.update(dict(
|
27 |
+
safety_checker=None,
|
28 |
+
requires_safety_checker=False
|
29 |
+
))
|
30 |
+
|
31 |
+
self.model = DiffusionPipeline.from_pretrained(
|
32 |
+
"Fantasy-Studio/Paint-by-Example",
|
33 |
+
torch_dtype=torch_dtype,
|
34 |
+
**model_kwargs
|
35 |
+
)
|
36 |
+
|
37 |
+
self.model.enable_attention_slicing()
|
38 |
+
if kwargs.get('enable_xformers', False):
|
39 |
+
self.model.enable_xformers_memory_efficient_attention()
|
40 |
+
|
41 |
+
# TODO: gpu_id
|
42 |
+
if kwargs.get('cpu_offload', False) and use_gpu:
|
43 |
+
self.model.image_encoder = self.model.image_encoder.to(device)
|
44 |
+
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
45 |
+
else:
|
46 |
+
self.model = self.model.to(device)
|
47 |
+
|
48 |
+
def forward(self, image, mask, config: Config):
|
49 |
+
"""Input image and output image have same size
|
50 |
+
image: [H, W, C] RGB
|
51 |
+
mask: [H, W, 1] 255 means area to repaint
|
52 |
+
return: BGR IMAGE
|
53 |
+
"""
|
54 |
+
output = self.model(
|
55 |
+
image=PIL.Image.fromarray(image),
|
56 |
+
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
57 |
+
example_image=config.paint_by_example_example_image,
|
58 |
+
num_inference_steps=config.paint_by_example_steps,
|
59 |
+
output_type='np.array',
|
60 |
+
generator=torch.manual_seed(config.paint_by_example_seed)
|
61 |
+
).images[0]
|
62 |
+
|
63 |
+
output = (output * 255).round().astype("uint8")
|
64 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
65 |
+
return output
|
66 |
+
|
67 |
+
def forward_post_process(self, result, image, mask, config):
|
68 |
+
if config.paint_by_example_match_histograms:
|
69 |
+
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
70 |
+
|
71 |
+
if config.paint_by_example_mask_blur != 0:
|
72 |
+
k = 2 * config.paint_by_example_mask_blur + 1
|
73 |
+
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
74 |
+
return result, image, mask
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def is_downloaded() -> bool:
|
78 |
+
# model will be downloaded when app start, and can't switch in frontend settings
|
79 |
+
return True
|
lama_cleaner/model/pipeline/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pipeline_stable_diffusion_controlnet_inpaint import (
|
2 |
+
StableDiffusionControlNetInpaintPipeline,
|
3 |
+
)
|
lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import PIL.Image
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
|
22 |
+
|
23 |
+
EXAMPLE_DOC_STRING = """
|
24 |
+
Examples:
|
25 |
+
```py
|
26 |
+
>>> # !pip install opencv-python transformers accelerate
|
27 |
+
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
|
28 |
+
>>> from diffusers.utils import load_image
|
29 |
+
>>> import numpy as np
|
30 |
+
>>> import torch
|
31 |
+
|
32 |
+
>>> import cv2
|
33 |
+
>>> from PIL import Image
|
34 |
+
>>> # download an image
|
35 |
+
>>> image = load_image(
|
36 |
+
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
37 |
+
... )
|
38 |
+
>>> image = np.array(image)
|
39 |
+
>>> mask_image = load_image(
|
40 |
+
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
41 |
+
... )
|
42 |
+
>>> mask_image = np.array(mask_image)
|
43 |
+
>>> # get canny image
|
44 |
+
>>> canny_image = cv2.Canny(image, 100, 200)
|
45 |
+
>>> canny_image = canny_image[:, :, None]
|
46 |
+
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
47 |
+
>>> canny_image = Image.fromarray(canny_image)
|
48 |
+
|
49 |
+
>>> # load control net and stable diffusion v1-5
|
50 |
+
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
51 |
+
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
52 |
+
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
|
53 |
+
... )
|
54 |
+
|
55 |
+
>>> # speed up diffusion process with faster scheduler and memory optimization
|
56 |
+
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
57 |
+
>>> # remove following line if xformers is not installed
|
58 |
+
>>> pipe.enable_xformers_memory_efficient_attention()
|
59 |
+
|
60 |
+
>>> pipe.enable_model_cpu_offload()
|
61 |
+
|
62 |
+
>>> # generate image
|
63 |
+
>>> generator = torch.manual_seed(0)
|
64 |
+
>>> image = pipe(
|
65 |
+
... "futuristic-looking doggo",
|
66 |
+
... num_inference_steps=20,
|
67 |
+
... generator=generator,
|
68 |
+
... image=image,
|
69 |
+
... control_image=canny_image,
|
70 |
+
... mask_image=mask_image
|
71 |
+
... ).images[0]
|
72 |
+
```
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
def prepare_mask_and_masked_image(image, mask):
|
77 |
+
"""
|
78 |
+
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
79 |
+
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
80 |
+
``image`` and ``1`` for the ``mask``.
|
81 |
+
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
82 |
+
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
83 |
+
Args:
|
84 |
+
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
85 |
+
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
86 |
+
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
87 |
+
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
88 |
+
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
89 |
+
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
90 |
+
Raises:
|
91 |
+
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
92 |
+
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
93 |
+
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
94 |
+
(ot the other way around).
|
95 |
+
Returns:
|
96 |
+
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
97 |
+
dimensions: ``batch x channels x height x width``.
|
98 |
+
"""
|
99 |
+
if isinstance(image, torch.Tensor):
|
100 |
+
if not isinstance(mask, torch.Tensor):
|
101 |
+
raise TypeError(
|
102 |
+
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
|
103 |
+
)
|
104 |
+
|
105 |
+
# Batch single image
|
106 |
+
if image.ndim == 3:
|
107 |
+
assert (
|
108 |
+
image.shape[0] == 3
|
109 |
+
), "Image outside a batch should be of shape (3, H, W)"
|
110 |
+
image = image.unsqueeze(0)
|
111 |
+
|
112 |
+
# Batch and add channel dim for single mask
|
113 |
+
if mask.ndim == 2:
|
114 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
115 |
+
|
116 |
+
# Batch single mask or add channel dim
|
117 |
+
if mask.ndim == 3:
|
118 |
+
# Single batched mask, no channel dim or single mask not batched but channel dim
|
119 |
+
if mask.shape[0] == 1:
|
120 |
+
mask = mask.unsqueeze(0)
|
121 |
+
|
122 |
+
# Batched masks no channel dim
|
123 |
+
else:
|
124 |
+
mask = mask.unsqueeze(1)
|
125 |
+
|
126 |
+
assert (
|
127 |
+
image.ndim == 4 and mask.ndim == 4
|
128 |
+
), "Image and Mask must have 4 dimensions"
|
129 |
+
assert (
|
130 |
+
image.shape[-2:] == mask.shape[-2:]
|
131 |
+
), "Image and Mask must have the same spatial dimensions"
|
132 |
+
assert (
|
133 |
+
image.shape[0] == mask.shape[0]
|
134 |
+
), "Image and Mask must have the same batch size"
|
135 |
+
|
136 |
+
# Check image is in [-1, 1]
|
137 |
+
if image.min() < -1 or image.max() > 1:
|
138 |
+
raise ValueError("Image should be in [-1, 1] range")
|
139 |
+
|
140 |
+
# Check mask is in [0, 1]
|
141 |
+
if mask.min() < 0 or mask.max() > 1:
|
142 |
+
raise ValueError("Mask should be in [0, 1] range")
|
143 |
+
|
144 |
+
# Binarize mask
|
145 |
+
mask[mask < 0.5] = 0
|
146 |
+
mask[mask >= 0.5] = 1
|
147 |
+
|
148 |
+
# Image as float32
|
149 |
+
image = image.to(dtype=torch.float32)
|
150 |
+
elif isinstance(mask, torch.Tensor):
|
151 |
+
raise TypeError(
|
152 |
+
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
# preprocess image
|
156 |
+
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
157 |
+
image = [image]
|
158 |
+
|
159 |
+
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
160 |
+
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
161 |
+
image = np.concatenate(image, axis=0)
|
162 |
+
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
163 |
+
image = np.concatenate([i[None, :] for i in image], axis=0)
|
164 |
+
|
165 |
+
image = image.transpose(0, 3, 1, 2)
|
166 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
167 |
+
|
168 |
+
# preprocess mask
|
169 |
+
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
170 |
+
mask = [mask]
|
171 |
+
|
172 |
+
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
173 |
+
mask = np.concatenate(
|
174 |
+
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
175 |
+
)
|
176 |
+
mask = mask.astype(np.float32) / 255.0
|
177 |
+
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
178 |
+
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
179 |
+
|
180 |
+
mask[mask < 0.5] = 0
|
181 |
+
mask[mask >= 0.5] = 1
|
182 |
+
mask = torch.from_numpy(mask)
|
183 |
+
|
184 |
+
masked_image = image * (mask < 0.5)
|
185 |
+
|
186 |
+
return mask, masked_image
|
187 |
+
|
188 |
+
|
189 |
+
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
|
190 |
+
r"""
|
191 |
+
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
192 |
+
|
193 |
+
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
|
194 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
195 |
+
|
196 |
+
Args:
|
197 |
+
vae ([`AutoencoderKL`]):
|
198 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
199 |
+
text_encoder ([`CLIPTextModel`]):
|
200 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
201 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
202 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
203 |
+
tokenizer (`CLIPTokenizer`):
|
204 |
+
Tokenizer of class
|
205 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
206 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
207 |
+
controlnet ([`ControlNetModel`]):
|
208 |
+
Provides additional conditioning to the unet during the denoising process
|
209 |
+
scheduler ([`SchedulerMixin`]):
|
210 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
211 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
212 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
213 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
214 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
215 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
216 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def prepare_mask_latents(
|
220 |
+
self,
|
221 |
+
mask,
|
222 |
+
masked_image,
|
223 |
+
batch_size,
|
224 |
+
height,
|
225 |
+
width,
|
226 |
+
dtype,
|
227 |
+
device,
|
228 |
+
generator,
|
229 |
+
do_classifier_free_guidance,
|
230 |
+
):
|
231 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
232 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
233 |
+
# and half precision
|
234 |
+
mask = torch.nn.functional.interpolate(
|
235 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
236 |
+
)
|
237 |
+
mask = mask.to(device=device, dtype=dtype)
|
238 |
+
|
239 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
240 |
+
|
241 |
+
# encode the mask image into latents space so we can concatenate it to the latents
|
242 |
+
if isinstance(generator, list):
|
243 |
+
masked_image_latents = [
|
244 |
+
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
|
245 |
+
generator=generator[i]
|
246 |
+
)
|
247 |
+
for i in range(batch_size)
|
248 |
+
]
|
249 |
+
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
250 |
+
else:
|
251 |
+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
|
252 |
+
generator=generator
|
253 |
+
)
|
254 |
+
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
255 |
+
|
256 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
257 |
+
if mask.shape[0] < batch_size:
|
258 |
+
if not batch_size % mask.shape[0] == 0:
|
259 |
+
raise ValueError(
|
260 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
261 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
262 |
+
" of masks that you pass is divisible by the total requested batch size."
|
263 |
+
)
|
264 |
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
265 |
+
if masked_image_latents.shape[0] < batch_size:
|
266 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
267 |
+
raise ValueError(
|
268 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
269 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
270 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
271 |
+
)
|
272 |
+
masked_image_latents = masked_image_latents.repeat(
|
273 |
+
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
274 |
+
)
|
275 |
+
|
276 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
277 |
+
masked_image_latents = (
|
278 |
+
torch.cat([masked_image_latents] * 2)
|
279 |
+
if do_classifier_free_guidance
|
280 |
+
else masked_image_latents
|
281 |
+
)
|
282 |
+
|
283 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
284 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
285 |
+
return mask, masked_image_latents
|
286 |
+
|
287 |
+
@torch.no_grad()
|
288 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
289 |
+
def __call__(
|
290 |
+
self,
|
291 |
+
prompt: Union[str, List[str]] = None,
|
292 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
293 |
+
control_image: Union[
|
294 |
+
torch.FloatTensor,
|
295 |
+
PIL.Image.Image,
|
296 |
+
List[torch.FloatTensor],
|
297 |
+
List[PIL.Image.Image],
|
298 |
+
] = None,
|
299 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
300 |
+
height: Optional[int] = None,
|
301 |
+
width: Optional[int] = None,
|
302 |
+
num_inference_steps: int = 50,
|
303 |
+
guidance_scale: float = 7.5,
|
304 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
305 |
+
num_images_per_prompt: Optional[int] = 1,
|
306 |
+
eta: float = 0.0,
|
307 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
308 |
+
latents: Optional[torch.FloatTensor] = None,
|
309 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
310 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
311 |
+
output_type: Optional[str] = "pil",
|
312 |
+
return_dict: bool = True,
|
313 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
314 |
+
callback_steps: int = 1,
|
315 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
316 |
+
controlnet_conditioning_scale: float = 1.0,
|
317 |
+
):
|
318 |
+
r"""
|
319 |
+
Function invoked when calling the pipeline for generation.
|
320 |
+
Args:
|
321 |
+
prompt (`str` or `List[str]`, *optional*):
|
322 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
323 |
+
instead.
|
324 |
+
image (`PIL.Image.Image`):
|
325 |
+
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
326 |
+
be masked out with `mask_image` and repainted according to `prompt`.
|
327 |
+
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
|
328 |
+
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
329 |
+
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
|
330 |
+
also be accepted as an image. The control image is automatically resized to fit the output image.
|
331 |
+
mask_image (`PIL.Image.Image`):
|
332 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
333 |
+
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
334 |
+
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
335 |
+
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
336 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
337 |
+
The height in pixels of the generated image.
|
338 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
339 |
+
The width in pixels of the generated image.
|
340 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
341 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
342 |
+
expense of slower inference.
|
343 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
344 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
345 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
346 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
347 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
348 |
+
usually at the expense of lower image quality.
|
349 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
350 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
351 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
352 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
353 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
354 |
+
The number of images to generate per prompt.
|
355 |
+
eta (`float`, *optional*, defaults to 0.0):
|
356 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
357 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
358 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
359 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
360 |
+
to make generation deterministic.
|
361 |
+
latents (`torch.FloatTensor`, *optional*):
|
362 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
363 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
364 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
365 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
366 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
367 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
368 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
369 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
370 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
371 |
+
argument.
|
372 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
373 |
+
The output format of the generate image. Choose between
|
374 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
375 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
376 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
377 |
+
plain tuple.
|
378 |
+
callback (`Callable`, *optional*):
|
379 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
380 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
381 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
382 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
383 |
+
called at every step.
|
384 |
+
cross_attention_kwargs (`dict`, *optional*):
|
385 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
386 |
+
`self.processor` in
|
387 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
388 |
+
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
389 |
+
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
390 |
+
to the residual in the original unet.
|
391 |
+
Examples:
|
392 |
+
Returns:
|
393 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
394 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
395 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
396 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
397 |
+
(nsfw) content, according to the `safety_checker`.
|
398 |
+
"""
|
399 |
+
# 0. Default height and width to unet
|
400 |
+
height, width = self._default_height_width(height, width, control_image)
|
401 |
+
|
402 |
+
# 1. Check inputs. Raise error if not correct
|
403 |
+
self.check_inputs(
|
404 |
+
prompt,
|
405 |
+
control_image,
|
406 |
+
height,
|
407 |
+
width,
|
408 |
+
callback_steps,
|
409 |
+
negative_prompt,
|
410 |
+
prompt_embeds,
|
411 |
+
negative_prompt_embeds,
|
412 |
+
)
|
413 |
+
|
414 |
+
# 2. Define call parameters
|
415 |
+
if prompt is not None and isinstance(prompt, str):
|
416 |
+
batch_size = 1
|
417 |
+
elif prompt is not None and isinstance(prompt, list):
|
418 |
+
batch_size = len(prompt)
|
419 |
+
else:
|
420 |
+
batch_size = prompt_embeds.shape[0]
|
421 |
+
|
422 |
+
device = self._execution_device
|
423 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
424 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
425 |
+
# corresponds to doing no classifier free guidance.
|
426 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
427 |
+
|
428 |
+
# 3. Encode input prompt
|
429 |
+
prompt_embeds = self._encode_prompt(
|
430 |
+
prompt,
|
431 |
+
device,
|
432 |
+
num_images_per_prompt,
|
433 |
+
do_classifier_free_guidance,
|
434 |
+
negative_prompt,
|
435 |
+
prompt_embeds=prompt_embeds,
|
436 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
437 |
+
)
|
438 |
+
|
439 |
+
# 4. Prepare image
|
440 |
+
control_image = self.prepare_image(
|
441 |
+
control_image,
|
442 |
+
width,
|
443 |
+
height,
|
444 |
+
batch_size * num_images_per_prompt,
|
445 |
+
num_images_per_prompt,
|
446 |
+
device,
|
447 |
+
self.controlnet.dtype,
|
448 |
+
)
|
449 |
+
|
450 |
+
if do_classifier_free_guidance:
|
451 |
+
control_image = torch.cat([control_image] * 2)
|
452 |
+
|
453 |
+
# 5. Prepare timesteps
|
454 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
455 |
+
timesteps = self.scheduler.timesteps
|
456 |
+
|
457 |
+
# 6. Prepare latent variables
|
458 |
+
num_channels_latents = self.controlnet.config.in_channels
|
459 |
+
latents = self.prepare_latents(
|
460 |
+
batch_size * num_images_per_prompt,
|
461 |
+
num_channels_latents,
|
462 |
+
height,
|
463 |
+
width,
|
464 |
+
prompt_embeds.dtype,
|
465 |
+
device,
|
466 |
+
generator,
|
467 |
+
latents,
|
468 |
+
)
|
469 |
+
|
470 |
+
# EXTRA: prepare mask latents
|
471 |
+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
472 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
473 |
+
mask,
|
474 |
+
masked_image,
|
475 |
+
batch_size * num_images_per_prompt,
|
476 |
+
height,
|
477 |
+
width,
|
478 |
+
prompt_embeds.dtype,
|
479 |
+
device,
|
480 |
+
generator,
|
481 |
+
do_classifier_free_guidance,
|
482 |
+
)
|
483 |
+
|
484 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
485 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
486 |
+
|
487 |
+
# 8. Denoising loop
|
488 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
489 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
490 |
+
for i, t in enumerate(timesteps):
|
491 |
+
# expand the latents if we are doing classifier free guidance
|
492 |
+
latent_model_input = (
|
493 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
494 |
+
)
|
495 |
+
latent_model_input = self.scheduler.scale_model_input(
|
496 |
+
latent_model_input, t
|
497 |
+
)
|
498 |
+
|
499 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
500 |
+
latent_model_input,
|
501 |
+
t,
|
502 |
+
encoder_hidden_states=prompt_embeds,
|
503 |
+
controlnet_cond=control_image,
|
504 |
+
return_dict=False,
|
505 |
+
)
|
506 |
+
|
507 |
+
down_block_res_samples = [
|
508 |
+
down_block_res_sample * controlnet_conditioning_scale
|
509 |
+
for down_block_res_sample in down_block_res_samples
|
510 |
+
]
|
511 |
+
mid_block_res_sample *= controlnet_conditioning_scale
|
512 |
+
|
513 |
+
# predict the noise residual
|
514 |
+
latent_model_input = torch.cat(
|
515 |
+
[latent_model_input, mask, masked_image_latents], dim=1
|
516 |
+
)
|
517 |
+
noise_pred = self.unet(
|
518 |
+
latent_model_input,
|
519 |
+
t,
|
520 |
+
encoder_hidden_states=prompt_embeds,
|
521 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
522 |
+
down_block_additional_residuals=down_block_res_samples,
|
523 |
+
mid_block_additional_residual=mid_block_res_sample,
|
524 |
+
).sample
|
525 |
+
|
526 |
+
# perform guidance
|
527 |
+
if do_classifier_free_guidance:
|
528 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
529 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
530 |
+
noise_pred_text - noise_pred_uncond
|
531 |
+
)
|
532 |
+
|
533 |
+
# compute the previous noisy sample x_t -> x_t-1
|
534 |
+
latents = self.scheduler.step(
|
535 |
+
noise_pred, t, latents, **extra_step_kwargs
|
536 |
+
).prev_sample
|
537 |
+
|
538 |
+
# call the callback, if provided
|
539 |
+
if i == len(timesteps) - 1 or (
|
540 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
541 |
+
):
|
542 |
+
progress_bar.update()
|
543 |
+
if callback is not None and i % callback_steps == 0:
|
544 |
+
callback(i, t, latents)
|
545 |
+
|
546 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
547 |
+
# manually for max memory savings
|
548 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
549 |
+
self.unet.to("cpu")
|
550 |
+
self.controlnet.to("cpu")
|
551 |
+
torch.cuda.empty_cache()
|
552 |
+
|
553 |
+
if output_type == "latent":
|
554 |
+
image = latents
|
555 |
+
has_nsfw_concept = None
|
556 |
+
elif output_type == "pil":
|
557 |
+
# 8. Post-processing
|
558 |
+
image = self.decode_latents(latents)
|
559 |
+
|
560 |
+
# 9. Run safety checker
|
561 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
562 |
+
image, device, prompt_embeds.dtype
|
563 |
+
)
|
564 |
+
|
565 |
+
# 10. Convert to PIL
|
566 |
+
image = self.numpy_to_pil(image)
|
567 |
+
else:
|
568 |
+
# 8. Post-processing
|
569 |
+
image = self.decode_latents(latents)
|
570 |
+
|
571 |
+
# 9. Run safety checker
|
572 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
573 |
+
image, device, prompt_embeds.dtype
|
574 |
+
)
|
575 |
+
|
576 |
+
# Offload last model to CPU
|
577 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
578 |
+
self.final_offload_hook.offload()
|
579 |
+
|
580 |
+
if not return_dict:
|
581 |
+
return (image, has_nsfw_concept)
|
582 |
+
|
583 |
+
return StableDiffusionPipelineOutput(
|
584 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
585 |
+
)
|
lama_cleaner/model/plms_sampler.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class PLMSSampler(object):
|
9 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
10 |
+
super().__init__()
|
11 |
+
self.model = model
|
12 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
13 |
+
self.schedule = schedule
|
14 |
+
|
15 |
+
def register_buffer(self, name, attr):
|
16 |
+
setattr(self, name, attr)
|
17 |
+
|
18 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
19 |
+
if ddim_eta != 0:
|
20 |
+
raise ValueError('ddim_eta must be 0 for PLMS')
|
21 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
22 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
23 |
+
alphas_cumprod = self.model.alphas_cumprod
|
24 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
25 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
26 |
+
|
27 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
28 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
29 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
30 |
+
|
31 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
32 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
33 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
34 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
35 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
36 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
37 |
+
|
38 |
+
# ddim sampling parameters
|
39 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
40 |
+
ddim_timesteps=self.ddim_timesteps,
|
41 |
+
eta=ddim_eta, verbose=verbose)
|
42 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
43 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
44 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
45 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
46 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
47 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
48 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
49 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def sample(self,
|
53 |
+
steps,
|
54 |
+
batch_size,
|
55 |
+
shape,
|
56 |
+
conditioning=None,
|
57 |
+
callback=None,
|
58 |
+
normals_sequence=None,
|
59 |
+
img_callback=None,
|
60 |
+
quantize_x0=False,
|
61 |
+
eta=0.,
|
62 |
+
mask=None,
|
63 |
+
x0=None,
|
64 |
+
temperature=1.,
|
65 |
+
noise_dropout=0.,
|
66 |
+
score_corrector=None,
|
67 |
+
corrector_kwargs=None,
|
68 |
+
verbose=False,
|
69 |
+
x_T=None,
|
70 |
+
log_every_t=100,
|
71 |
+
unconditional_guidance_scale=1.,
|
72 |
+
unconditional_conditioning=None,
|
73 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
74 |
+
**kwargs
|
75 |
+
):
|
76 |
+
if conditioning is not None:
|
77 |
+
if isinstance(conditioning, dict):
|
78 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
79 |
+
if cbs != batch_size:
|
80 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
81 |
+
else:
|
82 |
+
if conditioning.shape[0] != batch_size:
|
83 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
84 |
+
|
85 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
86 |
+
# sampling
|
87 |
+
C, H, W = shape
|
88 |
+
size = (batch_size, C, H, W)
|
89 |
+
print(f'Data shape for PLMS sampling is {size}')
|
90 |
+
|
91 |
+
samples = self.plms_sampling(conditioning, size,
|
92 |
+
callback=callback,
|
93 |
+
img_callback=img_callback,
|
94 |
+
quantize_denoised=quantize_x0,
|
95 |
+
mask=mask, x0=x0,
|
96 |
+
ddim_use_original_steps=False,
|
97 |
+
noise_dropout=noise_dropout,
|
98 |
+
temperature=temperature,
|
99 |
+
score_corrector=score_corrector,
|
100 |
+
corrector_kwargs=corrector_kwargs,
|
101 |
+
x_T=x_T,
|
102 |
+
log_every_t=log_every_t,
|
103 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
104 |
+
unconditional_conditioning=unconditional_conditioning,
|
105 |
+
)
|
106 |
+
return samples
|
107 |
+
|
108 |
+
@torch.no_grad()
|
109 |
+
def plms_sampling(self, cond, shape,
|
110 |
+
x_T=None, ddim_use_original_steps=False,
|
111 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
112 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
113 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
114 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
|
115 |
+
device = self.model.betas.device
|
116 |
+
b = shape[0]
|
117 |
+
if x_T is None:
|
118 |
+
img = torch.randn(shape, device=device)
|
119 |
+
else:
|
120 |
+
img = x_T
|
121 |
+
|
122 |
+
if timesteps is None:
|
123 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
124 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
125 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
126 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
127 |
+
|
128 |
+
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
129 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
130 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
131 |
+
|
132 |
+
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
133 |
+
old_eps = []
|
134 |
+
|
135 |
+
for i, step in enumerate(iterator):
|
136 |
+
index = total_steps - i - 1
|
137 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
138 |
+
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
139 |
+
|
140 |
+
if mask is not None:
|
141 |
+
assert x0 is not None
|
142 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
143 |
+
img = img_orig * mask + (1. - mask) * img
|
144 |
+
|
145 |
+
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
146 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
147 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
148 |
+
corrector_kwargs=corrector_kwargs,
|
149 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
150 |
+
unconditional_conditioning=unconditional_conditioning,
|
151 |
+
old_eps=old_eps, t_next=ts_next)
|
152 |
+
img, pred_x0, e_t = outs
|
153 |
+
old_eps.append(e_t)
|
154 |
+
if len(old_eps) >= 4:
|
155 |
+
old_eps.pop(0)
|
156 |
+
if callback: callback(i)
|
157 |
+
if img_callback: img_callback(pred_x0, i)
|
158 |
+
|
159 |
+
return img
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
163 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
164 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
165 |
+
b, *_, device = *x.shape, x.device
|
166 |
+
|
167 |
+
def get_model_output(x, t):
|
168 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
169 |
+
e_t = self.model.apply_model(x, t, c)
|
170 |
+
else:
|
171 |
+
x_in = torch.cat([x] * 2)
|
172 |
+
t_in = torch.cat([t] * 2)
|
173 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
174 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
175 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
176 |
+
|
177 |
+
if score_corrector is not None:
|
178 |
+
assert self.model.parameterization == "eps"
|
179 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
180 |
+
|
181 |
+
return e_t
|
182 |
+
|
183 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
184 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
185 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
186 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
187 |
+
|
188 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
189 |
+
# select parameters corresponding to the currently considered timestep
|
190 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
191 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
192 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
193 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
194 |
+
|
195 |
+
# current prediction for x_0
|
196 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
197 |
+
if quantize_denoised:
|
198 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
199 |
+
# direction pointing to x_t
|
200 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
201 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
202 |
+
if noise_dropout > 0.:
|
203 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
204 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
205 |
+
return x_prev, pred_x0
|
206 |
+
|
207 |
+
e_t = get_model_output(x, t)
|
208 |
+
if len(old_eps) == 0:
|
209 |
+
# Pseudo Improved Euler (2nd order)
|
210 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
211 |
+
e_t_next = get_model_output(x_prev, t_next)
|
212 |
+
e_t_prime = (e_t + e_t_next) / 2
|
213 |
+
elif len(old_eps) == 1:
|
214 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
215 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
216 |
+
elif len(old_eps) == 2:
|
217 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
218 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
219 |
+
elif len(old_eps) >= 3:
|
220 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
221 |
+
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
222 |
+
|
223 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
224 |
+
|
225 |
+
return x_prev, pred_x0, e_t
|