Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
•
6dd488f
1
Parent(s):
42373a7
Upload code
Browse files- .gitignore +168 -0
- LICENSE +201 -0
- diffusers_helper/cat_cond.py +24 -0
- diffusers_helper/code_cond.py +34 -0
- diffusers_helper/k_diffusion.py +145 -0
- diffusers_helper/utils.py +136 -0
- diffusers_vdm/attention.py +385 -0
- diffusers_vdm/basics.py +148 -0
- diffusers_vdm/dynamic_tsnr_sampler.py +177 -0
- diffusers_vdm/improved_clip_vision.py +58 -0
- diffusers_vdm/pipeline.py +188 -0
- diffusers_vdm/projection.py +160 -0
- diffusers_vdm/unet.py +650 -0
- diffusers_vdm/utils.py +43 -0
- diffusers_vdm/vae.py +826 -0
- gradio_app.py +324 -0
- imgs/1.jpg +0 -0
- imgs/2.jpg +0 -0
- imgs/3.jpg +0 -0
- memory_management.py +67 -0
- requirements.txt +19 -0
- wd14tagger.py +105 -0
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hf_token.txt
|
2 |
+
hf_download/
|
3 |
+
results/
|
4 |
+
*.csv
|
5 |
+
*.onnx
|
6 |
+
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
diffusers_helper/cat_cond.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def unet_add_concat_conds(unet, new_channels=4):
|
5 |
+
with torch.no_grad():
|
6 |
+
new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
7 |
+
new_conv_in.weight.zero_()
|
8 |
+
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
|
9 |
+
new_conv_in.bias = unet.conv_in.bias
|
10 |
+
unet.conv_in = new_conv_in
|
11 |
+
|
12 |
+
unet_original_forward = unet.forward
|
13 |
+
|
14 |
+
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
15 |
+
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
|
16 |
+
c_concat = cross_attention_kwargs.pop('concat_conds')
|
17 |
+
kwargs['cross_attention_kwargs'] = cross_attention_kwargs
|
18 |
+
|
19 |
+
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
|
20 |
+
new_sample = torch.cat([sample, c_concat], dim=1)
|
21 |
+
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
|
22 |
+
|
23 |
+
unet.forward = hooked_unet_forward
|
24 |
+
return
|
diffusers_helper/code_cond.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
4 |
+
|
5 |
+
|
6 |
+
def unet_add_coded_conds(unet, added_number_count=1):
|
7 |
+
unet.add_time_proj = Timesteps(256, True, 0)
|
8 |
+
unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)
|
9 |
+
|
10 |
+
def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
|
11 |
+
coded_conds = added_cond_kwargs.get("coded_conds")
|
12 |
+
batch_size = coded_conds.shape[0]
|
13 |
+
time_embeds = unet.add_time_proj(coded_conds.flatten())
|
14 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
15 |
+
time_embeds = time_embeds.to(emb)
|
16 |
+
aug_emb = unet.add_embedding(time_embeds)
|
17 |
+
return aug_emb
|
18 |
+
|
19 |
+
unet.get_aug_embed = get_aug_embed
|
20 |
+
|
21 |
+
unet_original_forward = unet.forward
|
22 |
+
|
23 |
+
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
24 |
+
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
|
25 |
+
coded_conds = cross_attention_kwargs.pop('coded_conds')
|
26 |
+
kwargs['cross_attention_kwargs'] = cross_attention_kwargs
|
27 |
+
|
28 |
+
coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
|
29 |
+
kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
|
30 |
+
return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)
|
31 |
+
|
32 |
+
unet.forward = hooked_unet_forward
|
33 |
+
|
34 |
+
return
|
diffusers_helper/k_diffusion.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None):
|
9 |
+
"""DPM-Solver++(2M)."""
|
10 |
+
extra_args = {} if extra_args is None else extra_args
|
11 |
+
s_in = x.new_ones([x.shape[0]])
|
12 |
+
sigma_fn = lambda t: t.neg().exp()
|
13 |
+
t_fn = lambda sigma: sigma.log().neg()
|
14 |
+
old_denoised = None
|
15 |
+
|
16 |
+
bar = tqdm if progress_tqdm is None else progress_tqdm
|
17 |
+
|
18 |
+
for i in bar(range(len(sigmas) - 1)):
|
19 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
20 |
+
if callback is not None:
|
21 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
22 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
23 |
+
h = t_next - t
|
24 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
25 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
26 |
+
else:
|
27 |
+
h_last = t - t_fn(sigmas[i - 1])
|
28 |
+
r = h_last / h
|
29 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
30 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
31 |
+
old_denoised = denoised
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class KModel:
|
36 |
+
def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False):
|
37 |
+
if linear:
|
38 |
+
betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64)
|
39 |
+
else:
|
40 |
+
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
|
41 |
+
|
42 |
+
alphas = 1. - betas
|
43 |
+
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
44 |
+
|
45 |
+
self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
46 |
+
self.log_sigmas = self.sigmas.log()
|
47 |
+
self.sigma_data = 1.0
|
48 |
+
self.unet = unet
|
49 |
+
return
|
50 |
+
|
51 |
+
@property
|
52 |
+
def sigma_min(self):
|
53 |
+
return self.sigmas[0]
|
54 |
+
|
55 |
+
@property
|
56 |
+
def sigma_max(self):
|
57 |
+
return self.sigmas[-1]
|
58 |
+
|
59 |
+
def timestep(self, sigma):
|
60 |
+
log_sigma = sigma.log()
|
61 |
+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
62 |
+
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
63 |
+
|
64 |
+
def get_sigmas_karras(self, n, rho=7.):
|
65 |
+
ramp = torch.linspace(0, 1, n)
|
66 |
+
min_inv_rho = self.sigma_min ** (1 / rho)
|
67 |
+
max_inv_rho = self.sigma_max ** (1 / rho)
|
68 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
69 |
+
return torch.cat([sigmas, sigmas.new_zeros([1])])
|
70 |
+
|
71 |
+
def __call__(self, x, sigma, **extra_args):
|
72 |
+
x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
|
73 |
+
x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype)
|
74 |
+
t = self.timestep(sigma)
|
75 |
+
cfg_scale = extra_args['cfg_scale']
|
76 |
+
eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
|
77 |
+
eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
|
78 |
+
noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
|
79 |
+
return x - noise_pred * sigma[:, None, None, None]
|
80 |
+
|
81 |
+
|
82 |
+
class KDiffusionSampler:
|
83 |
+
def __init__(self, unet, **kwargs):
|
84 |
+
self.unet = unet
|
85 |
+
self.k_model = KModel(unet=unet, **kwargs)
|
86 |
+
|
87 |
+
@torch.inference_mode()
|
88 |
+
def __call__(
|
89 |
+
self,
|
90 |
+
initial_latent = None,
|
91 |
+
strength = 1.0,
|
92 |
+
num_inference_steps = 25,
|
93 |
+
guidance_scale = 5.0,
|
94 |
+
batch_size = 1,
|
95 |
+
generator = None,
|
96 |
+
prompt_embeds = None,
|
97 |
+
negative_prompt_embeds = None,
|
98 |
+
cross_attention_kwargs = None,
|
99 |
+
same_noise_in_batch = False,
|
100 |
+
progress_tqdm = None,
|
101 |
+
):
|
102 |
+
|
103 |
+
device = self.unet.device
|
104 |
+
|
105 |
+
# Sigmas
|
106 |
+
|
107 |
+
sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength))
|
108 |
+
sigmas = sigmas[-(num_inference_steps + 1):].to(device)
|
109 |
+
|
110 |
+
# Initial latents
|
111 |
+
|
112 |
+
if same_noise_in_batch:
|
113 |
+
noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1)
|
114 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
|
115 |
+
else:
|
116 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
|
117 |
+
noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype)
|
118 |
+
|
119 |
+
latents = initial_latent + noise * sigmas[0].to(initial_latent)
|
120 |
+
|
121 |
+
# Batch
|
122 |
+
|
123 |
+
latents = latents.to(device)
|
124 |
+
prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device)
|
125 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device)
|
126 |
+
|
127 |
+
# Feeds
|
128 |
+
|
129 |
+
sampler_kwargs = dict(
|
130 |
+
cfg_scale=guidance_scale,
|
131 |
+
positive=dict(
|
132 |
+
encoder_hidden_states=prompt_embeds,
|
133 |
+
cross_attention_kwargs=cross_attention_kwargs
|
134 |
+
),
|
135 |
+
negative=dict(
|
136 |
+
encoder_hidden_states=negative_prompt_embeds,
|
137 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
# Sample
|
142 |
+
|
143 |
+
results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
|
144 |
+
|
145 |
+
return results
|
diffusers_helper/utils.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
import einops
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
import safetensors.torch as sf
|
10 |
+
|
11 |
+
|
12 |
+
def write_to_json(data, file_path):
|
13 |
+
temp_file_path = file_path + ".tmp"
|
14 |
+
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
15 |
+
json.dump(data, temp_file, indent=4)
|
16 |
+
os.replace(temp_file_path, file_path)
|
17 |
+
return
|
18 |
+
|
19 |
+
|
20 |
+
def read_from_json(file_path):
|
21 |
+
with open(file_path, 'rt', encoding='utf-8') as file:
|
22 |
+
data = json.load(file)
|
23 |
+
return data
|
24 |
+
|
25 |
+
|
26 |
+
def get_active_parameters(m):
|
27 |
+
return {k:v for k, v in m.named_parameters() if v.requires_grad}
|
28 |
+
|
29 |
+
|
30 |
+
def cast_training_params(m, dtype=torch.float32):
|
31 |
+
for param in m.parameters():
|
32 |
+
if param.requires_grad:
|
33 |
+
param.data = param.to(dtype)
|
34 |
+
return
|
35 |
+
|
36 |
+
|
37 |
+
def set_attr_recursive(obj, attr, value):
|
38 |
+
attrs = attr.split(".")
|
39 |
+
for name in attrs[:-1]:
|
40 |
+
obj = getattr(obj, name)
|
41 |
+
setattr(obj, attrs[-1], value)
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def batch_mixture(a, b, probability_a=0.5, mask_a=None):
|
47 |
+
assert a.shape == b.shape, "Tensors must have the same shape"
|
48 |
+
batch_size = a.size(0)
|
49 |
+
|
50 |
+
if mask_a is None:
|
51 |
+
mask_a = torch.rand(batch_size) < probability_a
|
52 |
+
|
53 |
+
mask_a = mask_a.to(a.device)
|
54 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
55 |
+
result = torch.where(mask_a, a, b)
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def zero_module(module):
|
61 |
+
for p in module.parameters():
|
62 |
+
p.detach().zero_()
|
63 |
+
return module
|
64 |
+
|
65 |
+
|
66 |
+
def load_last_state(model, folder='accelerator_output'):
|
67 |
+
file_pattern = os.path.join(folder, '**', 'model.safetensors')
|
68 |
+
files = glob.glob(file_pattern, recursive=True)
|
69 |
+
|
70 |
+
if not files:
|
71 |
+
print("No model.safetensors files found in the specified folder.")
|
72 |
+
return
|
73 |
+
|
74 |
+
newest_file = max(files, key=os.path.getmtime)
|
75 |
+
state_dict = sf.load_file(newest_file)
|
76 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
77 |
+
|
78 |
+
if missing_keys:
|
79 |
+
print("Missing keys:", missing_keys)
|
80 |
+
if unexpected_keys:
|
81 |
+
print("Unexpected keys:", unexpected_keys)
|
82 |
+
|
83 |
+
print("Loaded model state from:", newest_file)
|
84 |
+
return
|
85 |
+
|
86 |
+
|
87 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
88 |
+
tags = tags_str.split(', ')
|
89 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
90 |
+
prompt = ', '.join(tags)
|
91 |
+
return prompt
|
92 |
+
|
93 |
+
|
94 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10):
|
95 |
+
b, c, t, h, w = x.shape
|
96 |
+
|
97 |
+
per_row = b
|
98 |
+
for p in [6, 5, 4, 3, 2]:
|
99 |
+
if b % p == 0:
|
100 |
+
per_row = p
|
101 |
+
break
|
102 |
+
|
103 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
104 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
105 |
+
x = x.detach().cpu().to(torch.uint8)
|
106 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
107 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
def save_bcthw_as_png(x, output_filename):
|
112 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
113 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
114 |
+
x = x.detach().cpu().to(torch.uint8)
|
115 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
116 |
+
torchvision.io.write_png(x, output_filename)
|
117 |
+
return output_filename
|
118 |
+
|
119 |
+
|
120 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
121 |
+
if tensor1.shape == tensor2.shape:
|
122 |
+
return tensor1 + tensor2
|
123 |
+
|
124 |
+
shape1 = tensor1.shape
|
125 |
+
shape2 = tensor2.shape
|
126 |
+
|
127 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
128 |
+
|
129 |
+
padded_tensor1 = torch.zeros(new_shape)
|
130 |
+
padded_tensor2 = torch.zeros(new_shape)
|
131 |
+
|
132 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
133 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
134 |
+
|
135 |
+
result = padded_tensor1 + padded_tensor2
|
136 |
+
return result
|
diffusers_vdm/attention.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import xformers.ops
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from functools import partial
|
8 |
+
from diffusers_vdm.basics import zero_module, checkpoint, default, make_temporal_window
|
9 |
+
|
10 |
+
|
11 |
+
def sdp(q, k, v, heads):
|
12 |
+
b, _, C = q.shape
|
13 |
+
dim_head = C // heads
|
14 |
+
|
15 |
+
q, k, v = map(
|
16 |
+
lambda t: t.unsqueeze(3)
|
17 |
+
.reshape(b, t.shape[1], heads, dim_head)
|
18 |
+
.permute(0, 2, 1, 3)
|
19 |
+
.reshape(b * heads, t.shape[1], dim_head)
|
20 |
+
.contiguous(),
|
21 |
+
(q, k, v),
|
22 |
+
)
|
23 |
+
|
24 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
25 |
+
|
26 |
+
out = (
|
27 |
+
out.unsqueeze(0)
|
28 |
+
.reshape(b, heads, out.shape[1], dim_head)
|
29 |
+
.permute(0, 2, 1, 3)
|
30 |
+
.reshape(b, out.shape[1], heads * dim_head)
|
31 |
+
)
|
32 |
+
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class RelativePosition(nn.Module):
|
37 |
+
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
|
38 |
+
|
39 |
+
def __init__(self, num_units, max_relative_position):
|
40 |
+
super().__init__()
|
41 |
+
self.num_units = num_units
|
42 |
+
self.max_relative_position = max_relative_position
|
43 |
+
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
|
44 |
+
nn.init.xavier_uniform_(self.embeddings_table)
|
45 |
+
|
46 |
+
def forward(self, length_q, length_k):
|
47 |
+
device = self.embeddings_table.device
|
48 |
+
range_vec_q = torch.arange(length_q, device=device)
|
49 |
+
range_vec_k = torch.arange(length_k, device=device)
|
50 |
+
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
51 |
+
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
|
52 |
+
final_mat = distance_mat_clipped + self.max_relative_position
|
53 |
+
final_mat = final_mat.long()
|
54 |
+
embeddings = self.embeddings_table[final_mat]
|
55 |
+
return embeddings
|
56 |
+
|
57 |
+
|
58 |
+
class CrossAttention(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
|
61 |
+
relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False,
|
62 |
+
image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False,
|
63 |
+
text_context_len=77, temporal_window_for_spatial_self_attention=False):
|
64 |
+
super().__init__()
|
65 |
+
inner_dim = dim_head * heads
|
66 |
+
context_dim = default(context_dim, query_dim)
|
67 |
+
|
68 |
+
self.scale = dim_head**-0.5
|
69 |
+
self.heads = heads
|
70 |
+
self.dim_head = dim_head
|
71 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
72 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
73 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
74 |
+
|
75 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
76 |
+
|
77 |
+
self.is_temporal_attention = temporal_length is not None
|
78 |
+
|
79 |
+
self.relative_position = relative_position
|
80 |
+
if self.relative_position:
|
81 |
+
assert self.is_temporal_attention
|
82 |
+
self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
|
83 |
+
self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
|
84 |
+
|
85 |
+
self.video_length = video_length
|
86 |
+
self.temporal_window_for_spatial_self_attention = temporal_window_for_spatial_self_attention
|
87 |
+
self.temporal_window_type = 'prv'
|
88 |
+
|
89 |
+
self.image_cross_attention = image_cross_attention
|
90 |
+
self.image_cross_attention_scale = image_cross_attention_scale
|
91 |
+
self.text_context_len = text_context_len
|
92 |
+
self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
|
93 |
+
if self.image_cross_attention:
|
94 |
+
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
95 |
+
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
96 |
+
if image_cross_attention_scale_learnable:
|
97 |
+
self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) )
|
98 |
+
|
99 |
+
def forward(self, x, context=None, mask=None):
|
100 |
+
if self.is_temporal_attention:
|
101 |
+
return self.temporal_forward(x, context=context, mask=mask)
|
102 |
+
else:
|
103 |
+
return self.spatial_forward(x, context=context, mask=mask)
|
104 |
+
|
105 |
+
def temporal_forward(self, x, context=None, mask=None):
|
106 |
+
assert mask is None, 'Attention mask not implemented!'
|
107 |
+
assert context is None, 'Temporal attention only supports self attention!'
|
108 |
+
|
109 |
+
q = self.to_q(x)
|
110 |
+
k = self.to_k(x)
|
111 |
+
v = self.to_v(x)
|
112 |
+
|
113 |
+
out = sdp(q, k, v, self.heads)
|
114 |
+
|
115 |
+
return self.to_out(out)
|
116 |
+
|
117 |
+
def spatial_forward(self, x, context=None, mask=None):
|
118 |
+
assert mask is None, 'Attention mask not implemented!'
|
119 |
+
|
120 |
+
spatial_self_attn = (context is None)
|
121 |
+
k_ip, v_ip, out_ip = None, None, None
|
122 |
+
|
123 |
+
q = self.to_q(x)
|
124 |
+
context = default(context, x)
|
125 |
+
|
126 |
+
if spatial_self_attn:
|
127 |
+
k = self.to_k(context)
|
128 |
+
v = self.to_v(context)
|
129 |
+
|
130 |
+
if self.temporal_window_for_spatial_self_attention:
|
131 |
+
k = make_temporal_window(k, t=self.video_length, method=self.temporal_window_type)
|
132 |
+
v = make_temporal_window(v, t=self.video_length, method=self.temporal_window_type)
|
133 |
+
elif self.image_cross_attention:
|
134 |
+
context, context_image = context
|
135 |
+
k = self.to_k(context)
|
136 |
+
v = self.to_v(context)
|
137 |
+
k_ip = self.to_k_ip(context_image)
|
138 |
+
v_ip = self.to_v_ip(context_image)
|
139 |
+
else:
|
140 |
+
raise NotImplementedError('Traditional prompt-only attention without IP-Adapter is illegal now.')
|
141 |
+
|
142 |
+
out = sdp(q, k, v, self.heads)
|
143 |
+
|
144 |
+
if k_ip is not None:
|
145 |
+
out_ip = sdp(q, k_ip, v_ip, self.heads)
|
146 |
+
|
147 |
+
if self.image_cross_attention_scale_learnable:
|
148 |
+
out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha) + 1)
|
149 |
+
else:
|
150 |
+
out = out + self.image_cross_attention_scale * out_ip
|
151 |
+
|
152 |
+
return self.to_out(out)
|
153 |
+
|
154 |
+
|
155 |
+
class BasicTransformerBlock(nn.Module):
|
156 |
+
|
157 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
158 |
+
disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77):
|
159 |
+
super().__init__()
|
160 |
+
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
161 |
+
self.disable_self_attn = disable_self_attn
|
162 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
163 |
+
context_dim=context_dim if self.disable_self_attn else None, video_length=video_length)
|
164 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
165 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len)
|
166 |
+
self.image_cross_attention = image_cross_attention
|
167 |
+
|
168 |
+
self.norm1 = nn.LayerNorm(dim)
|
169 |
+
self.norm2 = nn.LayerNorm(dim)
|
170 |
+
self.norm3 = nn.LayerNorm(dim)
|
171 |
+
self.checkpoint = checkpoint
|
172 |
+
|
173 |
+
|
174 |
+
def forward(self, x, context=None, mask=None, **kwargs):
|
175 |
+
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
176 |
+
input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
177 |
+
if context is not None:
|
178 |
+
input_tuple = (x, context)
|
179 |
+
if mask is not None:
|
180 |
+
forward_mask = partial(self._forward, mask=mask)
|
181 |
+
return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
|
182 |
+
return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint)
|
183 |
+
|
184 |
+
|
185 |
+
def _forward(self, x, context=None, mask=None):
|
186 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
|
187 |
+
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
188 |
+
x = self.ff(self.norm3(x)) + x
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class SpatialTransformer(nn.Module):
|
193 |
+
"""
|
194 |
+
Transformer block for image-like data in spatial axis.
|
195 |
+
First, project the input (aka embedding)
|
196 |
+
and reshape to b, t, d.
|
197 |
+
Then apply standard transformer action.
|
198 |
+
Finally, reshape to image
|
199 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
|
203 |
+
use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None,
|
204 |
+
image_cross_attention=False, image_cross_attention_scale_learnable=False):
|
205 |
+
super().__init__()
|
206 |
+
self.in_channels = in_channels
|
207 |
+
inner_dim = n_heads * d_head
|
208 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
209 |
+
if not use_linear:
|
210 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
211 |
+
else:
|
212 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
213 |
+
|
214 |
+
attention_cls = None
|
215 |
+
self.transformer_blocks = nn.ModuleList([
|
216 |
+
BasicTransformerBlock(
|
217 |
+
inner_dim,
|
218 |
+
n_heads,
|
219 |
+
d_head,
|
220 |
+
dropout=dropout,
|
221 |
+
context_dim=context_dim,
|
222 |
+
disable_self_attn=disable_self_attn,
|
223 |
+
checkpoint=use_checkpoint,
|
224 |
+
attention_cls=attention_cls,
|
225 |
+
video_length=video_length,
|
226 |
+
image_cross_attention=image_cross_attention,
|
227 |
+
image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
|
228 |
+
) for d in range(depth)
|
229 |
+
])
|
230 |
+
if not use_linear:
|
231 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
232 |
+
else:
|
233 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
234 |
+
self.use_linear = use_linear
|
235 |
+
|
236 |
+
|
237 |
+
def forward(self, x, context=None, **kwargs):
|
238 |
+
b, c, h, w = x.shape
|
239 |
+
x_in = x
|
240 |
+
x = self.norm(x)
|
241 |
+
if not self.use_linear:
|
242 |
+
x = self.proj_in(x)
|
243 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
244 |
+
if self.use_linear:
|
245 |
+
x = self.proj_in(x)
|
246 |
+
for i, block in enumerate(self.transformer_blocks):
|
247 |
+
x = block(x, context=context, **kwargs)
|
248 |
+
if self.use_linear:
|
249 |
+
x = self.proj_out(x)
|
250 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
251 |
+
if not self.use_linear:
|
252 |
+
x = self.proj_out(x)
|
253 |
+
return x + x_in
|
254 |
+
|
255 |
+
|
256 |
+
class TemporalTransformer(nn.Module):
|
257 |
+
"""
|
258 |
+
Transformer block for image-like data in temporal axis.
|
259 |
+
First, reshape to b, t, d.
|
260 |
+
Then apply standard transformer action.
|
261 |
+
Finally, reshape to image
|
262 |
+
"""
|
263 |
+
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
|
264 |
+
use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1,
|
265 |
+
relative_position=False, temporal_length=None):
|
266 |
+
super().__init__()
|
267 |
+
self.only_self_att = only_self_att
|
268 |
+
self.relative_position = relative_position
|
269 |
+
self.causal_attention = causal_attention
|
270 |
+
self.causal_block_size = causal_block_size
|
271 |
+
|
272 |
+
self.in_channels = in_channels
|
273 |
+
inner_dim = n_heads * d_head
|
274 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
275 |
+
self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
276 |
+
if not use_linear:
|
277 |
+
self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
278 |
+
else:
|
279 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
280 |
+
|
281 |
+
if relative_position:
|
282 |
+
assert(temporal_length is not None)
|
283 |
+
attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
|
284 |
+
else:
|
285 |
+
attention_cls = partial(CrossAttention, temporal_length=temporal_length)
|
286 |
+
if self.causal_attention:
|
287 |
+
assert(temporal_length is not None)
|
288 |
+
self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
|
289 |
+
|
290 |
+
if self.only_self_att:
|
291 |
+
context_dim = None
|
292 |
+
self.transformer_blocks = nn.ModuleList([
|
293 |
+
BasicTransformerBlock(
|
294 |
+
inner_dim,
|
295 |
+
n_heads,
|
296 |
+
d_head,
|
297 |
+
dropout=dropout,
|
298 |
+
context_dim=context_dim,
|
299 |
+
attention_cls=attention_cls,
|
300 |
+
checkpoint=use_checkpoint) for d in range(depth)
|
301 |
+
])
|
302 |
+
if not use_linear:
|
303 |
+
self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
304 |
+
else:
|
305 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
306 |
+
self.use_linear = use_linear
|
307 |
+
|
308 |
+
def forward(self, x, context=None):
|
309 |
+
b, c, t, h, w = x.shape
|
310 |
+
x_in = x
|
311 |
+
x = self.norm(x)
|
312 |
+
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
|
313 |
+
if not self.use_linear:
|
314 |
+
x = self.proj_in(x)
|
315 |
+
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
|
316 |
+
if self.use_linear:
|
317 |
+
x = self.proj_in(x)
|
318 |
+
|
319 |
+
temp_mask = None
|
320 |
+
if self.causal_attention:
|
321 |
+
# slice the from mask map
|
322 |
+
temp_mask = self.mask[:,:t,:t].to(x.device)
|
323 |
+
|
324 |
+
if temp_mask is not None:
|
325 |
+
mask = temp_mask.to(x.device)
|
326 |
+
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
|
327 |
+
else:
|
328 |
+
mask = None
|
329 |
+
|
330 |
+
if self.only_self_att:
|
331 |
+
## note: if no context is given, cross-attention defaults to self-attention
|
332 |
+
for i, block in enumerate(self.transformer_blocks):
|
333 |
+
x = block(x, mask=mask)
|
334 |
+
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
335 |
+
else:
|
336 |
+
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
337 |
+
context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
|
338 |
+
for i, block in enumerate(self.transformer_blocks):
|
339 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
340 |
+
for j in range(b):
|
341 |
+
context_j = repeat(
|
342 |
+
context[j],
|
343 |
+
't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
|
344 |
+
## note: causal mask will not applied in cross-attention case
|
345 |
+
x[j] = block(x[j], context=context_j)
|
346 |
+
|
347 |
+
if self.use_linear:
|
348 |
+
x = self.proj_out(x)
|
349 |
+
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
|
350 |
+
if not self.use_linear:
|
351 |
+
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
|
352 |
+
x = self.proj_out(x)
|
353 |
+
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
|
354 |
+
|
355 |
+
return x + x_in
|
356 |
+
|
357 |
+
|
358 |
+
class GEGLU(nn.Module):
|
359 |
+
def __init__(self, dim_in, dim_out):
|
360 |
+
super().__init__()
|
361 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
365 |
+
return x * F.gelu(gate)
|
366 |
+
|
367 |
+
|
368 |
+
class FeedForward(nn.Module):
|
369 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
370 |
+
super().__init__()
|
371 |
+
inner_dim = int(dim * mult)
|
372 |
+
dim_out = default(dim_out, dim)
|
373 |
+
project_in = nn.Sequential(
|
374 |
+
nn.Linear(dim, inner_dim),
|
375 |
+
nn.GELU()
|
376 |
+
) if not glu else GEGLU(dim, inner_dim)
|
377 |
+
|
378 |
+
self.net = nn.Sequential(
|
379 |
+
project_in,
|
380 |
+
nn.Dropout(dropout),
|
381 |
+
nn.Linear(inner_dim, dim_out)
|
382 |
+
)
|
383 |
+
|
384 |
+
def forward(self, x):
|
385 |
+
return self.net(x)
|
diffusers_vdm/basics.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import einops
|
14 |
+
|
15 |
+
from inspect import isfunction
|
16 |
+
|
17 |
+
|
18 |
+
def zero_module(module):
|
19 |
+
"""
|
20 |
+
Zero out the parameters of a module and return it.
|
21 |
+
"""
|
22 |
+
for p in module.parameters():
|
23 |
+
p.detach().zero_()
|
24 |
+
return module
|
25 |
+
|
26 |
+
def scale_module(module, scale):
|
27 |
+
"""
|
28 |
+
Scale the parameters of a module and return it.
|
29 |
+
"""
|
30 |
+
for p in module.parameters():
|
31 |
+
p.detach().mul_(scale)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
def conv_nd(dims, *args, **kwargs):
|
36 |
+
"""
|
37 |
+
Create a 1D, 2D, or 3D convolution module.
|
38 |
+
"""
|
39 |
+
if dims == 1:
|
40 |
+
return nn.Conv1d(*args, **kwargs)
|
41 |
+
elif dims == 2:
|
42 |
+
return nn.Conv2d(*args, **kwargs)
|
43 |
+
elif dims == 3:
|
44 |
+
return nn.Conv3d(*args, **kwargs)
|
45 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
46 |
+
|
47 |
+
|
48 |
+
def linear(*args, **kwargs):
|
49 |
+
"""
|
50 |
+
Create a linear module.
|
51 |
+
"""
|
52 |
+
return nn.Linear(*args, **kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
56 |
+
"""
|
57 |
+
Create a 1D, 2D, or 3D average pooling module.
|
58 |
+
"""
|
59 |
+
if dims == 1:
|
60 |
+
return nn.AvgPool1d(*args, **kwargs)
|
61 |
+
elif dims == 2:
|
62 |
+
return nn.AvgPool2d(*args, **kwargs)
|
63 |
+
elif dims == 3:
|
64 |
+
return nn.AvgPool3d(*args, **kwargs)
|
65 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
66 |
+
|
67 |
+
|
68 |
+
def nonlinearity(type='silu'):
|
69 |
+
if type == 'silu':
|
70 |
+
return nn.SiLU()
|
71 |
+
elif type == 'leaky_relu':
|
72 |
+
return nn.LeakyReLU()
|
73 |
+
|
74 |
+
|
75 |
+
def normalization(channels, num_groups=32):
|
76 |
+
"""
|
77 |
+
Make a standard normalization layer.
|
78 |
+
:param channels: number of input channels.
|
79 |
+
:return: an nn.Module for normalization.
|
80 |
+
"""
|
81 |
+
return nn.GroupNorm(num_groups, channels)
|
82 |
+
|
83 |
+
|
84 |
+
def default(val, d):
|
85 |
+
if exists(val):
|
86 |
+
return val
|
87 |
+
return d() if isfunction(d) else d
|
88 |
+
|
89 |
+
|
90 |
+
def exists(val):
|
91 |
+
return val is not None
|
92 |
+
|
93 |
+
|
94 |
+
def extract_into_tensor(a, t, x_shape):
|
95 |
+
b, *_ = t.shape
|
96 |
+
out = a.gather(-1, t)
|
97 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
98 |
+
|
99 |
+
|
100 |
+
def make_temporal_window(x, t, method):
|
101 |
+
assert method in ['roll', 'prv', 'first']
|
102 |
+
|
103 |
+
if method == 'roll':
|
104 |
+
m = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
105 |
+
l = torch.roll(m, shifts=1, dims=1)
|
106 |
+
r = torch.roll(m, shifts=-1, dims=1)
|
107 |
+
|
108 |
+
recon = torch.cat([l, m, r], dim=2)
|
109 |
+
del l, m, r
|
110 |
+
|
111 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
112 |
+
return recon
|
113 |
+
|
114 |
+
if method == 'prv':
|
115 |
+
x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
116 |
+
prv = torch.cat([x[:, :1], x[:, :-1]], dim=1)
|
117 |
+
|
118 |
+
recon = torch.cat([x, prv], dim=2)
|
119 |
+
del x, prv
|
120 |
+
|
121 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
122 |
+
return recon
|
123 |
+
|
124 |
+
if method == 'first':
|
125 |
+
x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
|
126 |
+
prv = x[:, [0], :, :].repeat(1, t, 1, 1)
|
127 |
+
|
128 |
+
recon = torch.cat([x, prv], dim=2)
|
129 |
+
del x, prv
|
130 |
+
|
131 |
+
recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
|
132 |
+
return recon
|
133 |
+
|
134 |
+
|
135 |
+
def checkpoint(func, inputs, params, flag):
|
136 |
+
"""
|
137 |
+
Evaluate a function without caching intermediate activations, allowing for
|
138 |
+
reduced memory at the expense of extra compute in the backward pass.
|
139 |
+
:param func: the function to evaluate.
|
140 |
+
:param inputs: the argument sequence to pass to `func`.
|
141 |
+
:param params: a sequence of parameters `func` depends on but does not
|
142 |
+
explicitly take as arguments.
|
143 |
+
:param flag: if False, disable gradient checkpointing.
|
144 |
+
"""
|
145 |
+
if flag:
|
146 |
+
return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False)
|
147 |
+
else:
|
148 |
+
return func(*inputs)
|
diffusers_vdm/dynamic_tsnr_sampler.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# everything that can improve v-prediction model
|
2 |
+
# dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
|
3 |
+
# written by lvmin at stanford 2024
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from functools import partial
|
10 |
+
from diffusers_vdm.basics import extract_into_tensor
|
11 |
+
|
12 |
+
|
13 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
14 |
+
|
15 |
+
|
16 |
+
def rescale_zero_terminal_snr(betas):
|
17 |
+
# Convert betas to alphas_bar_sqrt
|
18 |
+
alphas = 1.0 - betas
|
19 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
20 |
+
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
|
21 |
+
|
22 |
+
# Store old values.
|
23 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
|
24 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
|
25 |
+
|
26 |
+
# Shift so the last timestep is zero.
|
27 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
28 |
+
|
29 |
+
# Scale so the first timestep is back to the old value.
|
30 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
31 |
+
|
32 |
+
# Convert alphas_bar_sqrt to betas
|
33 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
34 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
35 |
+
alphas = np.concatenate([alphas_bar[0:1], alphas])
|
36 |
+
betas = 1 - alphas
|
37 |
+
|
38 |
+
return betas
|
39 |
+
|
40 |
+
|
41 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
42 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
43 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
44 |
+
|
45 |
+
# rescale the results from guidance (fixes overexposure)
|
46 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
47 |
+
|
48 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
49 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
50 |
+
|
51 |
+
return noise_cfg
|
52 |
+
|
53 |
+
|
54 |
+
class SamplerDynamicTSNR(torch.nn.Module):
|
55 |
+
@torch.no_grad()
|
56 |
+
def __init__(self, unet, terminal_scale=0.7):
|
57 |
+
super().__init__()
|
58 |
+
self.unet = unet
|
59 |
+
|
60 |
+
self.is_v = True
|
61 |
+
self.n_timestep = 1000
|
62 |
+
self.guidance_rescale = 0.7
|
63 |
+
|
64 |
+
linear_start = 0.00085
|
65 |
+
linear_end = 0.012
|
66 |
+
|
67 |
+
betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
|
68 |
+
betas = rescale_zero_terminal_snr(betas)
|
69 |
+
alphas = 1. - betas
|
70 |
+
|
71 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
72 |
+
|
73 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
|
74 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
|
75 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))
|
76 |
+
|
77 |
+
# Dynamic TSNR
|
78 |
+
turning_step = 400
|
79 |
+
scale_arr = np.concatenate([
|
80 |
+
np.linspace(1.0, terminal_scale, turning_step),
|
81 |
+
np.full(self.n_timestep - turning_step, terminal_scale)
|
82 |
+
])
|
83 |
+
self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))
|
84 |
+
|
85 |
+
def predict_eps_from_z_and_v(self, x_t, t, v):
|
86 |
+
return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t
|
87 |
+
|
88 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
89 |
+
return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v
|
90 |
+
|
91 |
+
def q_sample(self, x0, t, noise):
|
92 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
93 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
94 |
+
|
95 |
+
def get_v(self, x0, t, noise):
|
96 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
|
97 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)
|
98 |
+
|
99 |
+
def dynamic_x0_rescale(self, x0, t):
|
100 |
+
return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def get_ground_truth(self, x0, noise, t):
|
104 |
+
x0 = self.dynamic_x0_rescale(x0, t)
|
105 |
+
xt = self.q_sample(x0, t, noise)
|
106 |
+
target = self.get_v(x0, t, noise) if self.is_v else noise
|
107 |
+
return xt, target
|
108 |
+
|
109 |
+
def get_uniform_trailing_steps(self, steps):
|
110 |
+
c = self.n_timestep / steps
|
111 |
+
ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
|
112 |
+
steps_out = ddim_timesteps - 1
|
113 |
+
return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
|
117 |
+
bar = tqdm if progress_tqdm is None else progress_tqdm
|
118 |
+
|
119 |
+
eta = 1.0
|
120 |
+
|
121 |
+
timesteps = self.get_uniform_trailing_steps(steps)
|
122 |
+
timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))
|
123 |
+
|
124 |
+
x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)
|
125 |
+
|
126 |
+
alphas = self.alphas_cumprod[timesteps]
|
127 |
+
alphas_prev = self.alphas_cumprod[timesteps_prev]
|
128 |
+
scale_arr = self.scale_arr[timesteps]
|
129 |
+
scale_arr_prev = self.scale_arr[timesteps_prev]
|
130 |
+
|
131 |
+
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
132 |
+
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
133 |
+
|
134 |
+
s_in = x.new_ones((x.shape[0]))
|
135 |
+
s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
|
136 |
+
for i in bar(range(len(timesteps))):
|
137 |
+
index = len(timesteps) - 1 - i
|
138 |
+
t = timesteps[index].item()
|
139 |
+
|
140 |
+
model_output = self.model_apply(x, t * s_in, **extra_args)
|
141 |
+
|
142 |
+
if self.is_v:
|
143 |
+
e_t = self.predict_eps_from_z_and_v(x, t, model_output)
|
144 |
+
else:
|
145 |
+
e_t = model_output
|
146 |
+
|
147 |
+
a_prev = alphas_prev[index].item() * s_x
|
148 |
+
sigma_t = sigmas[index].item() * s_x
|
149 |
+
|
150 |
+
if self.is_v:
|
151 |
+
pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
|
152 |
+
else:
|
153 |
+
a_t = alphas[index].item() * s_x
|
154 |
+
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
155 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
156 |
+
|
157 |
+
# dynamic rescale
|
158 |
+
scale_t = scale_arr[index].item() * s_x
|
159 |
+
prev_scale_t = scale_arr_prev[index].item() * s_x
|
160 |
+
rescale = (prev_scale_t / scale_t)
|
161 |
+
pred_x0 = pred_x0 * rescale
|
162 |
+
|
163 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
164 |
+
noise = sigma_t * torch.randn_like(x)
|
165 |
+
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def model_apply(self, x, t, **extra_args):
|
171 |
+
x = x.to(device=self.unet.device, dtype=self.unet.dtype)
|
172 |
+
cfg_scale = extra_args['cfg_scale']
|
173 |
+
p = self.unet(x, t, **extra_args['positive'])
|
174 |
+
n = self.unet(x, t, **extra_args['negative'])
|
175 |
+
o = n + cfg_scale * (p - n)
|
176 |
+
o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
|
177 |
+
return o_better
|
diffusers_vdm/improved_clip_vision.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
|
2 |
+
# The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import types
|
6 |
+
import einops
|
7 |
+
|
8 |
+
from abc import ABCMeta
|
9 |
+
from transformers import CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
|
12 |
+
def preprocess(image):
|
13 |
+
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
|
14 |
+
std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]
|
15 |
+
|
16 |
+
scale = 16 / min(image.shape[2], image.shape[3])
|
17 |
+
image = torch.nn.functional.interpolate(
|
18 |
+
image,
|
19 |
+
size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
|
20 |
+
mode="bicubic",
|
21 |
+
antialias=True
|
22 |
+
)
|
23 |
+
|
24 |
+
return (image - mean) / std
|
25 |
+
|
26 |
+
|
27 |
+
def arbitrary_positional_encoding(p, H, W):
|
28 |
+
weight = p.weight
|
29 |
+
cls = weight[:1]
|
30 |
+
pos = weight[1:]
|
31 |
+
pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
|
32 |
+
pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
|
33 |
+
pos = einops.rearrange(pos, '1 C H W -> (H W) C')
|
34 |
+
weight = torch.cat([cls, pos])[None]
|
35 |
+
return weight
|
36 |
+
|
37 |
+
|
38 |
+
def improved_clipvision_embedding_forward(self, pixel_values):
|
39 |
+
pixel_values = pixel_values * 0.5 + 0.5
|
40 |
+
pixel_values = preprocess(pixel_values)
|
41 |
+
batch_size = pixel_values.shape[0]
|
42 |
+
target_dtype = self.patch_embedding.weight.dtype
|
43 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
44 |
+
B, C, H, W = patch_embeds.shape
|
45 |
+
patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
|
46 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
47 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
48 |
+
embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
|
49 |
+
return embeddings
|
50 |
+
|
51 |
+
|
52 |
+
class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__(config)
|
55 |
+
self.vision_model.embeddings.forward = types.MethodType(
|
56 |
+
improved_clipvision_embedding_forward,
|
57 |
+
self.vision_model.embeddings
|
58 |
+
)
|
diffusers_vdm/pipeline.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import einops
|
4 |
+
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
+
from diffusers_vdm.vae import VideoAutoencoderKL
|
9 |
+
from diffusers_vdm.projection import Resampler
|
10 |
+
from diffusers_vdm.unet import UNet3DModel
|
11 |
+
from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection
|
12 |
+
from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR
|
13 |
+
|
14 |
+
|
15 |
+
class LatentVideoDiffusionPipeline(DiffusionPipeline):
|
16 |
+
def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.loading_components = dict(
|
20 |
+
vae=vae,
|
21 |
+
text_encoder=text_encoder,
|
22 |
+
tokenizer=tokenizer,
|
23 |
+
unet=unet,
|
24 |
+
image_encoder=image_encoder,
|
25 |
+
image_projection=image_projection
|
26 |
+
)
|
27 |
+
|
28 |
+
for k, v in self.loading_components.items():
|
29 |
+
setattr(self, k, v)
|
30 |
+
|
31 |
+
if fp16:
|
32 |
+
self.vae.half()
|
33 |
+
self.text_encoder.half()
|
34 |
+
self.unet.half()
|
35 |
+
self.image_encoder.half()
|
36 |
+
self.image_projection.half()
|
37 |
+
|
38 |
+
self.vae.requires_grad_(False)
|
39 |
+
self.text_encoder.requires_grad_(False)
|
40 |
+
self.image_encoder.requires_grad_(False)
|
41 |
+
|
42 |
+
self.vae.eval()
|
43 |
+
self.text_encoder.eval()
|
44 |
+
self.image_encoder.eval()
|
45 |
+
|
46 |
+
if eval:
|
47 |
+
self.unet.eval()
|
48 |
+
self.image_projection.eval()
|
49 |
+
else:
|
50 |
+
self.unet.train()
|
51 |
+
self.image_projection.train()
|
52 |
+
|
53 |
+
def to(self, *args, **kwargs):
|
54 |
+
for k, v in self.loading_components.items():
|
55 |
+
if hasattr(v, 'to'):
|
56 |
+
v.to(*args, **kwargs)
|
57 |
+
return self
|
58 |
+
|
59 |
+
def save_pretrained(self, save_directory, **kwargs):
|
60 |
+
for k, v in self.loading_components.items():
|
61 |
+
folder = os.path.join(save_directory, k)
|
62 |
+
os.makedirs(folder, exist_ok=True)
|
63 |
+
v.save_pretrained(folder)
|
64 |
+
return
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None):
|
68 |
+
local_folder = snapshot_download(repo_id=repo_id, token=token)
|
69 |
+
return cls(
|
70 |
+
tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")),
|
71 |
+
text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")),
|
72 |
+
image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")),
|
73 |
+
vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")),
|
74 |
+
image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")),
|
75 |
+
unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")),
|
76 |
+
fp16=fp16,
|
77 |
+
eval=eval
|
78 |
+
)
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
def encode_cropped_prompt_77tokens(self, prompt: str):
|
82 |
+
cond_ids = self.tokenizer(prompt,
|
83 |
+
padding="max_length",
|
84 |
+
max_length=self.tokenizer.model_max_length,
|
85 |
+
truncation=True,
|
86 |
+
return_tensors="pt").input_ids.to(self.text_encoder.device)
|
87 |
+
cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
88 |
+
return cond
|
89 |
+
|
90 |
+
@torch.inference_mode()
|
91 |
+
def encode_clip_vision(self, frames):
|
92 |
+
b, c, t, h, w = frames.shape
|
93 |
+
frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w')
|
94 |
+
clipvision_embed = self.image_encoder(frames).last_hidden_state
|
95 |
+
clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t)
|
96 |
+
return clipvision_embed
|
97 |
+
|
98 |
+
@torch.inference_mode()
|
99 |
+
def encode_latents(self, videos, return_hidden_states=True):
|
100 |
+
b, c, t, h, w = videos.shape
|
101 |
+
x = einops.rearrange(videos, 'b c t h w -> (b t) c h w')
|
102 |
+
encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states)
|
103 |
+
z = encoder_posterior.mode() * self.vae.scale_factor
|
104 |
+
z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
105 |
+
|
106 |
+
if not return_hidden_states:
|
107 |
+
return z
|
108 |
+
|
109 |
+
hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states]
|
110 |
+
hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last
|
111 |
+
|
112 |
+
return z, hidden_states
|
113 |
+
|
114 |
+
@torch.inference_mode()
|
115 |
+
def decode_latents(self, latents, hidden_states):
|
116 |
+
B, C, T, H, W = latents.shape
|
117 |
+
latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w')
|
118 |
+
latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor
|
119 |
+
pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T)
|
120 |
+
pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T)
|
121 |
+
return pixels
|
122 |
+
|
123 |
+
@torch.inference_mode()
|
124 |
+
def __call__(
|
125 |
+
self,
|
126 |
+
batch_size: int = 1,
|
127 |
+
steps: int = 50,
|
128 |
+
guidance_scale: float = 5.0,
|
129 |
+
positive_text_cond = None,
|
130 |
+
negative_text_cond = None,
|
131 |
+
positive_image_cond = None,
|
132 |
+
negative_image_cond = None,
|
133 |
+
concat_cond = None,
|
134 |
+
fs = 3,
|
135 |
+
progress_tqdm = None,
|
136 |
+
):
|
137 |
+
unet_is_training = self.unet.training
|
138 |
+
|
139 |
+
if unet_is_training:
|
140 |
+
self.unet.eval()
|
141 |
+
|
142 |
+
device = self.unet.device
|
143 |
+
dtype = self.unet.dtype
|
144 |
+
dynamic_tsnr_model = SamplerDynamicTSNR(self.unet)
|
145 |
+
|
146 |
+
# Batch
|
147 |
+
|
148 |
+
concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w
|
149 |
+
positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
|
150 |
+
negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
|
151 |
+
positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c
|
152 |
+
negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond)
|
153 |
+
|
154 |
+
if isinstance(fs, torch.Tensor):
|
155 |
+
fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b
|
156 |
+
else:
|
157 |
+
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b
|
158 |
+
|
159 |
+
# Initial latents
|
160 |
+
|
161 |
+
latent_shape = concat_cond.shape
|
162 |
+
|
163 |
+
# Feeds
|
164 |
+
|
165 |
+
sampler_kwargs = dict(
|
166 |
+
cfg_scale=guidance_scale,
|
167 |
+
positive=dict(
|
168 |
+
context_text=positive_text_cond,
|
169 |
+
context_img=positive_image_cond,
|
170 |
+
fs=fs,
|
171 |
+
concat_cond=concat_cond
|
172 |
+
),
|
173 |
+
negative=dict(
|
174 |
+
context_text=negative_text_cond,
|
175 |
+
context_img=negative_image_cond,
|
176 |
+
fs=fs,
|
177 |
+
concat_cond=concat_cond
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
# Sample
|
182 |
+
|
183 |
+
results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
|
184 |
+
|
185 |
+
if unet_is_training:
|
186 |
+
self.unet.train()
|
187 |
+
|
188 |
+
return results
|
diffusers_vdm/projection.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
4 |
+
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import einops
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
|
13 |
+
|
14 |
+
class ImageProjModel(nn.Module):
|
15 |
+
"""Projection Model"""
|
16 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
17 |
+
super().__init__()
|
18 |
+
self.cross_attention_dim = cross_attention_dim
|
19 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
20 |
+
self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
21 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
22 |
+
|
23 |
+
def forward(self, image_embeds):
|
24 |
+
#embeds = image_embeds
|
25 |
+
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
|
26 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
27 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
28 |
+
return clip_extra_context_tokens
|
29 |
+
|
30 |
+
|
31 |
+
# FFN
|
32 |
+
def FeedForward(dim, mult=4):
|
33 |
+
inner_dim = int(dim * mult)
|
34 |
+
return nn.Sequential(
|
35 |
+
nn.LayerNorm(dim),
|
36 |
+
nn.Linear(dim, inner_dim, bias=False),
|
37 |
+
nn.GELU(),
|
38 |
+
nn.Linear(inner_dim, dim, bias=False),
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def reshape_tensor(x, heads):
|
43 |
+
bs, length, width = x.shape
|
44 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
45 |
+
x = x.view(bs, length, heads, -1)
|
46 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
47 |
+
x = x.transpose(1, 2)
|
48 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
49 |
+
x = x.reshape(bs, heads, length, -1)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class PerceiverAttention(nn.Module):
|
54 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
55 |
+
super().__init__()
|
56 |
+
self.scale = dim_head**-0.5
|
57 |
+
self.dim_head = dim_head
|
58 |
+
self.heads = heads
|
59 |
+
inner_dim = dim_head * heads
|
60 |
+
|
61 |
+
self.norm1 = nn.LayerNorm(dim)
|
62 |
+
self.norm2 = nn.LayerNorm(dim)
|
63 |
+
|
64 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
65 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
66 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
67 |
+
|
68 |
+
|
69 |
+
def forward(self, x, latents):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
x (torch.Tensor): image features
|
73 |
+
shape (b, n1, D)
|
74 |
+
latent (torch.Tensor): latent features
|
75 |
+
shape (b, n2, D)
|
76 |
+
"""
|
77 |
+
x = self.norm1(x)
|
78 |
+
latents = self.norm2(latents)
|
79 |
+
|
80 |
+
b, l, _ = latents.shape
|
81 |
+
|
82 |
+
q = self.to_q(latents)
|
83 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
84 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
85 |
+
|
86 |
+
q = reshape_tensor(q, self.heads)
|
87 |
+
k = reshape_tensor(k, self.heads)
|
88 |
+
v = reshape_tensor(v, self.heads)
|
89 |
+
|
90 |
+
# attention
|
91 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
92 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
93 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
94 |
+
out = weight @ v
|
95 |
+
|
96 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
97 |
+
|
98 |
+
return self.to_out(out)
|
99 |
+
|
100 |
+
|
101 |
+
class Resampler(nn.Module, PyTorchModelHubMixin):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
dim=1024,
|
105 |
+
depth=8,
|
106 |
+
dim_head=64,
|
107 |
+
heads=16,
|
108 |
+
num_queries=8,
|
109 |
+
embedding_dim=768,
|
110 |
+
output_dim=1024,
|
111 |
+
ff_mult=4,
|
112 |
+
video_length=16,
|
113 |
+
input_frames_length=2,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.num_queries = num_queries
|
117 |
+
self.video_length = video_length
|
118 |
+
|
119 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries * video_length, dim) / dim**0.5)
|
120 |
+
self.input_pos = nn.Parameter(torch.zeros(1, input_frames_length, 1, embedding_dim))
|
121 |
+
|
122 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
123 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
124 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
125 |
+
|
126 |
+
self.layers = nn.ModuleList([])
|
127 |
+
for _ in range(depth):
|
128 |
+
self.layers.append(
|
129 |
+
nn.ModuleList(
|
130 |
+
[
|
131 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
132 |
+
FeedForward(dim=dim, mult=ff_mult),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
139 |
+
|
140 |
+
x = x + self.input_pos
|
141 |
+
x = einops.rearrange(x, 'b ti d c -> b (ti d) c')
|
142 |
+
x = self.proj_in(x)
|
143 |
+
|
144 |
+
for attn, ff in self.layers:
|
145 |
+
latents = attn(x, latents) + latents
|
146 |
+
latents = ff(latents) + latents
|
147 |
+
|
148 |
+
latents = self.proj_out(latents)
|
149 |
+
latents = self.norm_out(latents)
|
150 |
+
|
151 |
+
latents = einops.rearrange(latents, 'b (to l) c -> b to l c', to=self.video_length)
|
152 |
+
return latents
|
153 |
+
|
154 |
+
@property
|
155 |
+
def device(self):
|
156 |
+
return next(self.parameters()).device
|
157 |
+
|
158 |
+
@property
|
159 |
+
def dtype(self):
|
160 |
+
return next(self.parameters()).dtype
|
diffusers_vdm/unet.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/AILab-CVC/VideoCrafter
|
2 |
+
# https://github.com/Doubiiu/DynamiCrafter
|
3 |
+
# https://github.com/ToonCrafter/ToonCrafter
|
4 |
+
# Then edited by lllyasviel
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from abc import abstractmethod
|
8 |
+
import torch
|
9 |
+
import math
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from diffusers_vdm.basics import checkpoint
|
14 |
+
from diffusers_vdm.basics import (
|
15 |
+
zero_module,
|
16 |
+
conv_nd,
|
17 |
+
linear,
|
18 |
+
avg_pool_nd,
|
19 |
+
normalization
|
20 |
+
)
|
21 |
+
from diffusers_vdm.attention import SpatialTransformer, TemporalTransformer
|
22 |
+
from huggingface_hub import PyTorchModelHubMixin
|
23 |
+
|
24 |
+
|
25 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
26 |
+
"""
|
27 |
+
Create sinusoidal timestep embeddings.
|
28 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
29 |
+
These may be fractional.
|
30 |
+
:param dim: the dimension of the output.
|
31 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
32 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
33 |
+
"""
|
34 |
+
if not repeat_only:
|
35 |
+
half = dim // 2
|
36 |
+
freqs = torch.exp(
|
37 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
38 |
+
).to(device=timesteps.device)
|
39 |
+
args = timesteps[:, None].float() * freqs[None]
|
40 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
41 |
+
if dim % 2:
|
42 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
43 |
+
else:
|
44 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
45 |
+
return embedding
|
46 |
+
|
47 |
+
|
48 |
+
class TimestepBlock(nn.Module):
|
49 |
+
"""
|
50 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
51 |
+
"""
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def forward(self, x, emb):
|
55 |
+
"""
|
56 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
57 |
+
"""
|
58 |
+
|
59 |
+
|
60 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
61 |
+
"""
|
62 |
+
A sequential module that passes timestep embeddings to the children that
|
63 |
+
support it as an extra input.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def forward(self, x, emb, context=None, batch_size=None):
|
67 |
+
for layer in self:
|
68 |
+
if isinstance(layer, TimestepBlock):
|
69 |
+
x = layer(x, emb, batch_size=batch_size)
|
70 |
+
elif isinstance(layer, SpatialTransformer):
|
71 |
+
x = layer(x, context)
|
72 |
+
elif isinstance(layer, TemporalTransformer):
|
73 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
|
74 |
+
x = layer(x, context)
|
75 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
76 |
+
else:
|
77 |
+
x = layer(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Downsample(nn.Module):
|
82 |
+
"""
|
83 |
+
A downsampling layer with an optional convolution.
|
84 |
+
:param channels: channels in the inputs and outputs.
|
85 |
+
:param use_conv: a bool determining if a convolution is applied.
|
86 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
87 |
+
downsampling occurs in the inner-two dimensions.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
91 |
+
super().__init__()
|
92 |
+
self.channels = channels
|
93 |
+
self.out_channels = out_channels or channels
|
94 |
+
self.use_conv = use_conv
|
95 |
+
self.dims = dims
|
96 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
97 |
+
if use_conv:
|
98 |
+
self.op = conv_nd(
|
99 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
assert self.channels == self.out_channels
|
103 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
assert x.shape[1] == self.channels
|
107 |
+
return self.op(x)
|
108 |
+
|
109 |
+
|
110 |
+
class Upsample(nn.Module):
|
111 |
+
"""
|
112 |
+
An upsampling layer with an optional convolution.
|
113 |
+
:param channels: channels in the inputs and outputs.
|
114 |
+
:param use_conv: a bool determining if a convolution is applied.
|
115 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
116 |
+
upsampling occurs in the inner-two dimensions.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
120 |
+
super().__init__()
|
121 |
+
self.channels = channels
|
122 |
+
self.out_channels = out_channels or channels
|
123 |
+
self.use_conv = use_conv
|
124 |
+
self.dims = dims
|
125 |
+
if use_conv:
|
126 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
assert x.shape[1] == self.channels
|
130 |
+
if self.dims == 3:
|
131 |
+
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
|
132 |
+
else:
|
133 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
134 |
+
if self.use_conv:
|
135 |
+
x = self.conv(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class ResBlock(TimestepBlock):
|
140 |
+
"""
|
141 |
+
A residual block that can optionally change the number of channels.
|
142 |
+
:param channels: the number of input channels.
|
143 |
+
:param emb_channels: the number of timestep embedding channels.
|
144 |
+
:param dropout: the rate of dropout.
|
145 |
+
:param out_channels: if specified, the number of out channels.
|
146 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
147 |
+
convolution instead of a smaller 1x1 convolution to change the
|
148 |
+
channels in the skip connection.
|
149 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
150 |
+
:param up: if True, use this block for upsampling.
|
151 |
+
:param down: if True, use this block for downsampling.
|
152 |
+
:param use_temporal_conv: if True, use the temporal convolution.
|
153 |
+
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
channels,
|
159 |
+
emb_channels,
|
160 |
+
dropout,
|
161 |
+
out_channels=None,
|
162 |
+
use_scale_shift_norm=False,
|
163 |
+
dims=2,
|
164 |
+
use_checkpoint=False,
|
165 |
+
use_conv=False,
|
166 |
+
up=False,
|
167 |
+
down=False,
|
168 |
+
use_temporal_conv=False,
|
169 |
+
tempspatial_aware=False
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.channels = channels
|
173 |
+
self.emb_channels = emb_channels
|
174 |
+
self.dropout = dropout
|
175 |
+
self.out_channels = out_channels or channels
|
176 |
+
self.use_conv = use_conv
|
177 |
+
self.use_checkpoint = use_checkpoint
|
178 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
179 |
+
self.use_temporal_conv = use_temporal_conv
|
180 |
+
|
181 |
+
self.in_layers = nn.Sequential(
|
182 |
+
normalization(channels),
|
183 |
+
nn.SiLU(),
|
184 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
185 |
+
)
|
186 |
+
|
187 |
+
self.updown = up or down
|
188 |
+
|
189 |
+
if up:
|
190 |
+
self.h_upd = Upsample(channels, False, dims)
|
191 |
+
self.x_upd = Upsample(channels, False, dims)
|
192 |
+
elif down:
|
193 |
+
self.h_upd = Downsample(channels, False, dims)
|
194 |
+
self.x_upd = Downsample(channels, False, dims)
|
195 |
+
else:
|
196 |
+
self.h_upd = self.x_upd = nn.Identity()
|
197 |
+
|
198 |
+
self.emb_layers = nn.Sequential(
|
199 |
+
nn.SiLU(),
|
200 |
+
nn.Linear(
|
201 |
+
emb_channels,
|
202 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
203 |
+
),
|
204 |
+
)
|
205 |
+
self.out_layers = nn.Sequential(
|
206 |
+
normalization(self.out_channels),
|
207 |
+
nn.SiLU(),
|
208 |
+
nn.Dropout(p=dropout),
|
209 |
+
zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
210 |
+
)
|
211 |
+
|
212 |
+
if self.out_channels == channels:
|
213 |
+
self.skip_connection = nn.Identity()
|
214 |
+
elif use_conv:
|
215 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
216 |
+
else:
|
217 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
218 |
+
|
219 |
+
if self.use_temporal_conv:
|
220 |
+
self.temopral_conv = TemporalConvBlock(
|
221 |
+
self.out_channels,
|
222 |
+
self.out_channels,
|
223 |
+
dropout=0.1,
|
224 |
+
spatial_aware=tempspatial_aware
|
225 |
+
)
|
226 |
+
|
227 |
+
def forward(self, x, emb, batch_size=None):
|
228 |
+
"""
|
229 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
230 |
+
:param x: an [N x C x ...] Tensor of features.
|
231 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
232 |
+
:return: an [N x C x ...] Tensor of outputs.
|
233 |
+
"""
|
234 |
+
input_tuple = (x, emb)
|
235 |
+
if batch_size:
|
236 |
+
forward_batchsize = partial(self._forward, batch_size=batch_size)
|
237 |
+
return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
|
238 |
+
return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
|
239 |
+
|
240 |
+
def _forward(self, x, emb, batch_size=None):
|
241 |
+
if self.updown:
|
242 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
243 |
+
h = in_rest(x)
|
244 |
+
h = self.h_upd(h)
|
245 |
+
x = self.x_upd(x)
|
246 |
+
h = in_conv(h)
|
247 |
+
else:
|
248 |
+
h = self.in_layers(x)
|
249 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
250 |
+
while len(emb_out.shape) < len(h.shape):
|
251 |
+
emb_out = emb_out[..., None]
|
252 |
+
if self.use_scale_shift_norm:
|
253 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
254 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
255 |
+
h = out_norm(h) * (1 + scale) + shift
|
256 |
+
h = out_rest(h)
|
257 |
+
else:
|
258 |
+
h = h + emb_out
|
259 |
+
h = self.out_layers(h)
|
260 |
+
h = self.skip_connection(x) + h
|
261 |
+
|
262 |
+
if self.use_temporal_conv and batch_size:
|
263 |
+
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
|
264 |
+
h = self.temopral_conv(h)
|
265 |
+
h = rearrange(h, 'b c t h w -> (b t) c h w')
|
266 |
+
return h
|
267 |
+
|
268 |
+
|
269 |
+
class TemporalConvBlock(nn.Module):
|
270 |
+
"""
|
271 |
+
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
|
275 |
+
super(TemporalConvBlock, self).__init__()
|
276 |
+
if out_channels is None:
|
277 |
+
out_channels = in_channels
|
278 |
+
self.in_channels = in_channels
|
279 |
+
self.out_channels = out_channels
|
280 |
+
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
|
281 |
+
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
|
282 |
+
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
|
283 |
+
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
|
284 |
+
|
285 |
+
# conv layers
|
286 |
+
self.conv1 = nn.Sequential(
|
287 |
+
nn.GroupNorm(32, in_channels), nn.SiLU(),
|
288 |
+
nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
|
289 |
+
self.conv2 = nn.Sequential(
|
290 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
291 |
+
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
|
292 |
+
self.conv3 = nn.Sequential(
|
293 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
294 |
+
nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
|
295 |
+
self.conv4 = nn.Sequential(
|
296 |
+
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
297 |
+
nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
|
298 |
+
|
299 |
+
# zero out the last layer params,so the conv block is identity
|
300 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
301 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
identity = x
|
305 |
+
x = self.conv1(x)
|
306 |
+
x = self.conv2(x)
|
307 |
+
x = self.conv3(x)
|
308 |
+
x = self.conv4(x)
|
309 |
+
|
310 |
+
return identity + x
|
311 |
+
|
312 |
+
|
313 |
+
class UNet3DModel(nn.Module, PyTorchModelHubMixin):
|
314 |
+
"""
|
315 |
+
The full UNet model with attention and timestep embedding.
|
316 |
+
:param in_channels: in_channels in the input Tensor.
|
317 |
+
:param model_channels: base channel count for the model.
|
318 |
+
:param out_channels: channels in the output Tensor.
|
319 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
320 |
+
:param attention_resolutions: a collection of downsample rates at which
|
321 |
+
attention will take place. May be a set, list, or tuple.
|
322 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
323 |
+
will be used.
|
324 |
+
:param dropout: the dropout probability.
|
325 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
326 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
327 |
+
downsampling.
|
328 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
329 |
+
:param num_classes: if specified (as an int), then this model will be
|
330 |
+
class-conditional with `num_classes` classes.
|
331 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
332 |
+
:param num_heads: the number of attention heads in each attention layer.
|
333 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
334 |
+
a fixed channel width per attention head.
|
335 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
336 |
+
of heads for upsampling. Deprecated.
|
337 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
338 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
339 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
340 |
+
increased efficiency.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(self,
|
344 |
+
in_channels,
|
345 |
+
model_channels,
|
346 |
+
out_channels,
|
347 |
+
num_res_blocks,
|
348 |
+
attention_resolutions,
|
349 |
+
dropout=0.0,
|
350 |
+
channel_mult=(1, 2, 4, 8),
|
351 |
+
conv_resample=True,
|
352 |
+
dims=2,
|
353 |
+
context_dim=None,
|
354 |
+
use_scale_shift_norm=False,
|
355 |
+
resblock_updown=False,
|
356 |
+
num_heads=-1,
|
357 |
+
num_head_channels=-1,
|
358 |
+
transformer_depth=1,
|
359 |
+
use_linear=False,
|
360 |
+
temporal_conv=False,
|
361 |
+
tempspatial_aware=False,
|
362 |
+
temporal_attention=True,
|
363 |
+
use_relative_position=True,
|
364 |
+
use_causal_attention=False,
|
365 |
+
temporal_length=None,
|
366 |
+
addition_attention=False,
|
367 |
+
temporal_selfatt_only=True,
|
368 |
+
image_cross_attention=False,
|
369 |
+
image_cross_attention_scale_learnable=False,
|
370 |
+
default_fs=4,
|
371 |
+
fs_condition=False,
|
372 |
+
):
|
373 |
+
super(UNet3DModel, self).__init__()
|
374 |
+
if num_heads == -1:
|
375 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
376 |
+
if num_head_channels == -1:
|
377 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
378 |
+
|
379 |
+
self.in_channels = in_channels
|
380 |
+
self.model_channels = model_channels
|
381 |
+
self.out_channels = out_channels
|
382 |
+
self.num_res_blocks = num_res_blocks
|
383 |
+
self.attention_resolutions = attention_resolutions
|
384 |
+
self.dropout = dropout
|
385 |
+
self.channel_mult = channel_mult
|
386 |
+
self.conv_resample = conv_resample
|
387 |
+
self.temporal_attention = temporal_attention
|
388 |
+
time_embed_dim = model_channels * 4
|
389 |
+
self.use_checkpoint = use_checkpoint = False # moved to self.enable_gradient_checkpointing()
|
390 |
+
temporal_self_att_only = True
|
391 |
+
self.addition_attention = addition_attention
|
392 |
+
self.temporal_length = temporal_length
|
393 |
+
self.image_cross_attention = image_cross_attention
|
394 |
+
self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
|
395 |
+
self.default_fs = default_fs
|
396 |
+
self.fs_condition = fs_condition
|
397 |
+
|
398 |
+
## Time embedding blocks
|
399 |
+
self.time_embed = nn.Sequential(
|
400 |
+
linear(model_channels, time_embed_dim),
|
401 |
+
nn.SiLU(),
|
402 |
+
linear(time_embed_dim, time_embed_dim),
|
403 |
+
)
|
404 |
+
if fs_condition:
|
405 |
+
self.fps_embedding = nn.Sequential(
|
406 |
+
linear(model_channels, time_embed_dim),
|
407 |
+
nn.SiLU(),
|
408 |
+
linear(time_embed_dim, time_embed_dim),
|
409 |
+
)
|
410 |
+
nn.init.zeros_(self.fps_embedding[-1].weight)
|
411 |
+
nn.init.zeros_(self.fps_embedding[-1].bias)
|
412 |
+
## Input Block
|
413 |
+
self.input_blocks = nn.ModuleList(
|
414 |
+
[
|
415 |
+
TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
416 |
+
]
|
417 |
+
)
|
418 |
+
if self.addition_attention:
|
419 |
+
self.init_attn = TimestepEmbedSequential(
|
420 |
+
TemporalTransformer(
|
421 |
+
model_channels,
|
422 |
+
n_heads=8,
|
423 |
+
d_head=num_head_channels,
|
424 |
+
depth=transformer_depth,
|
425 |
+
context_dim=context_dim,
|
426 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
|
427 |
+
causal_attention=False, relative_position=use_relative_position,
|
428 |
+
temporal_length=temporal_length))
|
429 |
+
|
430 |
+
input_block_chans = [model_channels]
|
431 |
+
ch = model_channels
|
432 |
+
ds = 1
|
433 |
+
for level, mult in enumerate(channel_mult):
|
434 |
+
for _ in range(num_res_blocks):
|
435 |
+
layers = [
|
436 |
+
ResBlock(ch, time_embed_dim, dropout,
|
437 |
+
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
|
438 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
439 |
+
use_temporal_conv=temporal_conv
|
440 |
+
)
|
441 |
+
]
|
442 |
+
ch = mult * model_channels
|
443 |
+
if ds in attention_resolutions:
|
444 |
+
if num_head_channels == -1:
|
445 |
+
dim_head = ch // num_heads
|
446 |
+
else:
|
447 |
+
num_heads = ch // num_head_channels
|
448 |
+
dim_head = num_head_channels
|
449 |
+
layers.append(
|
450 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
451 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
452 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False,
|
453 |
+
video_length=temporal_length,
|
454 |
+
image_cross_attention=self.image_cross_attention,
|
455 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
|
456 |
+
)
|
457 |
+
)
|
458 |
+
if self.temporal_attention:
|
459 |
+
layers.append(
|
460 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
461 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
462 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
463 |
+
causal_attention=use_causal_attention,
|
464 |
+
relative_position=use_relative_position,
|
465 |
+
temporal_length=temporal_length
|
466 |
+
)
|
467 |
+
)
|
468 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
469 |
+
input_block_chans.append(ch)
|
470 |
+
if level != len(channel_mult) - 1:
|
471 |
+
out_ch = ch
|
472 |
+
self.input_blocks.append(
|
473 |
+
TimestepEmbedSequential(
|
474 |
+
ResBlock(ch, time_embed_dim, dropout,
|
475 |
+
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
|
476 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
477 |
+
down=True
|
478 |
+
)
|
479 |
+
if resblock_updown
|
480 |
+
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
481 |
+
)
|
482 |
+
)
|
483 |
+
ch = out_ch
|
484 |
+
input_block_chans.append(ch)
|
485 |
+
ds *= 2
|
486 |
+
|
487 |
+
if num_head_channels == -1:
|
488 |
+
dim_head = ch // num_heads
|
489 |
+
else:
|
490 |
+
num_heads = ch // num_head_channels
|
491 |
+
dim_head = num_head_channels
|
492 |
+
layers = [
|
493 |
+
ResBlock(ch, time_embed_dim, dropout,
|
494 |
+
dims=dims, use_checkpoint=use_checkpoint,
|
495 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
496 |
+
use_temporal_conv=temporal_conv
|
497 |
+
),
|
498 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
499 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
500 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
|
501 |
+
image_cross_attention=self.image_cross_attention,
|
502 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
|
503 |
+
)
|
504 |
+
]
|
505 |
+
if self.temporal_attention:
|
506 |
+
layers.append(
|
507 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
508 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
509 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
510 |
+
causal_attention=use_causal_attention, relative_position=use_relative_position,
|
511 |
+
temporal_length=temporal_length
|
512 |
+
)
|
513 |
+
)
|
514 |
+
layers.append(
|
515 |
+
ResBlock(ch, time_embed_dim, dropout,
|
516 |
+
dims=dims, use_checkpoint=use_checkpoint,
|
517 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
518 |
+
use_temporal_conv=temporal_conv
|
519 |
+
)
|
520 |
+
)
|
521 |
+
|
522 |
+
## Middle Block
|
523 |
+
self.middle_block = TimestepEmbedSequential(*layers)
|
524 |
+
|
525 |
+
## Output Block
|
526 |
+
self.output_blocks = nn.ModuleList([])
|
527 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
528 |
+
for i in range(num_res_blocks + 1):
|
529 |
+
ich = input_block_chans.pop()
|
530 |
+
layers = [
|
531 |
+
ResBlock(ch + ich, time_embed_dim, dropout,
|
532 |
+
out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
|
533 |
+
use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
|
534 |
+
use_temporal_conv=temporal_conv
|
535 |
+
)
|
536 |
+
]
|
537 |
+
ch = model_channels * mult
|
538 |
+
if ds in attention_resolutions:
|
539 |
+
if num_head_channels == -1:
|
540 |
+
dim_head = ch // num_heads
|
541 |
+
else:
|
542 |
+
num_heads = ch // num_head_channels
|
543 |
+
dim_head = num_head_channels
|
544 |
+
layers.append(
|
545 |
+
SpatialTransformer(ch, num_heads, dim_head,
|
546 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
547 |
+
use_checkpoint=use_checkpoint, disable_self_attn=False,
|
548 |
+
video_length=temporal_length,
|
549 |
+
image_cross_attention=self.image_cross_attention,
|
550 |
+
image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
|
551 |
+
)
|
552 |
+
)
|
553 |
+
if self.temporal_attention:
|
554 |
+
layers.append(
|
555 |
+
TemporalTransformer(ch, num_heads, dim_head,
|
556 |
+
depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
|
557 |
+
use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
|
558 |
+
causal_attention=use_causal_attention,
|
559 |
+
relative_position=use_relative_position,
|
560 |
+
temporal_length=temporal_length
|
561 |
+
)
|
562 |
+
)
|
563 |
+
if level and i == num_res_blocks:
|
564 |
+
out_ch = ch
|
565 |
+
layers.append(
|
566 |
+
ResBlock(ch, time_embed_dim, dropout,
|
567 |
+
out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
|
568 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
569 |
+
up=True
|
570 |
+
)
|
571 |
+
if resblock_updown
|
572 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
573 |
+
)
|
574 |
+
ds //= 2
|
575 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
576 |
+
|
577 |
+
self.out = nn.Sequential(
|
578 |
+
normalization(ch),
|
579 |
+
nn.SiLU(),
|
580 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
581 |
+
)
|
582 |
+
|
583 |
+
@property
|
584 |
+
def device(self):
|
585 |
+
return next(self.parameters()).device
|
586 |
+
|
587 |
+
@property
|
588 |
+
def dtype(self):
|
589 |
+
return next(self.parameters()).dtype
|
590 |
+
|
591 |
+
def forward(self, x, timesteps, context_text=None, context_img=None, concat_cond=None, fs=None, **kwargs):
|
592 |
+
b, _, t, _, _ = x.shape
|
593 |
+
|
594 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
|
595 |
+
emb = self.time_embed(t_emb)
|
596 |
+
|
597 |
+
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
598 |
+
context_img = rearrange(context_img, 'b t l c -> (b t) l c')
|
599 |
+
|
600 |
+
context = (context_text, context_img)
|
601 |
+
|
602 |
+
emb = emb.repeat_interleave(repeats=t, dim=0)
|
603 |
+
|
604 |
+
if concat_cond is not None:
|
605 |
+
x = torch.cat([x, concat_cond], dim=1)
|
606 |
+
|
607 |
+
## always in shape (b t) c h w, except for temporal layer
|
608 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
609 |
+
|
610 |
+
## combine emb
|
611 |
+
if self.fs_condition:
|
612 |
+
if fs is None:
|
613 |
+
fs = torch.tensor(
|
614 |
+
[self.default_fs] * b, dtype=torch.long, device=x.device)
|
615 |
+
fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
|
616 |
+
|
617 |
+
fs_embed = self.fps_embedding(fs_emb)
|
618 |
+
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
619 |
+
emb = emb + fs_embed
|
620 |
+
|
621 |
+
h = x
|
622 |
+
hs = []
|
623 |
+
for id, module in enumerate(self.input_blocks):
|
624 |
+
h = module(h, emb, context=context, batch_size=b)
|
625 |
+
if id == 0 and self.addition_attention:
|
626 |
+
h = self.init_attn(h, emb, context=context, batch_size=b)
|
627 |
+
hs.append(h)
|
628 |
+
|
629 |
+
h = self.middle_block(h, emb, context=context, batch_size=b)
|
630 |
+
|
631 |
+
for module in self.output_blocks:
|
632 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
633 |
+
h = module(h, emb, context=context, batch_size=b)
|
634 |
+
h = h.type(x.dtype)
|
635 |
+
y = self.out(h)
|
636 |
+
|
637 |
+
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
638 |
+
return y
|
639 |
+
|
640 |
+
def enable_gradient_checkpointing(self, enable=True, verbose=False):
|
641 |
+
for k, v in self.named_modules():
|
642 |
+
if hasattr(v, 'checkpoint'):
|
643 |
+
v.checkpoint = enable
|
644 |
+
if verbose:
|
645 |
+
print(f'{k}.checkpoint = {enable}')
|
646 |
+
if hasattr(v, 'use_checkpoint'):
|
647 |
+
v.use_checkpoint = enable
|
648 |
+
if verbose:
|
649 |
+
print(f'{k}.use_checkpoint = {enable}')
|
650 |
+
return
|
diffusers_vdm/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
|
8 |
+
def resize_and_center_crop(image, target_width, target_height, interpolation=cv2.INTER_AREA):
|
9 |
+
original_height, original_width = image.shape[:2]
|
10 |
+
k = max(target_height / original_height, target_width / original_width)
|
11 |
+
new_width = int(round(original_width * k))
|
12 |
+
new_height = int(round(original_height * k))
|
13 |
+
resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
|
14 |
+
x_start = (new_width - target_width) // 2
|
15 |
+
y_start = (new_height - target_height) // 2
|
16 |
+
cropped_image = resized_image[y_start:y_start + target_height, x_start:x_start + target_width]
|
17 |
+
return cropped_image
|
18 |
+
|
19 |
+
|
20 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10):
|
21 |
+
b, c, t, h, w = x.shape
|
22 |
+
|
23 |
+
per_row = b
|
24 |
+
for p in [6, 5, 4, 3, 2]:
|
25 |
+
if b % p == 0:
|
26 |
+
per_row = p
|
27 |
+
break
|
28 |
+
|
29 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
30 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
31 |
+
x = x.detach().cpu().to(torch.uint8)
|
32 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
33 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '1'})
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def save_bcthw_as_png(x, output_filename):
|
38 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
39 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
40 |
+
x = x.detach().cpu().to(torch.uint8)
|
41 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
42 |
+
torchvision.io.write_png(x, output_filename)
|
43 |
+
return output_filename
|
diffusers_vdm/vae.py
ADDED
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# video VAE with many components from lots of repos
|
2 |
+
# collected by lvmin
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import xformers.ops
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from diffusers_vdm.basics import default, exists, zero_module, conv_nd, linear, normalization
|
11 |
+
from diffusers_vdm.unet import Upsample, Downsample
|
12 |
+
from huggingface_hub import PyTorchModelHubMixin
|
13 |
+
|
14 |
+
|
15 |
+
def chunked_attention(q, k, v, batch_chunk=0):
|
16 |
+
# if batch_chunk > 0 and not torch.is_grad_enabled():
|
17 |
+
# batch_size = q.size(0)
|
18 |
+
# chunks = [slice(i, i + batch_chunk) for i in range(0, batch_size, batch_chunk)]
|
19 |
+
#
|
20 |
+
# out_chunks = []
|
21 |
+
# for chunk in chunks:
|
22 |
+
# q_chunk = q[chunk]
|
23 |
+
# k_chunk = k[chunk]
|
24 |
+
# v_chunk = v[chunk]
|
25 |
+
#
|
26 |
+
# out_chunk = torch.nn.functional.scaled_dot_product_attention(
|
27 |
+
# q_chunk, k_chunk, v_chunk, attn_mask=None
|
28 |
+
# )
|
29 |
+
# out_chunks.append(out_chunk)
|
30 |
+
#
|
31 |
+
# out = torch.cat(out_chunks, dim=0)
|
32 |
+
# else:
|
33 |
+
# out = torch.nn.functional.scaled_dot_product_attention(
|
34 |
+
# q, k, v, attn_mask=None
|
35 |
+
# )
|
36 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def nonlinearity(x):
|
41 |
+
return x * torch.sigmoid(x)
|
42 |
+
|
43 |
+
|
44 |
+
def GroupNorm(in_channels, num_groups=32):
|
45 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
46 |
+
|
47 |
+
|
48 |
+
class DiagonalGaussianDistribution:
|
49 |
+
def __init__(self, parameters, deterministic=False):
|
50 |
+
self.parameters = parameters
|
51 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
52 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
53 |
+
self.deterministic = deterministic
|
54 |
+
self.std = torch.exp(0.5 * self.logvar)
|
55 |
+
self.var = torch.exp(self.logvar)
|
56 |
+
if self.deterministic:
|
57 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
58 |
+
|
59 |
+
def sample(self, noise=None):
|
60 |
+
if noise is None:
|
61 |
+
noise = torch.randn(self.mean.shape)
|
62 |
+
|
63 |
+
x = self.mean + self.std * noise.to(device=self.parameters.device)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def mode(self):
|
67 |
+
return self.mean
|
68 |
+
|
69 |
+
|
70 |
+
class EncoderDownSampleBlock(nn.Module):
|
71 |
+
def __init__(self, in_channels, with_conv):
|
72 |
+
super().__init__()
|
73 |
+
self.with_conv = with_conv
|
74 |
+
self.in_channels = in_channels
|
75 |
+
if self.with_conv:
|
76 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
77 |
+
in_channels,
|
78 |
+
kernel_size=3,
|
79 |
+
stride=2,
|
80 |
+
padding=0)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
if self.with_conv:
|
84 |
+
pad = (0, 1, 0, 1)
|
85 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
86 |
+
x = self.conv(x)
|
87 |
+
else:
|
88 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class ResnetBlock(nn.Module):
|
93 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
94 |
+
dropout, temb_channels=512):
|
95 |
+
super().__init__()
|
96 |
+
self.in_channels = in_channels
|
97 |
+
out_channels = in_channels if out_channels is None else out_channels
|
98 |
+
self.out_channels = out_channels
|
99 |
+
self.use_conv_shortcut = conv_shortcut
|
100 |
+
|
101 |
+
self.norm1 = GroupNorm(in_channels)
|
102 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
if temb_channels > 0:
|
108 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
109 |
+
out_channels)
|
110 |
+
self.norm2 = GroupNorm(out_channels)
|
111 |
+
self.dropout = torch.nn.Dropout(dropout)
|
112 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
113 |
+
out_channels,
|
114 |
+
kernel_size=3,
|
115 |
+
stride=1,
|
116 |
+
padding=1)
|
117 |
+
if self.in_channels != self.out_channels:
|
118 |
+
if self.use_conv_shortcut:
|
119 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
120 |
+
out_channels,
|
121 |
+
kernel_size=3,
|
122 |
+
stride=1,
|
123 |
+
padding=1)
|
124 |
+
else:
|
125 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
126 |
+
out_channels,
|
127 |
+
kernel_size=1,
|
128 |
+
stride=1,
|
129 |
+
padding=0)
|
130 |
+
|
131 |
+
def forward(self, x, temb):
|
132 |
+
h = x
|
133 |
+
h = self.norm1(h)
|
134 |
+
h = nonlinearity(h)
|
135 |
+
h = self.conv1(h)
|
136 |
+
|
137 |
+
if temb is not None:
|
138 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
139 |
+
|
140 |
+
h = self.norm2(h)
|
141 |
+
h = nonlinearity(h)
|
142 |
+
h = self.dropout(h)
|
143 |
+
h = self.conv2(h)
|
144 |
+
|
145 |
+
if self.in_channels != self.out_channels:
|
146 |
+
if self.use_conv_shortcut:
|
147 |
+
x = self.conv_shortcut(x)
|
148 |
+
else:
|
149 |
+
x = self.nin_shortcut(x)
|
150 |
+
|
151 |
+
return x + h
|
152 |
+
|
153 |
+
|
154 |
+
class Encoder(nn.Module):
|
155 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
156 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
157 |
+
resolution, z_channels, double_z=True, **kwargs):
|
158 |
+
super().__init__()
|
159 |
+
self.ch = ch
|
160 |
+
self.temb_ch = 0
|
161 |
+
self.num_resolutions = len(ch_mult)
|
162 |
+
self.num_res_blocks = num_res_blocks
|
163 |
+
self.resolution = resolution
|
164 |
+
self.in_channels = in_channels
|
165 |
+
|
166 |
+
# downsampling
|
167 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
168 |
+
self.ch,
|
169 |
+
kernel_size=3,
|
170 |
+
stride=1,
|
171 |
+
padding=1)
|
172 |
+
|
173 |
+
curr_res = resolution
|
174 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
175 |
+
self.in_ch_mult = in_ch_mult
|
176 |
+
self.down = nn.ModuleList()
|
177 |
+
for i_level in range(self.num_resolutions):
|
178 |
+
block = nn.ModuleList()
|
179 |
+
attn = nn.ModuleList()
|
180 |
+
block_in = ch * in_ch_mult[i_level]
|
181 |
+
block_out = ch * ch_mult[i_level]
|
182 |
+
for i_block in range(self.num_res_blocks):
|
183 |
+
block.append(ResnetBlock(in_channels=block_in,
|
184 |
+
out_channels=block_out,
|
185 |
+
temb_channels=self.temb_ch,
|
186 |
+
dropout=dropout))
|
187 |
+
block_in = block_out
|
188 |
+
if curr_res in attn_resolutions:
|
189 |
+
attn.append(Attention(block_in))
|
190 |
+
down = nn.Module()
|
191 |
+
down.block = block
|
192 |
+
down.attn = attn
|
193 |
+
if i_level != self.num_resolutions - 1:
|
194 |
+
down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv)
|
195 |
+
curr_res = curr_res // 2
|
196 |
+
self.down.append(down)
|
197 |
+
|
198 |
+
# middle
|
199 |
+
self.mid = nn.Module()
|
200 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
201 |
+
out_channels=block_in,
|
202 |
+
temb_channels=self.temb_ch,
|
203 |
+
dropout=dropout)
|
204 |
+
self.mid.attn_1 = Attention(block_in)
|
205 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
206 |
+
out_channels=block_in,
|
207 |
+
temb_channels=self.temb_ch,
|
208 |
+
dropout=dropout)
|
209 |
+
|
210 |
+
# end
|
211 |
+
self.norm_out = GroupNorm(block_in)
|
212 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
213 |
+
2 * z_channels if double_z else z_channels,
|
214 |
+
kernel_size=3,
|
215 |
+
stride=1,
|
216 |
+
padding=1)
|
217 |
+
|
218 |
+
def forward(self, x, return_hidden_states=False):
|
219 |
+
# timestep embedding
|
220 |
+
temb = None
|
221 |
+
|
222 |
+
# print(f'encoder-input={x.shape}')
|
223 |
+
# downsampling
|
224 |
+
hs = [self.conv_in(x)]
|
225 |
+
|
226 |
+
## if we return hidden states for decoder usage, we will store them in a list
|
227 |
+
if return_hidden_states:
|
228 |
+
hidden_states = []
|
229 |
+
# print(f'encoder-conv in feat={hs[0].shape}')
|
230 |
+
for i_level in range(self.num_resolutions):
|
231 |
+
for i_block in range(self.num_res_blocks):
|
232 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
233 |
+
# print(f'encoder-down feat={h.shape}')
|
234 |
+
if len(self.down[i_level].attn) > 0:
|
235 |
+
h = self.down[i_level].attn[i_block](h)
|
236 |
+
hs.append(h)
|
237 |
+
if return_hidden_states:
|
238 |
+
hidden_states.append(h)
|
239 |
+
if i_level != self.num_resolutions - 1:
|
240 |
+
# print(f'encoder-downsample (input)={hs[-1].shape}')
|
241 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
242 |
+
# print(f'encoder-downsample (output)={hs[-1].shape}')
|
243 |
+
if return_hidden_states:
|
244 |
+
hidden_states.append(hs[0])
|
245 |
+
# middle
|
246 |
+
h = hs[-1]
|
247 |
+
h = self.mid.block_1(h, temb)
|
248 |
+
# print(f'encoder-mid1 feat={h.shape}')
|
249 |
+
h = self.mid.attn_1(h)
|
250 |
+
h = self.mid.block_2(h, temb)
|
251 |
+
# print(f'encoder-mid2 feat={h.shape}')
|
252 |
+
|
253 |
+
# end
|
254 |
+
h = self.norm_out(h)
|
255 |
+
h = nonlinearity(h)
|
256 |
+
h = self.conv_out(h)
|
257 |
+
# print(f'end feat={h.shape}')
|
258 |
+
if return_hidden_states:
|
259 |
+
return h, hidden_states
|
260 |
+
else:
|
261 |
+
return h
|
262 |
+
|
263 |
+
|
264 |
+
class ConvCombiner(nn.Module):
|
265 |
+
def __init__(self, ch):
|
266 |
+
super().__init__()
|
267 |
+
self.conv = nn.Conv2d(ch, ch, 1, padding=0)
|
268 |
+
|
269 |
+
nn.init.zeros_(self.conv.weight)
|
270 |
+
nn.init.zeros_(self.conv.bias)
|
271 |
+
|
272 |
+
def forward(self, x, context):
|
273 |
+
## x: b c h w, context: b c 2 h w
|
274 |
+
b, c, l, h, w = context.shape
|
275 |
+
bt, c, h, w = x.shape
|
276 |
+
context = rearrange(context, "b c l h w -> (b l) c h w")
|
277 |
+
context = self.conv(context)
|
278 |
+
context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
|
279 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=bt // b)
|
280 |
+
x[:, :, 0] = x[:, :, 0] + context[:, :, 0]
|
281 |
+
x[:, :, -1] = x[:, :, -1] + context[:, :, -1]
|
282 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
class AttentionCombiner(nn.Module):
|
287 |
+
def __init__(
|
288 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
inner_dim = dim_head * heads
|
293 |
+
context_dim = default(context_dim, query_dim)
|
294 |
+
|
295 |
+
self.heads = heads
|
296 |
+
self.dim_head = dim_head
|
297 |
+
|
298 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
299 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
300 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
301 |
+
|
302 |
+
self.to_out = nn.Sequential(
|
303 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
304 |
+
)
|
305 |
+
self.attention_op = None
|
306 |
+
|
307 |
+
self.norm = GroupNorm(query_dim)
|
308 |
+
nn.init.zeros_(self.to_out[0].weight)
|
309 |
+
nn.init.zeros_(self.to_out[0].bias)
|
310 |
+
|
311 |
+
def forward(
|
312 |
+
self,
|
313 |
+
x,
|
314 |
+
context=None,
|
315 |
+
mask=None,
|
316 |
+
):
|
317 |
+
bt, c, h, w = x.shape
|
318 |
+
h_ = self.norm(x)
|
319 |
+
h_ = rearrange(h_, "b c h w -> b (h w) c")
|
320 |
+
q = self.to_q(h_)
|
321 |
+
|
322 |
+
b, c, l, h, w = context.shape
|
323 |
+
context = rearrange(context, "b c l h w -> (b l) (h w) c")
|
324 |
+
k = self.to_k(context)
|
325 |
+
v = self.to_v(context)
|
326 |
+
|
327 |
+
t = bt // b
|
328 |
+
k = repeat(k, "(b l) d c -> (b t) (l d) c", l=l, t=t)
|
329 |
+
v = repeat(v, "(b l) d c -> (b t) (l d) c", l=l, t=t)
|
330 |
+
|
331 |
+
b, _, _ = q.shape
|
332 |
+
q, k, v = map(
|
333 |
+
lambda t: t.unsqueeze(3)
|
334 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
335 |
+
.permute(0, 2, 1, 3)
|
336 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
337 |
+
.contiguous(),
|
338 |
+
(q, k, v),
|
339 |
+
)
|
340 |
+
|
341 |
+
out = chunked_attention(
|
342 |
+
q, k, v, batch_chunk=1
|
343 |
+
)
|
344 |
+
|
345 |
+
if exists(mask):
|
346 |
+
raise NotImplementedError
|
347 |
+
|
348 |
+
out = (
|
349 |
+
out.unsqueeze(0)
|
350 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
351 |
+
.permute(0, 2, 1, 3)
|
352 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
353 |
+
)
|
354 |
+
out = self.to_out(out)
|
355 |
+
out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
|
356 |
+
return x + out
|
357 |
+
|
358 |
+
|
359 |
+
class Attention(nn.Module):
|
360 |
+
def __init__(self, in_channels):
|
361 |
+
super().__init__()
|
362 |
+
self.in_channels = in_channels
|
363 |
+
|
364 |
+
self.norm = GroupNorm(in_channels)
|
365 |
+
self.q = torch.nn.Conv2d(
|
366 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
367 |
+
)
|
368 |
+
self.k = torch.nn.Conv2d(
|
369 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
370 |
+
)
|
371 |
+
self.v = torch.nn.Conv2d(
|
372 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
373 |
+
)
|
374 |
+
self.proj_out = torch.nn.Conv2d(
|
375 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
376 |
+
)
|
377 |
+
|
378 |
+
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
379 |
+
h_ = self.norm(h_)
|
380 |
+
q = self.q(h_)
|
381 |
+
k = self.k(h_)
|
382 |
+
v = self.v(h_)
|
383 |
+
|
384 |
+
# compute attention
|
385 |
+
B, C, H, W = q.shape
|
386 |
+
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
387 |
+
|
388 |
+
q, k, v = map(
|
389 |
+
lambda t: t.unsqueeze(3)
|
390 |
+
.reshape(B, t.shape[1], 1, C)
|
391 |
+
.permute(0, 2, 1, 3)
|
392 |
+
.reshape(B * 1, t.shape[1], C)
|
393 |
+
.contiguous(),
|
394 |
+
(q, k, v),
|
395 |
+
)
|
396 |
+
|
397 |
+
out = chunked_attention(
|
398 |
+
q, k, v, batch_chunk=1
|
399 |
+
)
|
400 |
+
|
401 |
+
out = (
|
402 |
+
out.unsqueeze(0)
|
403 |
+
.reshape(B, 1, out.shape[1], C)
|
404 |
+
.permute(0, 2, 1, 3)
|
405 |
+
.reshape(B, out.shape[1], C)
|
406 |
+
)
|
407 |
+
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
408 |
+
|
409 |
+
def forward(self, x, **kwargs):
|
410 |
+
h_ = x
|
411 |
+
h_ = self.attention(h_)
|
412 |
+
h_ = self.proj_out(h_)
|
413 |
+
return x + h_
|
414 |
+
|
415 |
+
|
416 |
+
class VideoDecoder(nn.Module):
|
417 |
+
def __init__(
|
418 |
+
self,
|
419 |
+
*,
|
420 |
+
ch,
|
421 |
+
out_ch,
|
422 |
+
ch_mult=(1, 2, 4, 8),
|
423 |
+
num_res_blocks,
|
424 |
+
attn_resolutions,
|
425 |
+
dropout=0.0,
|
426 |
+
resamp_with_conv=True,
|
427 |
+
in_channels,
|
428 |
+
resolution,
|
429 |
+
z_channels,
|
430 |
+
give_pre_end=False,
|
431 |
+
tanh_out=False,
|
432 |
+
use_linear_attn=False,
|
433 |
+
attn_level=[2, 3],
|
434 |
+
video_kernel_size=[3, 1, 1],
|
435 |
+
alpha: float = 0.0,
|
436 |
+
merge_strategy: str = "learned",
|
437 |
+
**kwargs,
|
438 |
+
):
|
439 |
+
super().__init__()
|
440 |
+
self.video_kernel_size = video_kernel_size
|
441 |
+
self.alpha = alpha
|
442 |
+
self.merge_strategy = merge_strategy
|
443 |
+
self.ch = ch
|
444 |
+
self.temb_ch = 0
|
445 |
+
self.num_resolutions = len(ch_mult)
|
446 |
+
self.num_res_blocks = num_res_blocks
|
447 |
+
self.resolution = resolution
|
448 |
+
self.in_channels = in_channels
|
449 |
+
self.give_pre_end = give_pre_end
|
450 |
+
self.tanh_out = tanh_out
|
451 |
+
self.attn_level = attn_level
|
452 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
453 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
454 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
455 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
456 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
457 |
+
|
458 |
+
# z to block_in
|
459 |
+
self.conv_in = torch.nn.Conv2d(
|
460 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
461 |
+
)
|
462 |
+
|
463 |
+
# middle
|
464 |
+
self.mid = nn.Module()
|
465 |
+
self.mid.block_1 = VideoResBlock(
|
466 |
+
in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout,
|
470 |
+
video_kernel_size=self.video_kernel_size,
|
471 |
+
alpha=self.alpha,
|
472 |
+
merge_strategy=self.merge_strategy,
|
473 |
+
)
|
474 |
+
self.mid.attn_1 = Attention(block_in)
|
475 |
+
self.mid.block_2 = VideoResBlock(
|
476 |
+
in_channels=block_in,
|
477 |
+
out_channels=block_in,
|
478 |
+
temb_channels=self.temb_ch,
|
479 |
+
dropout=dropout,
|
480 |
+
video_kernel_size=self.video_kernel_size,
|
481 |
+
alpha=self.alpha,
|
482 |
+
merge_strategy=self.merge_strategy,
|
483 |
+
)
|
484 |
+
|
485 |
+
# upsampling
|
486 |
+
self.up = nn.ModuleList()
|
487 |
+
self.attn_refinement = nn.ModuleList()
|
488 |
+
for i_level in reversed(range(self.num_resolutions)):
|
489 |
+
block = nn.ModuleList()
|
490 |
+
attn = nn.ModuleList()
|
491 |
+
block_out = ch * ch_mult[i_level]
|
492 |
+
for i_block in range(self.num_res_blocks + 1):
|
493 |
+
block.append(
|
494 |
+
VideoResBlock(
|
495 |
+
in_channels=block_in,
|
496 |
+
out_channels=block_out,
|
497 |
+
temb_channels=self.temb_ch,
|
498 |
+
dropout=dropout,
|
499 |
+
video_kernel_size=self.video_kernel_size,
|
500 |
+
alpha=self.alpha,
|
501 |
+
merge_strategy=self.merge_strategy,
|
502 |
+
)
|
503 |
+
)
|
504 |
+
block_in = block_out
|
505 |
+
if curr_res in attn_resolutions:
|
506 |
+
attn.append(Attention(block_in))
|
507 |
+
up = nn.Module()
|
508 |
+
up.block = block
|
509 |
+
up.attn = attn
|
510 |
+
if i_level != 0:
|
511 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
512 |
+
curr_res = curr_res * 2
|
513 |
+
self.up.insert(0, up) # prepend to get consistent order
|
514 |
+
|
515 |
+
if i_level in self.attn_level:
|
516 |
+
self.attn_refinement.insert(0, AttentionCombiner(block_in))
|
517 |
+
else:
|
518 |
+
self.attn_refinement.insert(0, ConvCombiner(block_in))
|
519 |
+
# end
|
520 |
+
self.norm_out = GroupNorm(block_in)
|
521 |
+
self.attn_refinement.append(ConvCombiner(block_in))
|
522 |
+
self.conv_out = DecoderConv3D(
|
523 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1, video_kernel_size=self.video_kernel_size
|
524 |
+
)
|
525 |
+
|
526 |
+
def forward(self, z, ref_context=None, **kwargs):
|
527 |
+
## ref_context: b c 2 h w, 2 means starting and ending frame
|
528 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
529 |
+
self.last_z_shape = z.shape
|
530 |
+
# timestep embedding
|
531 |
+
temb = None
|
532 |
+
|
533 |
+
# z to block_in
|
534 |
+
h = self.conv_in(z)
|
535 |
+
|
536 |
+
# middle
|
537 |
+
h = self.mid.block_1(h, temb, **kwargs)
|
538 |
+
h = self.mid.attn_1(h, **kwargs)
|
539 |
+
h = self.mid.block_2(h, temb, **kwargs)
|
540 |
+
|
541 |
+
# upsampling
|
542 |
+
for i_level in reversed(range(self.num_resolutions)):
|
543 |
+
for i_block in range(self.num_res_blocks + 1):
|
544 |
+
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
545 |
+
if len(self.up[i_level].attn) > 0:
|
546 |
+
h = self.up[i_level].attn[i_block](h, **kwargs)
|
547 |
+
if ref_context:
|
548 |
+
h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
|
549 |
+
if i_level != 0:
|
550 |
+
h = self.up[i_level].upsample(h)
|
551 |
+
|
552 |
+
# end
|
553 |
+
if self.give_pre_end:
|
554 |
+
return h
|
555 |
+
|
556 |
+
h = self.norm_out(h)
|
557 |
+
h = nonlinearity(h)
|
558 |
+
if ref_context:
|
559 |
+
# print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
|
560 |
+
h = self.attn_refinement[-1](x=h, context=ref_context[-1])
|
561 |
+
h = self.conv_out(h, **kwargs)
|
562 |
+
if self.tanh_out:
|
563 |
+
h = torch.tanh(h)
|
564 |
+
return h
|
565 |
+
|
566 |
+
|
567 |
+
class TimeStackBlock(torch.nn.Module):
|
568 |
+
def __init__(
|
569 |
+
self,
|
570 |
+
channels: int,
|
571 |
+
emb_channels: int,
|
572 |
+
dropout: float,
|
573 |
+
out_channels: int = None,
|
574 |
+
use_conv: bool = False,
|
575 |
+
use_scale_shift_norm: bool = False,
|
576 |
+
dims: int = 2,
|
577 |
+
use_checkpoint: bool = False,
|
578 |
+
up: bool = False,
|
579 |
+
down: bool = False,
|
580 |
+
kernel_size: int = 3,
|
581 |
+
exchange_temb_dims: bool = False,
|
582 |
+
skip_t_emb: bool = False,
|
583 |
+
):
|
584 |
+
super().__init__()
|
585 |
+
self.channels = channels
|
586 |
+
self.emb_channels = emb_channels
|
587 |
+
self.dropout = dropout
|
588 |
+
self.out_channels = out_channels or channels
|
589 |
+
self.use_conv = use_conv
|
590 |
+
self.use_checkpoint = use_checkpoint
|
591 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
592 |
+
self.exchange_temb_dims = exchange_temb_dims
|
593 |
+
|
594 |
+
if isinstance(kernel_size, list):
|
595 |
+
padding = [k // 2 for k in kernel_size]
|
596 |
+
else:
|
597 |
+
padding = kernel_size // 2
|
598 |
+
|
599 |
+
self.in_layers = nn.Sequential(
|
600 |
+
normalization(channels),
|
601 |
+
nn.SiLU(),
|
602 |
+
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
603 |
+
)
|
604 |
+
|
605 |
+
self.updown = up or down
|
606 |
+
|
607 |
+
if up:
|
608 |
+
self.h_upd = Upsample(channels, False, dims)
|
609 |
+
self.x_upd = Upsample(channels, False, dims)
|
610 |
+
elif down:
|
611 |
+
self.h_upd = Downsample(channels, False, dims)
|
612 |
+
self.x_upd = Downsample(channels, False, dims)
|
613 |
+
else:
|
614 |
+
self.h_upd = self.x_upd = nn.Identity()
|
615 |
+
|
616 |
+
self.skip_t_emb = skip_t_emb
|
617 |
+
self.emb_out_channels = (
|
618 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
619 |
+
)
|
620 |
+
if self.skip_t_emb:
|
621 |
+
# print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
622 |
+
assert not self.use_scale_shift_norm
|
623 |
+
self.emb_layers = None
|
624 |
+
self.exchange_temb_dims = False
|
625 |
+
else:
|
626 |
+
self.emb_layers = nn.Sequential(
|
627 |
+
nn.SiLU(),
|
628 |
+
linear(
|
629 |
+
emb_channels,
|
630 |
+
self.emb_out_channels,
|
631 |
+
),
|
632 |
+
)
|
633 |
+
|
634 |
+
self.out_layers = nn.Sequential(
|
635 |
+
normalization(self.out_channels),
|
636 |
+
nn.SiLU(),
|
637 |
+
nn.Dropout(p=dropout),
|
638 |
+
zero_module(
|
639 |
+
conv_nd(
|
640 |
+
dims,
|
641 |
+
self.out_channels,
|
642 |
+
self.out_channels,
|
643 |
+
kernel_size,
|
644 |
+
padding=padding,
|
645 |
+
)
|
646 |
+
),
|
647 |
+
)
|
648 |
+
|
649 |
+
if self.out_channels == channels:
|
650 |
+
self.skip_connection = nn.Identity()
|
651 |
+
elif use_conv:
|
652 |
+
self.skip_connection = conv_nd(
|
653 |
+
dims, channels, self.out_channels, kernel_size, padding=padding
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
657 |
+
|
658 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
659 |
+
if self.updown:
|
660 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
661 |
+
h = in_rest(x)
|
662 |
+
h = self.h_upd(h)
|
663 |
+
x = self.x_upd(x)
|
664 |
+
h = in_conv(h)
|
665 |
+
else:
|
666 |
+
h = self.in_layers(x)
|
667 |
+
|
668 |
+
if self.skip_t_emb:
|
669 |
+
emb_out = torch.zeros_like(h)
|
670 |
+
else:
|
671 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
672 |
+
while len(emb_out.shape) < len(h.shape):
|
673 |
+
emb_out = emb_out[..., None]
|
674 |
+
if self.use_scale_shift_norm:
|
675 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
676 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
677 |
+
h = out_norm(h) * (1 + scale) + shift
|
678 |
+
h = out_rest(h)
|
679 |
+
else:
|
680 |
+
if self.exchange_temb_dims:
|
681 |
+
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
682 |
+
h = h + emb_out
|
683 |
+
h = self.out_layers(h)
|
684 |
+
return self.skip_connection(x) + h
|
685 |
+
|
686 |
+
|
687 |
+
class VideoResBlock(ResnetBlock):
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
out_channels,
|
691 |
+
*args,
|
692 |
+
dropout=0.0,
|
693 |
+
video_kernel_size=3,
|
694 |
+
alpha=0.0,
|
695 |
+
merge_strategy="learned",
|
696 |
+
**kwargs,
|
697 |
+
):
|
698 |
+
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
699 |
+
if video_kernel_size is None:
|
700 |
+
video_kernel_size = [3, 1, 1]
|
701 |
+
self.time_stack = TimeStackBlock(
|
702 |
+
channels=out_channels,
|
703 |
+
emb_channels=0,
|
704 |
+
dropout=dropout,
|
705 |
+
dims=3,
|
706 |
+
use_scale_shift_norm=False,
|
707 |
+
use_conv=False,
|
708 |
+
up=False,
|
709 |
+
down=False,
|
710 |
+
kernel_size=video_kernel_size,
|
711 |
+
use_checkpoint=True,
|
712 |
+
skip_t_emb=True,
|
713 |
+
)
|
714 |
+
|
715 |
+
self.merge_strategy = merge_strategy
|
716 |
+
if self.merge_strategy == "fixed":
|
717 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
718 |
+
elif self.merge_strategy == "learned":
|
719 |
+
self.register_parameter(
|
720 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
724 |
+
|
725 |
+
def get_alpha(self, bs):
|
726 |
+
if self.merge_strategy == "fixed":
|
727 |
+
return self.mix_factor
|
728 |
+
elif self.merge_strategy == "learned":
|
729 |
+
return torch.sigmoid(self.mix_factor)
|
730 |
+
else:
|
731 |
+
raise NotImplementedError()
|
732 |
+
|
733 |
+
def forward(self, x, temb, skip_video=False, timesteps=None):
|
734 |
+
assert isinstance(timesteps, int)
|
735 |
+
|
736 |
+
b, c, h, w = x.shape
|
737 |
+
|
738 |
+
x = super().forward(x, temb)
|
739 |
+
|
740 |
+
if not skip_video:
|
741 |
+
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
742 |
+
|
743 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
744 |
+
|
745 |
+
x = self.time_stack(x, temb)
|
746 |
+
|
747 |
+
alpha = self.get_alpha(bs=b // timesteps)
|
748 |
+
x = alpha * x + (1.0 - alpha) * x_mix
|
749 |
+
|
750 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
751 |
+
return x
|
752 |
+
|
753 |
+
|
754 |
+
class DecoderConv3D(torch.nn.Conv2d):
|
755 |
+
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
756 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
757 |
+
if isinstance(video_kernel_size, list):
|
758 |
+
padding = [int(k // 2) for k in video_kernel_size]
|
759 |
+
else:
|
760 |
+
padding = int(video_kernel_size // 2)
|
761 |
+
|
762 |
+
self.time_mix_conv = torch.nn.Conv3d(
|
763 |
+
in_channels=out_channels,
|
764 |
+
out_channels=out_channels,
|
765 |
+
kernel_size=video_kernel_size,
|
766 |
+
padding=padding,
|
767 |
+
)
|
768 |
+
|
769 |
+
def forward(self, input, timesteps, skip_video=False):
|
770 |
+
x = super().forward(input)
|
771 |
+
if skip_video:
|
772 |
+
return x
|
773 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
774 |
+
x = self.time_mix_conv(x)
|
775 |
+
return rearrange(x, "b c t h w -> (b t) c h w")
|
776 |
+
|
777 |
+
|
778 |
+
class VideoAutoencoderKL(torch.nn.Module, PyTorchModelHubMixin):
|
779 |
+
def __init__(self,
|
780 |
+
double_z=True,
|
781 |
+
z_channels=4,
|
782 |
+
resolution=256,
|
783 |
+
in_channels=3,
|
784 |
+
out_ch=3,
|
785 |
+
ch=128,
|
786 |
+
ch_mult=[],
|
787 |
+
num_res_blocks=2,
|
788 |
+
attn_resolutions=[],
|
789 |
+
dropout=0.0,
|
790 |
+
):
|
791 |
+
super().__init__()
|
792 |
+
self.encoder = Encoder(double_z=double_z, z_channels=z_channels, resolution=resolution, in_channels=in_channels,
|
793 |
+
out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
|
794 |
+
attn_resolutions=attn_resolutions, dropout=dropout)
|
795 |
+
self.decoder = VideoDecoder(double_z=double_z, z_channels=z_channels, resolution=resolution,
|
796 |
+
in_channels=in_channels, out_ch=out_ch, ch=ch, ch_mult=ch_mult,
|
797 |
+
num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout)
|
798 |
+
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
|
799 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
800 |
+
self.scale_factor = 0.18215
|
801 |
+
|
802 |
+
def encode(self, x, return_hidden_states=False, **kwargs):
|
803 |
+
if return_hidden_states:
|
804 |
+
h, hidden = self.encoder(x, return_hidden_states)
|
805 |
+
moments = self.quant_conv(h)
|
806 |
+
posterior = DiagonalGaussianDistribution(moments)
|
807 |
+
return posterior, hidden
|
808 |
+
else:
|
809 |
+
h = self.encoder(x)
|
810 |
+
moments = self.quant_conv(h)
|
811 |
+
posterior = DiagonalGaussianDistribution(moments)
|
812 |
+
return posterior, None
|
813 |
+
|
814 |
+
def decode(self, z, **kwargs):
|
815 |
+
if len(kwargs) == 0:
|
816 |
+
z = self.post_quant_conv(z)
|
817 |
+
dec = self.decoder(z, **kwargs)
|
818 |
+
return dec
|
819 |
+
|
820 |
+
@property
|
821 |
+
def device(self):
|
822 |
+
return next(self.parameters()).device
|
823 |
+
|
824 |
+
@property
|
825 |
+
def dtype(self):
|
826 |
+
return next(self.parameters()).dtype
|
gradio_app.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
|
4 |
+
result_dir = os.path.join('./', 'results')
|
5 |
+
os.makedirs(result_dir, exist_ok=True)
|
6 |
+
|
7 |
+
|
8 |
+
import functools
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import wd14tagger
|
15 |
+
import memory_management
|
16 |
+
import uuid
|
17 |
+
|
18 |
+
from PIL import Image
|
19 |
+
from diffusers_helper.code_cond import unet_add_coded_conds
|
20 |
+
from diffusers_helper.cat_cond import unet_add_concat_conds
|
21 |
+
from diffusers_helper.k_diffusion import KDiffusionSampler
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
23 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
24 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
25 |
+
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
26 |
+
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
27 |
+
import spaces
|
28 |
+
|
29 |
+
class ModifiedUNet(UNet2DConditionModel):
|
30 |
+
@classmethod
|
31 |
+
def from_config(cls, *args, **kwargs):
|
32 |
+
m = super().from_config(*args, **kwargs)
|
33 |
+
unet_add_concat_conds(unet=m, new_channels=4)
|
34 |
+
unet_add_coded_conds(unet=m, added_number_count=1)
|
35 |
+
return m
|
36 |
+
|
37 |
+
|
38 |
+
model_name = 'lllyasviel/paints_undo_single_frame'
|
39 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
40 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
|
41 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
|
42 |
+
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
|
43 |
+
|
44 |
+
unet.set_attn_processor(AttnProcessor2_0())
|
45 |
+
vae.set_attn_processor(AttnProcessor2_0())
|
46 |
+
|
47 |
+
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
48 |
+
'lllyasviel/paints_undo_multi_frame',
|
49 |
+
fp16=True
|
50 |
+
)
|
51 |
+
|
52 |
+
memory_management.unload_all_models([
|
53 |
+
video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
|
54 |
+
unet, vae, text_encoder
|
55 |
+
])
|
56 |
+
|
57 |
+
k_sampler = KDiffusionSampler(
|
58 |
+
unet=unet,
|
59 |
+
timesteps=1000,
|
60 |
+
linear_start=0.00085,
|
61 |
+
linear_end=0.020,
|
62 |
+
linear=True
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def find_best_bucket(h, w, options):
|
67 |
+
min_metric = float('inf')
|
68 |
+
best_bucket = None
|
69 |
+
for (bucket_h, bucket_w) in options:
|
70 |
+
metric = abs(h * bucket_w - w * bucket_h)
|
71 |
+
if metric <= min_metric:
|
72 |
+
min_metric = metric
|
73 |
+
best_bucket = (bucket_h, bucket_w)
|
74 |
+
return best_bucket
|
75 |
+
|
76 |
+
|
77 |
+
@torch.inference_mode()
|
78 |
+
def encode_cropped_prompt_77tokens(txt: str):
|
79 |
+
memory_management.load_models_to_gpu(text_encoder)
|
80 |
+
cond_ids = tokenizer(txt,
|
81 |
+
padding="max_length",
|
82 |
+
max_length=tokenizer.model_max_length,
|
83 |
+
truncation=True,
|
84 |
+
return_tensors="pt").input_ids.to(device=text_encoder.device)
|
85 |
+
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
86 |
+
return text_cond
|
87 |
+
|
88 |
+
|
89 |
+
@torch.inference_mode()
|
90 |
+
def pytorch2numpy(imgs):
|
91 |
+
results = []
|
92 |
+
for x in imgs:
|
93 |
+
y = x.movedim(0, -1)
|
94 |
+
y = y * 127.5 + 127.5
|
95 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
96 |
+
results.append(y)
|
97 |
+
return results
|
98 |
+
|
99 |
+
|
100 |
+
@torch.inference_mode()
|
101 |
+
def numpy2pytorch(imgs):
|
102 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
103 |
+
h = h.movedim(-1, 1)
|
104 |
+
return h
|
105 |
+
|
106 |
+
|
107 |
+
def resize_without_crop(image, target_width, target_height):
|
108 |
+
pil_image = Image.fromarray(image)
|
109 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
110 |
+
return np.array(resized_image)
|
111 |
+
|
112 |
+
|
113 |
+
@torch.inference_mode()
|
114 |
+
@spaces.GPU
|
115 |
+
def interrogator_process(x):
|
116 |
+
return wd14tagger.default_interrogator(x)
|
117 |
+
|
118 |
+
|
119 |
+
@torch.inference_mode()
|
120 |
+
@spaces.GPU
|
121 |
+
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
122 |
+
progress=gr.Progress()):
|
123 |
+
rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
|
124 |
+
|
125 |
+
memory_management.load_models_to_gpu(vae)
|
126 |
+
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
127 |
+
concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
|
128 |
+
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
129 |
+
|
130 |
+
memory_management.load_models_to_gpu(text_encoder)
|
131 |
+
conds = encode_cropped_prompt_77tokens(prompt)
|
132 |
+
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
133 |
+
|
134 |
+
memory_management.load_models_to_gpu(unet)
|
135 |
+
fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
|
136 |
+
initial_latents = torch.zeros_like(concat_conds)
|
137 |
+
concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
|
138 |
+
latents = k_sampler(
|
139 |
+
initial_latent=initial_latents,
|
140 |
+
strength=1.0,
|
141 |
+
num_inference_steps=steps,
|
142 |
+
guidance_scale=cfg,
|
143 |
+
batch_size=len(input_undo_steps),
|
144 |
+
generator=rng,
|
145 |
+
prompt_embeds=conds,
|
146 |
+
negative_prompt_embeds=unconds,
|
147 |
+
cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs},
|
148 |
+
same_noise_in_batch=True,
|
149 |
+
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
150 |
+
).to(vae.dtype) / vae.config.scaling_factor
|
151 |
+
|
152 |
+
memory_management.load_models_to_gpu(vae)
|
153 |
+
pixels = vae.decode(latents).sample
|
154 |
+
pixels = pytorch2numpy(pixels)
|
155 |
+
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
156 |
+
|
157 |
+
return pixels
|
158 |
+
|
159 |
+
|
160 |
+
@torch.inference_mode()
|
161 |
+
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
162 |
+
random.seed(seed)
|
163 |
+
np.random.seed(seed)
|
164 |
+
torch.manual_seed(seed)
|
165 |
+
torch.cuda.manual_seed_all(seed)
|
166 |
+
|
167 |
+
frames = 16
|
168 |
+
|
169 |
+
target_height, target_width = find_best_bucket(
|
170 |
+
image_1.shape[0], image_1.shape[1],
|
171 |
+
options=[(320, 512), (384, 448), (448, 384), (512, 320)]
|
172 |
+
)
|
173 |
+
|
174 |
+
image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height)
|
175 |
+
image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height)
|
176 |
+
input_frames = numpy2pytorch([image_1, image_2])
|
177 |
+
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
178 |
+
|
179 |
+
memory_management.load_models_to_gpu(video_pipe.text_encoder)
|
180 |
+
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
181 |
+
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
182 |
+
|
183 |
+
memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
|
184 |
+
input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
|
185 |
+
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
186 |
+
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
187 |
+
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
188 |
+
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
189 |
+
|
190 |
+
memory_management.load_models_to_gpu([video_pipe.vae])
|
191 |
+
input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
|
192 |
+
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
193 |
+
first_frame = input_frame_latents[:, :, 0]
|
194 |
+
last_frame = input_frame_latents[:, :, 1]
|
195 |
+
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
196 |
+
|
197 |
+
memory_management.load_models_to_gpu([video_pipe.unet])
|
198 |
+
latents = video_pipe(
|
199 |
+
batch_size=1,
|
200 |
+
steps=int(steps),
|
201 |
+
guidance_scale=cfg_scale,
|
202 |
+
positive_text_cond=positive_text_cond,
|
203 |
+
negative_text_cond=negative_text_cond,
|
204 |
+
positive_image_cond=positive_image_cond,
|
205 |
+
negative_image_cond=negative_image_cond,
|
206 |
+
concat_cond=concat_cond,
|
207 |
+
fs=fs,
|
208 |
+
progress_tqdm=progress_tqdm
|
209 |
+
)
|
210 |
+
|
211 |
+
memory_management.load_models_to_gpu([video_pipe.vae])
|
212 |
+
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
213 |
+
return video, image_1, image_2
|
214 |
+
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
@spaces.GPU
|
218 |
+
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
219 |
+
result_frames = []
|
220 |
+
cropped_images = []
|
221 |
+
|
222 |
+
for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])):
|
223 |
+
im1 = np.array(Image.open(im1[0]))
|
224 |
+
im2 = np.array(Image.open(im2[0]))
|
225 |
+
frames, im1, im2 = process_video_inner(
|
226 |
+
im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3,
|
227 |
+
progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})')
|
228 |
+
)
|
229 |
+
result_frames.append(frames[:, :, :-1, :, :])
|
230 |
+
cropped_images.append([im1, im2])
|
231 |
+
|
232 |
+
video = torch.cat(result_frames, dim=2)
|
233 |
+
video = torch.flip(video, dims=[2])
|
234 |
+
|
235 |
+
uuid_name = str(uuid.uuid4())
|
236 |
+
output_filename = os.path.join(result_dir, uuid_name + '.mp4')
|
237 |
+
Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png'))
|
238 |
+
video = save_bcthw_as_mp4(video, output_filename, fps=fps)
|
239 |
+
video = [x.cpu().numpy() for x in video]
|
240 |
+
return output_filename, video
|
241 |
+
|
242 |
+
|
243 |
+
block = gr.Blocks().queue()
|
244 |
+
with block:
|
245 |
+
gr.Markdown('# Paints-Undo')
|
246 |
+
|
247 |
+
with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True):
|
248 |
+
with gr.Row():
|
249 |
+
with gr.Column():
|
250 |
+
input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512)
|
251 |
+
with gr.Column():
|
252 |
+
prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False)
|
253 |
+
prompt = gr.Textbox(label="Output Prompt", interactive=True)
|
254 |
+
|
255 |
+
with gr.Accordion(label='Step 2: Generate Key Frames', open=True):
|
256 |
+
with gr.Row():
|
257 |
+
with gr.Column():
|
258 |
+
input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999],
|
259 |
+
choices=list(range(1000)), multiselect=True)
|
260 |
+
seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345)
|
261 |
+
image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
|
262 |
+
image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
|
263 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
|
264 |
+
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01)
|
265 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
266 |
+
value='lowres, bad anatomy, bad hands, cropped, worst quality')
|
267 |
+
|
268 |
+
with gr.Column():
|
269 |
+
key_gen_button = gr.Button(value="Generate Key Frames", interactive=False)
|
270 |
+
result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4)
|
271 |
+
|
272 |
+
with gr.Accordion(label='Step 3: Generate All Videos', open=True):
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column():
|
275 |
+
i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality')
|
276 |
+
i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123)
|
277 |
+
i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5,
|
278 |
+
elem_id="i2v_cfg_scale")
|
279 |
+
i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps",
|
280 |
+
label="Sampling steps", value=50)
|
281 |
+
i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4)
|
282 |
+
with gr.Column():
|
283 |
+
i2v_end_btn = gr.Button("Generate Video", interactive=False)
|
284 |
+
i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
|
285 |
+
show_share_button=True, height=512)
|
286 |
+
with gr.Row():
|
287 |
+
i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8)
|
288 |
+
|
289 |
+
input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)],
|
290 |
+
outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn])
|
291 |
+
|
292 |
+
prompt_gen_button.click(
|
293 |
+
fn=interrogator_process,
|
294 |
+
inputs=[input_fg],
|
295 |
+
outputs=[prompt]
|
296 |
+
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
|
297 |
+
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
298 |
+
|
299 |
+
key_gen_button.click(
|
300 |
+
fn=process,
|
301 |
+
inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg],
|
302 |
+
outputs=[result_gallery]
|
303 |
+
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
|
304 |
+
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
305 |
+
|
306 |
+
i2v_end_btn.click(
|
307 |
+
inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed],
|
308 |
+
outputs=[i2v_output_video, i2v_output_images],
|
309 |
+
fn=process_video
|
310 |
+
)
|
311 |
+
|
312 |
+
dbs = [
|
313 |
+
['./imgs/1.jpg', 12345, 123],
|
314 |
+
['./imgs/2.jpg', 37000, 12345],
|
315 |
+
['./imgs/3.jpg', 3000, 3000],
|
316 |
+
]
|
317 |
+
|
318 |
+
gr.Examples(
|
319 |
+
examples=dbs,
|
320 |
+
inputs=[input_fg, seed, i2v_seed],
|
321 |
+
examples_per_page=1024
|
322 |
+
)
|
323 |
+
|
324 |
+
block.queue().launch(server_name='0.0.0.0')
|
imgs/1.jpg
ADDED
imgs/2.jpg
ADDED
imgs/3.jpg
ADDED
memory_management.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from contextlib import contextmanager
|
3 |
+
|
4 |
+
|
5 |
+
high_vram = False
|
6 |
+
gpu = torch.device('cuda')
|
7 |
+
cpu = torch.device('cpu')
|
8 |
+
|
9 |
+
torch.zeros((1, 1)).to(gpu, torch.float32)
|
10 |
+
torch.cuda.empty_cache()
|
11 |
+
|
12 |
+
models_in_gpu = []
|
13 |
+
|
14 |
+
|
15 |
+
@contextmanager
|
16 |
+
def movable_bnb_model(m):
|
17 |
+
if hasattr(m, 'quantization_method'):
|
18 |
+
m.quantization_method_backup = m.quantization_method
|
19 |
+
del m.quantization_method
|
20 |
+
try:
|
21 |
+
yield None
|
22 |
+
finally:
|
23 |
+
if hasattr(m, 'quantization_method_backup'):
|
24 |
+
m.quantization_method = m.quantization_method_backup
|
25 |
+
del m.quantization_method_backup
|
26 |
+
return
|
27 |
+
|
28 |
+
|
29 |
+
def load_models_to_gpu(models):
|
30 |
+
global models_in_gpu
|
31 |
+
|
32 |
+
if not isinstance(models, (tuple, list)):
|
33 |
+
models = [models]
|
34 |
+
|
35 |
+
models_to_remain = [m for m in set(models) if m in models_in_gpu]
|
36 |
+
models_to_load = [m for m in set(models) if m not in models_in_gpu]
|
37 |
+
models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
|
38 |
+
|
39 |
+
if not high_vram:
|
40 |
+
for m in models_to_unload:
|
41 |
+
with movable_bnb_model(m):
|
42 |
+
m.to(cpu)
|
43 |
+
print('Unload to CPU:', m.__class__.__name__)
|
44 |
+
models_in_gpu = models_to_remain
|
45 |
+
|
46 |
+
for m in models_to_load:
|
47 |
+
with movable_bnb_model(m):
|
48 |
+
m.to(gpu)
|
49 |
+
print('Load to GPU:', m.__class__.__name__)
|
50 |
+
|
51 |
+
models_in_gpu = list(set(models_in_gpu + models))
|
52 |
+
torch.cuda.empty_cache()
|
53 |
+
return
|
54 |
+
|
55 |
+
|
56 |
+
def unload_all_models(extra_models=None):
|
57 |
+
global models_in_gpu
|
58 |
+
|
59 |
+
if extra_models is None:
|
60 |
+
extra_models = []
|
61 |
+
|
62 |
+
if not isinstance(extra_models, (tuple, list)):
|
63 |
+
extra_models = [extra_models]
|
64 |
+
|
65 |
+
models_in_gpu = list(set(models_in_gpu + extra_models))
|
66 |
+
|
67 |
+
return load_models_to_gpu([])
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.28.0
|
2 |
+
transformers==4.41.1
|
3 |
+
gradio==4.31.5
|
4 |
+
bitsandbytes==0.43.1
|
5 |
+
accelerate==0.30.1
|
6 |
+
protobuf==3.20
|
7 |
+
opencv-python
|
8 |
+
tensorboardX
|
9 |
+
safetensors
|
10 |
+
pillow
|
11 |
+
einops
|
12 |
+
torch
|
13 |
+
peft
|
14 |
+
xformers
|
15 |
+
onnxruntime
|
16 |
+
av
|
17 |
+
torchvision
|
18 |
+
xformers
|
19 |
+
spaces
|
wd14tagger.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
import csv
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime as ort
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from onnxruntime import InferenceSession
|
11 |
+
from torch.hub import download_url_to_file
|
12 |
+
|
13 |
+
|
14 |
+
global_model = None
|
15 |
+
global_csv = None
|
16 |
+
|
17 |
+
|
18 |
+
def download_model(url, local_path):
|
19 |
+
if os.path.exists(local_path):
|
20 |
+
return local_path
|
21 |
+
|
22 |
+
temp_path = local_path + '.tmp'
|
23 |
+
download_url_to_file(url=url, dst=temp_path)
|
24 |
+
os.rename(temp_path, local_path)
|
25 |
+
return local_path
|
26 |
+
|
27 |
+
|
28 |
+
def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""):
|
29 |
+
global global_model, global_csv
|
30 |
+
|
31 |
+
model_name = "wd-v1-4-moat-tagger-v2"
|
32 |
+
|
33 |
+
model_onnx_filename = download_model(
|
34 |
+
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
|
35 |
+
local_path=f'./{model_name}.onnx',
|
36 |
+
)
|
37 |
+
|
38 |
+
model_csv_filename = download_model(
|
39 |
+
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
|
40 |
+
local_path=f'./{model_name}.csv',
|
41 |
+
)
|
42 |
+
|
43 |
+
if global_model is not None:
|
44 |
+
model = global_model
|
45 |
+
else:
|
46 |
+
# assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!'
|
47 |
+
# model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider'])
|
48 |
+
model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider'])
|
49 |
+
global_model = model
|
50 |
+
|
51 |
+
input = model.get_inputs()[0]
|
52 |
+
height = input.shape[1]
|
53 |
+
|
54 |
+
if isinstance(image, str):
|
55 |
+
image = Image.open(image) # RGB
|
56 |
+
elif isinstance(image, np.ndarray):
|
57 |
+
image = Image.fromarray(image)
|
58 |
+
else:
|
59 |
+
image = image
|
60 |
+
|
61 |
+
ratio = float(height) / max(image.size)
|
62 |
+
new_size = tuple([int(x*ratio) for x in image.size])
|
63 |
+
image = image.resize(new_size, Image.LANCZOS)
|
64 |
+
square = Image.new("RGB", (height, height), (255, 255, 255))
|
65 |
+
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
|
66 |
+
|
67 |
+
image = np.array(square).astype(np.float32)
|
68 |
+
image = image[:, :, ::-1] # RGB -> BGR
|
69 |
+
image = np.expand_dims(image, 0)
|
70 |
+
|
71 |
+
if global_csv is not None:
|
72 |
+
csv_lines = global_csv
|
73 |
+
else:
|
74 |
+
csv_lines = []
|
75 |
+
with open(model_csv_filename) as f:
|
76 |
+
reader = csv.reader(f)
|
77 |
+
next(reader)
|
78 |
+
for row in reader:
|
79 |
+
csv_lines.append(row)
|
80 |
+
global_csv = csv_lines
|
81 |
+
|
82 |
+
tags = []
|
83 |
+
general_index = None
|
84 |
+
character_index = None
|
85 |
+
for line_num, row in enumerate(csv_lines):
|
86 |
+
if general_index is None and row[2] == "0":
|
87 |
+
general_index = line_num
|
88 |
+
elif character_index is None and row[2] == "4":
|
89 |
+
character_index = line_num
|
90 |
+
tags.append(row[1])
|
91 |
+
|
92 |
+
label_name = model.get_outputs()[0].name
|
93 |
+
probs = model.run([label_name], {input.name: image})[0]
|
94 |
+
|
95 |
+
result = list(zip(tags, probs[0]))
|
96 |
+
|
97 |
+
general = [item for item in result[general_index:character_index] if item[1] > threshold]
|
98 |
+
character = [item for item in result[character_index:] if item[1] > character_threshold]
|
99 |
+
|
100 |
+
all = character + general
|
101 |
+
remove = [s.strip() for s in exclude_tags.lower().split(",")]
|
102 |
+
all = [tag for tag in all if tag[0] not in remove]
|
103 |
+
|
104 |
+
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
|
105 |
+
return res
|