Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .github/ISSUE_TEMPLATE/bug_report.yaml +51 -0
- .github/ISSUE_TEMPLATE/feature-request.yaml +34 -0
- .github/PULL_REQUEST_TEMPLATE/pr_template.md +34 -0
- .gitignore +9 -0
- LICENSE +201 -0
- Model_License +71 -0
- README.md +157 -6
- README_zh.md +149 -0
- gradio_demo.py +254 -0
- inference/cli_demo.py +127 -0
- inference/cli_vae_demo.py +103 -0
- inference/convert_demo.py +92 -0
- inference/web_demo.py +214 -0
- pyproject.toml +27 -0
- requirements.txt +11 -0
- resources/CogVideoX.pdf +3 -0
- resources/WECHAT.md +7 -0
- resources/contribute.md +50 -0
- resources/contribute_zh.md +45 -0
- resources/logo.svg +298 -0
- resources/videos/1.mp4 +0 -0
- resources/videos/2.mp4 +3 -0
- resources/videos/3.mp4 +0 -0
- resources/videos/4.mp4 +0 -0
- resources/web_demo.png +3 -0
- resources/wechat.jpg +0 -0
- sat/README.md +182 -0
- sat/README_zh.md +180 -0
- sat/arguments.py +281 -0
- sat/configs/cogvideox_2b_infer.yaml +166 -0
- sat/configs/cogvideox_2b_sft.yaml +225 -0
- sat/configs/test.txt +3 -0
- sat/data_video.py +451 -0
- sat/diffusion_video.py +318 -0
- sat/dit_video_concat.py +858 -0
- sat/finetune.sh +12 -0
- sat/inference.sh +12 -0
- sat/requirements.txt +17 -0
- sat/sample_video.py +236 -0
- sat/sgm/__init__.py +4 -0
- sat/sgm/lr_scheduler.py +110 -0
- sat/sgm/models/__init__.py +1 -0
- sat/sgm/models/autoencoder.py +630 -0
- sat/sgm/modules/__init__.py +6 -0
- sat/sgm/modules/attention.py +572 -0
- sat/sgm/modules/autoencoding/__init__.py +0 -0
- sat/sgm/modules/autoencoding/losses/__init__.py +8 -0
- sat/sgm/modules/autoencoding/losses/discriminator_loss.py +301 -0
- sat/sgm/modules/autoencoding/losses/lpips.py +64 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resources/CogVideoX.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
resources/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
resources/web_demo.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
tools/caption/assests/cogvlm2-video-example.png filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug_report.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F41B Bug Report"
|
2 |
+
description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
|
3 |
+
body:
|
4 |
+
- type: textarea
|
5 |
+
id: system-info
|
6 |
+
attributes:
|
7 |
+
label: System Info / 系統信息
|
8 |
+
description: Your operating environment / 您的运行环境信息
|
9 |
+
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
|
10 |
+
validations:
|
11 |
+
required: true
|
12 |
+
|
13 |
+
- type: checkboxes
|
14 |
+
id: information-scripts-examples
|
15 |
+
attributes:
|
16 |
+
label: Information / 问题信息
|
17 |
+
description: 'The problem arises when using: / 问题出现在'
|
18 |
+
options:
|
19 |
+
- label: "The official example scripts / 官方的示例脚本"
|
20 |
+
- label: "My own modified scripts / 我自己修改的脚本和任务"
|
21 |
+
|
22 |
+
- type: textarea
|
23 |
+
id: reproduction
|
24 |
+
validations:
|
25 |
+
required: true
|
26 |
+
attributes:
|
27 |
+
label: Reproduction / 复现过程
|
28 |
+
description: |
|
29 |
+
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
|
30 |
+
If you have code snippets, error messages, stack traces, please provide them here as well.
|
31 |
+
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
32 |
+
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
|
33 |
+
|
34 |
+
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
|
35 |
+
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
|
36 |
+
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
37 |
+
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
|
38 |
+
placeholder: |
|
39 |
+
Steps to reproduce the behavior/复现Bug的步骤:
|
40 |
+
|
41 |
+
1.
|
42 |
+
2.
|
43 |
+
3.
|
44 |
+
|
45 |
+
- type: textarea
|
46 |
+
id: expected-behavior
|
47 |
+
validations:
|
48 |
+
required: true
|
49 |
+
attributes:
|
50 |
+
label: Expected behavior / 期待表现
|
51 |
+
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
|
.github/ISSUE_TEMPLATE/feature-request.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F680 Feature request"
|
2 |
+
description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
|
3 |
+
labels: [ "feature" ]
|
4 |
+
body:
|
5 |
+
- type: textarea
|
6 |
+
id: feature-request
|
7 |
+
validations:
|
8 |
+
required: true
|
9 |
+
attributes:
|
10 |
+
label: Feature request / 功能建议
|
11 |
+
description: |
|
12 |
+
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
|
13 |
+
对功能建议的简述。最好提供对应的论文和代码链接。
|
14 |
+
|
15 |
+
- type: textarea
|
16 |
+
id: motivation
|
17 |
+
validations:
|
18 |
+
required: true
|
19 |
+
attributes:
|
20 |
+
label: Motivation / 动机
|
21 |
+
description: |
|
22 |
+
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
|
23 |
+
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
|
24 |
+
|
25 |
+
- type: textarea
|
26 |
+
id: contribution
|
27 |
+
validations:
|
28 |
+
required: true
|
29 |
+
attributes:
|
30 |
+
label: Your contribution / 您的贡献
|
31 |
+
description: |
|
32 |
+
|
33 |
+
Your PR link or any other link you can help with.
|
34 |
+
您的PR链接或者其他您能提供帮助的链接。
|
.github/PULL_REQUEST_TEMPLATE/pr_template.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Raise valuable PR / 提出有价值的PR
|
2 |
+
|
3 |
+
## Caution / 注意事项:
|
4 |
+
Users should keep the following points in mind when submitting PRs:
|
5 |
+
|
6 |
+
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
7 |
+
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
8 |
+
|
9 |
+
用户在提交PR时候应该注意以下几点:
|
10 |
+
|
11 |
+
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
12 |
+
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
13 |
+
|
14 |
+
## 不应该提出的PR / PRs that should not be proposed
|
15 |
+
|
16 |
+
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
17 |
+
|
18 |
+
1. those that don't describe improvement options.
|
19 |
+
2. multiple issues of different types combined in one PR.
|
20 |
+
3. The proposed PR is highly duplicative of already existing PRs.
|
21 |
+
|
22 |
+
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
23 |
+
|
24 |
+
1. 没有说明改进方案的。
|
25 |
+
2. 多个不同类型的问题合并在一个PR中的。
|
26 |
+
3. 提出的PR与已经存在的PR高度重复的。
|
27 |
+
|
28 |
+
|
29 |
+
# 检查您的PR
|
30 |
+
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
31 |
+
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
32 |
+
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
33 |
+
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
34 |
+
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output/
|
2 |
+
*__pycache__/
|
3 |
+
samples*/
|
4 |
+
runs/
|
5 |
+
checkpoints/
|
6 |
+
master_ip
|
7 |
+
logs/
|
8 |
+
*.DS_Store
|
9 |
+
.idea
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 CogVideo Model Team @ Zhipu AI
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
Model_License
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The CogVideoX License
|
2 |
+
|
3 |
+
1. Definitions
|
4 |
+
|
5 |
+
“Licensor” means the CogVideoX Model Team that distributes its Software.
|
6 |
+
|
7 |
+
“Software” means the CogVideoX model parameters made available under this license.
|
8 |
+
|
9 |
+
2. License Grant
|
10 |
+
|
11 |
+
Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
|
12 |
+
This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
|
13 |
+
Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
|
14 |
+
If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
|
15 |
+
The above copyright statement and this license statement should be included in all copies or significant portions of this software.
|
16 |
+
|
17 |
+
3. Restriction
|
18 |
+
|
19 |
+
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
|
20 |
+
|
21 |
+
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
|
22 |
+
|
23 |
+
4. Disclaimer
|
24 |
+
|
25 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
26 |
+
|
27 |
+
5. Limitation of Liability
|
28 |
+
|
29 |
+
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
30 |
+
|
31 |
+
6. Dispute Resolution
|
32 |
+
|
33 |
+
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
|
34 |
+
|
35 |
+
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
|
36 |
+
|
37 |
+
1. 定义
|
38 |
+
|
39 |
+
“许可方”是指分发其软件的 CogVideoX 模型团队。
|
40 |
+
|
41 |
+
“软件”是指根据本许可提供的 CogVideoX 模型参数。
|
42 |
+
|
43 |
+
2. 许可授予
|
44 |
+
|
45 |
+
根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
|
46 |
+
本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
|
47 |
+
|
48 |
+
经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
|
49 |
+
在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
|
50 |
+
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
|
51 |
+
|
52 |
+
3.限制
|
53 |
+
|
54 |
+
您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
|
55 |
+
|
56 |
+
您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
|
57 |
+
|
58 |
+
4.免责声明
|
59 |
+
|
60 |
+
本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
|
61 |
+
在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
|
62 |
+
|
63 |
+
5. 责任限制
|
64 |
+
|
65 |
+
除适用��律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
|
66 |
+
|
67 |
+
6.争议解决
|
68 |
+
|
69 |
+
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
|
70 |
+
|
71 |
+
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
|
README.md
CHANGED
@@ -1,12 +1,163 @@
|
|
1 |
---
|
2 |
title: CogVideo
|
3 |
-
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: CogVideo
|
3 |
+
app_file: gradio_demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.41.0
|
|
|
|
|
6 |
---
|
7 |
+
# CogVideo && CogVideoX
|
8 |
|
9 |
+
[中文阅读](./README_zh.md)
|
10 |
+
|
11 |
+
<div align="center">
|
12 |
+
<img src=resources/logo.svg width="50%"/>
|
13 |
+
</div>
|
14 |
+
<p align="center">
|
15 |
+
🤗 Experience on <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a>
|
16 |
+
</p>
|
17 |
+
<p align="center">
|
18 |
+
📚 Check here to view <a href="resources/CogVideoX.pdf" target="_blank">Paper</a>
|
19 |
+
</p>
|
20 |
+
<p align="center">
|
21 |
+
👋 Join our <a href="resources/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
|
22 |
+
</p>
|
23 |
+
<p align="center">
|
24 |
+
📍 Visit <a href="https://chatglm.cn/video?fr=osm_cogvideox">清影</a> and <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">API Platform</a> to experience larger-scale commercial video generation models.
|
25 |
+
</p>
|
26 |
+
|
27 |
+
## Update and News
|
28 |
+
|
29 |
+
- 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct
|
30 |
+
the video almost losslessly.
|
31 |
+
- 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**,the first model in the CogVideoX series of video
|
32 |
+
generation models.
|
33 |
+
- 🌱 **Source**: ```2022/5/19```: We have open-sourced CogVideo (now you can see in `CogVideo` branch),the **first** open-sourced pretrained text-to-video model, and you can check [ICLR'23 CogVideo Paper](https://arxiv.org/abs/2205.15868) for technical details.
|
34 |
+
|
35 |
+
**More powerful models with larger parameter sizes are on the way~ Stay tuned!**
|
36 |
+
|
37 |
+
## CogVideoX-2B Gallery
|
38 |
+
|
39 |
+
<div align="center">
|
40 |
+
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
|
41 |
+
<p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
|
42 |
+
</div>
|
43 |
+
|
44 |
+
<div align="center">
|
45 |
+
<video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
|
46 |
+
<p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
|
47 |
+
</div>
|
48 |
+
|
49 |
+
<div align="center">
|
50 |
+
<video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
|
51 |
+
<p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
|
52 |
+
</div>
|
53 |
+
|
54 |
+
<div align="center">
|
55 |
+
<video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
|
56 |
+
<p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
|
57 |
+
</div>
|
58 |
+
|
59 |
+
## Model Introduction
|
60 |
+
|
61 |
+
CogVideoX is an open-source version of the video generation model, which is homologous
|
62 |
+
to [清影](https://chatglm.cn/video?fr=osm_cogvideox).
|
63 |
+
|
64 |
+
The table below shows the list of video generation models we currently provide,
|
65 |
+
along with related basic information:
|
66 |
+
|
67 |
+
| Model Name | CogVideoX-2B |
|
68 |
+
|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
69 |
+
| Prompt Language | English |
|
70 |
+
| GPU Memory Required for Inference (FP16) | 18GB if using [SAT](https://github.com/THUDM/SwissArmyTransformer); 36GB if using diffusers (will be optimized before the PR is merged) |
|
71 |
+
| GPU Memory Required for Fine-tuning(bs=1) | 40GB |
|
72 |
+
| Prompt Max Length | 226 Tokens |
|
73 |
+
| Video Length | 6 seconds |
|
74 |
+
| Frames Per Second | 8 frames |
|
75 |
+
| Resolution | 720 * 480 |
|
76 |
+
| Quantized Inference | Not Supported |
|
77 |
+
| Multi-card Inference | Not Supported |
|
78 |
+
| Download Link (HF diffusers Model) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) [💫 WiseModel](https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b) |
|
79 |
+
| Download Link (SAT Model) | [SAT](./sat/README.md) |
|
80 |
+
|
81 |
+
## Project Structure
|
82 |
+
|
83 |
+
This open-source repository will guide developers to quickly get started with the basic usage and fine-tuning examples
|
84 |
+
of the **CogVideoX** open-source model.
|
85 |
+
|
86 |
+
### Inference
|
87 |
+
|
88 |
+
+ [cli_demo](inference/cli_demo.py): A more detailed explanation of the inference code, mentioning the significance of common parameters.
|
89 |
+
+ [cli_vae_demo](inference/cli_vae_demo.py): Executing the VAE inference code alone currently requires 71GB of memory, but it will be optimized in the future.
|
90 |
+
+ [convert_demo](inference/convert_demo.py): How to convert user input into a format suitable for CogVideoX. Because CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as GPT, Gemini, etc.
|
91 |
+
+ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model to generate videos.
|
92 |
+
|
93 |
+
<div style="text-align: center;">
|
94 |
+
<img src="resources/web_demo.png" style="width: 100%; height: auto;" />
|
95 |
+
</div>
|
96 |
+
|
97 |
+
### sat
|
98 |
+
|
99 |
+
+ [sat_demo](sat/README.md): Contains the inference code and fine-tuning code of SAT weights. It is
|
100 |
+
recommended to improve based on the CogVideoX model structure. Innovative researchers use this code to better perform
|
101 |
+
rapid stacking and development.
|
102 |
+
|
103 |
+
### Tools
|
104 |
+
|
105 |
+
This folder contains some tools for model conversion / caption generation, etc.
|
106 |
+
|
107 |
+
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): Convert SAT model weights to Huggingface model weights.
|
108 |
+
+ [caption_demo](tools/caption): Caption tool, a model that understands videos and outputs them in text.
|
109 |
+
|
110 |
+
## Project Plan
|
111 |
+
|
112 |
+
- [x] Open source CogVideoX model
|
113 |
+
- [x] Open source 3D Causal VAE used in CogVideoX.
|
114 |
+
- [x] CogVideoX model inference example (CLI / Web Demo)
|
115 |
+
- [x] CogVideoX online experience demo (Huggingface Space)
|
116 |
+
- [x] CogVideoX open source model API interface example (Huggingface)
|
117 |
+
- [x] CogVideoX model fine-tuning example (SAT)
|
118 |
+
- [ ] CogVideoX model fine-tuning example (Huggingface / SAT)
|
119 |
+
- [ ] Open source CogVideoX-Pro (adapted for CogVideoX-2B suite)
|
120 |
+
- [x] Release CogVideoX technical report
|
121 |
+
|
122 |
+
We welcome your contributions. You can click [here](resources/contribute.md) for more information.
|
123 |
+
|
124 |
+
## Model License
|
125 |
+
|
126 |
+
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
127 |
+
|
128 |
+
The model weights and implementation code are released under the [CogVideoX LICENSE](MODEL_LICENSE).
|
129 |
+
|
130 |
+
## CogVideo(ICLR'23)
|
131 |
+
The official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) is on the [CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)
|
132 |
+
|
133 |
+
**CogVideo is able to generate relatively high-frame-rate videos.**
|
134 |
+
A 4-second clip of 32 frames is shown below.
|
135 |
+
|
136 |
+
![High-frame-rate sample](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/appendix-sample-highframerate.png)
|
137 |
+
|
138 |
+
![Intro images](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/intro-image.png)
|
139 |
+
<div align="center">
|
140 |
+
<video src="https://github.com/user-attachments/assets/2fa19651-e925-4a2a-b8d6-b3f216d490ba" width="80%" controls autoplay></video>
|
141 |
+
</div>
|
142 |
+
|
143 |
+
|
144 |
+
The demo for CogVideo is at [https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.*
|
145 |
+
|
146 |
+
|
147 |
+
## Citation
|
148 |
+
|
149 |
+
🌟 If you find our work helpful, please leave us a star and cite our paper.
|
150 |
+
|
151 |
+
```
|
152 |
+
@article{yang2024cogvideox,
|
153 |
+
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
|
154 |
+
author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
|
155 |
+
year={2024},
|
156 |
+
}
|
157 |
+
@article{hong2022cogvideo,
|
158 |
+
title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
|
159 |
+
author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
|
160 |
+
journal={arXiv preprint arXiv:2205.15868},
|
161 |
+
year={2022}
|
162 |
+
}
|
163 |
+
```
|
README_zh.md
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CogVideo && CogVideoX
|
2 |
+
|
3 |
+
[Read this in English.](./README_zh)
|
4 |
+
|
5 |
+
|
6 |
+
<div align="center">
|
7 |
+
<img src=resources/logo.svg width="50%"/>
|
8 |
+
</div>
|
9 |
+
<p align="center">
|
10 |
+
🤗 在 <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a> 体验视频生成模型
|
11 |
+
</p>
|
12 |
+
<p align="center">
|
13 |
+
📚 查看 <a href="resources/CogVideoX.pdf" target="_blank">论文</a>
|
14 |
+
</p>
|
15 |
+
<p align="center">
|
16 |
+
👋 加入我们的 <a href="resources/WECHAT.md" target="_blank">微信</a> 和 <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
|
17 |
+
</p>
|
18 |
+
<p align="center">
|
19 |
+
📍 前往<a href="https://chatglm.cn/video?fr=osm_cogvideox"> 清影</a> 和 <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9"> API平台</a> 体验更大规模的商业版视频生成模型。
|
20 |
+
</p>
|
21 |
+
|
22 |
+
## 项目更新
|
23 |
+
|
24 |
+
- 🔥 **News**: ``2024/8/6``: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
|
25 |
+
- 🔥 **News**: ``2024/8/6``: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**。
|
26 |
+
- 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于 Transformer 的大型文本生成视频模型,您可以访问 [ICLR'23 论文](https://arxiv.org/abs/2205.15868) 查看技术细节。
|
27 |
+
**性能更强,参数量更大的模型正在到来的路上~,欢迎关注**
|
28 |
+
|
29 |
+
## CogVideoX-2B 视频作品
|
30 |
+
|
31 |
+
<div align="center">
|
32 |
+
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
|
33 |
+
<p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
|
34 |
+
</div>
|
35 |
+
|
36 |
+
<div align="center">
|
37 |
+
<video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
|
38 |
+
<p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
|
39 |
+
</div>
|
40 |
+
|
41 |
+
<div align="center">
|
42 |
+
<video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
|
43 |
+
<p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
|
44 |
+
</div>
|
45 |
+
|
46 |
+
<div align="center">
|
47 |
+
<video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
|
48 |
+
<p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
|
49 |
+
</div>
|
50 |
+
|
51 |
+
## 模型介绍
|
52 |
+
|
53 |
+
CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源版本视频生成模型。
|
54 |
+
|
55 |
+
下表战展示目前我们提供的视频生成模型列表,以及相关基础信息:
|
56 |
+
|
57 |
+
| 模型名字 | CogVideoX-2B |
|
58 |
+
|---------------------|--------------------------------------------------------------------------------------------------------------------------------------|
|
59 |
+
| 提示词语言 | English |
|
60 |
+
| 推理显存消耗 (FP-16) | 36GB using diffusers (will be optimized before the PR is merged) and 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) |
|
61 |
+
| 微调显存消耗 (bs=1) | 42GB |
|
62 |
+
| 提示词长度上限 | 226 Tokens |
|
63 |
+
| 视频长度 | 6 seconds |
|
64 |
+
| 帧率(每秒) | 8 frames |
|
65 |
+
| 视频分辨率 | 720 * 480 |
|
66 |
+
| 量化推理 | 不支持 |
|
67 |
+
| 多卡推理 | 不支持 |
|
68 |
+
| 下载地址 (Diffusers 模型) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) |
|
69 |
+
| 下载地址 (SAT 模型) | [SAT](./sat/README_zh.md) |
|
70 |
+
|
71 |
+
## 项目结构
|
72 |
+
|
73 |
+
本开源仓库将带领开发者快速上手 **CogVideoX** 开源模型的基础调用方式、微调示例。
|
74 |
+
|
75 |
+
### inference
|
76 |
+
|
77 |
+
+ [cli_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
|
78 |
+
+ [cli_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码,目前需要71GB显存,将来会优化。
|
79 |
+
+ [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的,所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4,也可以替换为GPT、Gemini等任意大语言模型。
|
80 |
+
+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
|
81 |
+
|
82 |
+
<div style="text-align: center;">
|
83 |
+
<img src="resources/web_demo.png" style="width: 100%; height: auto;" />
|
84 |
+
</div>
|
85 |
+
|
86 |
+
### sat
|
87 |
+
|
88 |
+
+ [sat_demo](sat/README_zh.md): 包含了 SAT 权重的推理代码和微调代码,推荐基于 CogVideoX
|
89 |
+
模型结构进行改进,创新的研究者使用改代码以更好的进行快速的堆叠和开发。
|
90 |
+
|
91 |
+
### tools
|
92 |
+
|
93 |
+
本文件夹包含了一些工具,用于模型的转换 / Caption 等工作。
|
94 |
+
|
95 |
+
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): 将 SAT 模型权重转换为 Huggingface 模型权重。
|
96 |
+
+ [caption_demo](tools/caption/README_zh.md): Caption 工具,对视频理解并用文字输出的模型。
|
97 |
+
|
98 |
+
## 项目规划
|
99 |
+
|
100 |
+
- [x] CogVideoX 模型开源
|
101 |
+
- [x] CogVideoX 模型推理示例 (CLI / Web Demo)
|
102 |
+
- [x] CogVideoX 在线体验示例 (Huggingface Space)
|
103 |
+
- [x] CogVideoX 开源模型API接口示例 (Huggingface)
|
104 |
+
- [x] CogVideoX 模型微调示例 (SAT)
|
105 |
+
- [ ] CogVideoX 模型微调示例 (Huggingface / SAT)
|
106 |
+
- [ ] CogVideoX-Pro 开源(适配 CogVideoX-2B 套件)
|
107 |
+
- [ ] CogVideoX 技术报告公开
|
108 |
+
|
109 |
+
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
|
110 |
+
|
111 |
+
## 模型协议
|
112 |
+
|
113 |
+
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
114 |
+
|
115 |
+
本模型权重和模型实现代码根据 [CogVideoX LICENSE](MODEL_LICENSE) 许可证发布。
|
116 |
+
|
117 |
+
## CogVideo(ICLR'23)
|
118 |
+
[CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) 的官方repo位于[CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)。
|
119 |
+
|
120 |
+
**CogVideo可以生成高帧率视频,下面展示了一个32帧的4秒视频。**
|
121 |
+
|
122 |
+
![High-frame-rate sample](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/appendix-sample-highframerate.png)
|
123 |
+
|
124 |
+
![Intro images](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/intro-image.png)
|
125 |
+
|
126 |
+
|
127 |
+
<div align="center">
|
128 |
+
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
|
129 |
+
</div>
|
130 |
+
|
131 |
+
CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。*原始输入为中文。*
|
132 |
+
|
133 |
+
## 引用
|
134 |
+
|
135 |
+
🌟 如果您发现我们的工作有所帮助,欢迎引用我们的文章,留下宝贵的stars
|
136 |
+
|
137 |
+
```
|
138 |
+
@article{yang2024cogvideox,
|
139 |
+
title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
|
140 |
+
author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
|
141 |
+
year={2024},
|
142 |
+
}
|
143 |
+
@article{hong2022cogvideo,
|
144 |
+
title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
|
145 |
+
author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
|
146 |
+
journal={arXiv preprint arXiv:2205.15868},
|
147 |
+
year={2022}
|
148 |
+
}
|
149 |
+
```
|
gradio_demo.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from diffusers import CogVideoXPipeline
|
10 |
+
from datetime import datetime, timedelta
|
11 |
+
from openai import OpenAI
|
12 |
+
import spaces
|
13 |
+
import imageio
|
14 |
+
import moviepy.editor as mp
|
15 |
+
from typing import List, Union
|
16 |
+
import PIL
|
17 |
+
|
18 |
+
dtype = torch.bfloat16
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
|
21 |
+
|
22 |
+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
23 |
+
|
24 |
+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
25 |
+
There are a few rules to follow:
|
26 |
+
|
27 |
+
You will only ever output a single video description per user request.
|
28 |
+
|
29 |
+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
30 |
+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
31 |
+
|
32 |
+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
def export_to_video_imageio(
|
37 |
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
38 |
+
) -> str:
|
39 |
+
"""
|
40 |
+
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
41 |
+
"""
|
42 |
+
if output_video_path is None:
|
43 |
+
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
44 |
+
|
45 |
+
if isinstance(video_frames[0], PIL.Image.Image):
|
46 |
+
video_frames = [np.array(frame) for frame in video_frames]
|
47 |
+
|
48 |
+
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
49 |
+
for frame in video_frames:
|
50 |
+
writer.append_data(frame)
|
51 |
+
|
52 |
+
return output_video_path
|
53 |
+
|
54 |
+
|
55 |
+
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
56 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
57 |
+
return prompt
|
58 |
+
client = OpenAI()
|
59 |
+
text = prompt.strip()
|
60 |
+
|
61 |
+
for i in range(retry_times):
|
62 |
+
response = client.chat.completions.create(
|
63 |
+
messages=[
|
64 |
+
{"role": "system", "content": sys_prompt},
|
65 |
+
{"role": "user",
|
66 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
|
67 |
+
{"role": "assistant",
|
68 |
+
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
|
69 |
+
{"role": "user",
|
70 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
|
71 |
+
{"role": "assistant",
|
72 |
+
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
|
73 |
+
{"role": "user",
|
74 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
|
75 |
+
{"role": "assistant",
|
76 |
+
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
|
77 |
+
{"role": "user",
|
78 |
+
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
|
79 |
+
],
|
80 |
+
model="glm-4-0520",
|
81 |
+
temperature=0.01,
|
82 |
+
top_p=0.7,
|
83 |
+
stream=False,
|
84 |
+
max_tokens=250,
|
85 |
+
)
|
86 |
+
if response.choices:
|
87 |
+
return response.choices[0].message.content
|
88 |
+
return prompt
|
89 |
+
|
90 |
+
|
91 |
+
@spaces.GPU(duration=240)
|
92 |
+
def infer(
|
93 |
+
prompt: str,
|
94 |
+
num_inference_steps: int,
|
95 |
+
guidance_scale: float,
|
96 |
+
progress=gr.Progress(track_tqdm=True)
|
97 |
+
):
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
|
100 |
+
prompt_embeds, _ = pipe.encode_prompt(
|
101 |
+
prompt=prompt,
|
102 |
+
negative_prompt=None,
|
103 |
+
do_classifier_free_guidance=True,
|
104 |
+
num_videos_per_prompt=1,
|
105 |
+
max_sequence_length=226,
|
106 |
+
device=device,
|
107 |
+
dtype=dtype,
|
108 |
+
)
|
109 |
+
|
110 |
+
video = pipe(
|
111 |
+
num_inference_steps=num_inference_steps,
|
112 |
+
guidance_scale=guidance_scale,
|
113 |
+
prompt_embeds=prompt_embeds,
|
114 |
+
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
115 |
+
).frames[0]
|
116 |
+
|
117 |
+
|
118 |
+
return video
|
119 |
+
|
120 |
+
|
121 |
+
def save_video(tensor):
|
122 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
123 |
+
video_path = f"./output/{timestamp}.mp4"
|
124 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
125 |
+
export_to_video_imageio(tensor[1:], video_path)
|
126 |
+
return video_path
|
127 |
+
|
128 |
+
def convert_to_gif(video_path):
|
129 |
+
clip = mp.VideoFileClip(video_path)
|
130 |
+
clip = clip.set_fps(8)
|
131 |
+
clip = clip.resize(height=240)
|
132 |
+
gif_path = video_path.replace('.mp4', '.gif')
|
133 |
+
clip.write_gif(gif_path, fps=8)
|
134 |
+
return gif_path
|
135 |
+
|
136 |
+
|
137 |
+
def delete_old_files():
|
138 |
+
while True:
|
139 |
+
now = datetime.now()
|
140 |
+
cutoff = now - timedelta(minutes=10)
|
141 |
+
output_dir = './output'
|
142 |
+
for filename in os.listdir(output_dir):
|
143 |
+
file_path = os.path.join(output_dir, filename)
|
144 |
+
if os.path.isfile(file_path):
|
145 |
+
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
|
146 |
+
if file_mtime < cutoff:
|
147 |
+
os.remove(file_path)
|
148 |
+
time.sleep(600) # Sleep for 10 minutes
|
149 |
+
|
150 |
+
|
151 |
+
threading.Thread(target=delete_old_files, daemon=True).start()
|
152 |
+
|
153 |
+
with gr.Blocks() as demo:
|
154 |
+
gr.Markdown("""
|
155 |
+
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
156 |
+
CogVideoX-2B Huggingface Space🤗
|
157 |
+
</div>
|
158 |
+
<div style="text-align: center;">
|
159 |
+
<a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 Model Hub</a> |
|
160 |
+
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a>
|
161 |
+
</div>
|
162 |
+
|
163 |
+
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
164 |
+
⚠️ This demo is for academic research and experiential use only.
|
165 |
+
Users should strictly adhere to local laws and ethics.
|
166 |
+
</div>
|
167 |
+
""")
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column():
|
170 |
+
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
171 |
+
with gr.Row():
|
172 |
+
gr.Markdown(
|
173 |
+
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
|
174 |
+
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
175 |
+
|
176 |
+
with gr.Column():
|
177 |
+
gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
|
178 |
+
"Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
|
179 |
+
"50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
|
180 |
+
with gr.Row():
|
181 |
+
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
182 |
+
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
183 |
+
generate_button = gr.Button("🎬 Generate Video")
|
184 |
+
|
185 |
+
with gr.Column():
|
186 |
+
video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
|
187 |
+
with gr.Row():
|
188 |
+
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
189 |
+
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
190 |
+
|
191 |
+
gr.Markdown("""
|
192 |
+
<table border="1" style="width: 100%; text-align: left; margin-top: 20px;">
|
193 |
+
<tr>
|
194 |
+
<th>Prompt</th>
|
195 |
+
<th>Video URL</th>
|
196 |
+
<th>Inference Steps</th>
|
197 |
+
<th>Guidance Scale</th>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<td>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</td>
|
201 |
+
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/1.mp4">Video 1</a></td>
|
202 |
+
<td>50</td>
|
203 |
+
<td>6</td>
|
204 |
+
</tr>
|
205 |
+
<tr>
|
206 |
+
<td>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</td>
|
207 |
+
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/2.mp4">Video 2</a></td>
|
208 |
+
<td>50</td>
|
209 |
+
<td>6</td>
|
210 |
+
</tr>
|
211 |
+
<tr>
|
212 |
+
<td>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</td>
|
213 |
+
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/3.mp4">Video 3</a></td>
|
214 |
+
<td>50</td>
|
215 |
+
<td>6</td>
|
216 |
+
</tr>
|
217 |
+
<tr>
|
218 |
+
<td>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</td>
|
219 |
+
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/4.mp4">Video 4</a></td>
|
220 |
+
<td>50</td>
|
221 |
+
<td>6</td>
|
222 |
+
</tr>
|
223 |
+
</table>
|
224 |
+
""")
|
225 |
+
|
226 |
+
|
227 |
+
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
228 |
+
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
229 |
+
video_path = save_video(tensor)
|
230 |
+
video_update = gr.update(visible=True, value=video_path)
|
231 |
+
gif_path = convert_to_gif(video_path)
|
232 |
+
gif_update = gr.update(visible=True, value=gif_path)
|
233 |
+
|
234 |
+
return video_path, video_update, gif_update
|
235 |
+
|
236 |
+
|
237 |
+
def enhance_prompt_func(prompt):
|
238 |
+
return convert_prompt(prompt, retry_times=1)
|
239 |
+
|
240 |
+
|
241 |
+
generate_button.click(
|
242 |
+
generate,
|
243 |
+
inputs=[prompt, num_inference_steps, guidance_scale],
|
244 |
+
outputs=[video_output, download_video_button, download_gif_button]
|
245 |
+
)
|
246 |
+
|
247 |
+
enhance_button.click(
|
248 |
+
enhance_prompt_func,
|
249 |
+
inputs=[prompt],
|
250 |
+
outputs=[prompt]
|
251 |
+
)
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
|
inference/cli_demo.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
3 |
+
|
4 |
+
Note:
|
5 |
+
This script requires the `diffusers>=0.30.0` library to be installed.
|
6 |
+
If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
|
7 |
+
|
8 |
+
Run the script:
|
9 |
+
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import tempfile
|
15 |
+
from typing import Union, List
|
16 |
+
|
17 |
+
import PIL
|
18 |
+
import imageio
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from diffusers import CogVideoXPipeline
|
22 |
+
|
23 |
+
|
24 |
+
def export_to_video_imageio(
|
25 |
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
26 |
+
) -> str:
|
27 |
+
"""
|
28 |
+
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
29 |
+
"""
|
30 |
+
if output_video_path is None:
|
31 |
+
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
32 |
+
if isinstance(video_frames[0], PIL.Image.Image):
|
33 |
+
video_frames = [np.array(frame) for frame in video_frames]
|
34 |
+
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
35 |
+
for frame in video_frames:
|
36 |
+
writer.append_data(frame)
|
37 |
+
return output_video_path
|
38 |
+
|
39 |
+
|
40 |
+
def generate_video(
|
41 |
+
prompt: str,
|
42 |
+
model_path: str,
|
43 |
+
output_path: str = "./output.mp4",
|
44 |
+
num_inference_steps: int = 50,
|
45 |
+
guidance_scale: float = 6.0,
|
46 |
+
num_videos_per_prompt: int = 1,
|
47 |
+
device: str = "cuda",
|
48 |
+
dtype: torch.dtype = torch.float16,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Generates a video based on the given prompt and saves it to the specified path.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
- prompt (str): The description of the video to be generated.
|
55 |
+
- model_path (str): The path of the pre-trained model to be used.
|
56 |
+
- output_path (str): The path where the generated video will be saved.
|
57 |
+
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
58 |
+
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
59 |
+
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
60 |
+
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
61 |
+
- dtype (torch.dtype): The data type for computation (default is torch.float16).
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
65 |
+
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
66 |
+
|
67 |
+
# Encode the prompt to get the prompt embeddings
|
68 |
+
prompt_embeds, _ = pipe.encode_prompt(
|
69 |
+
prompt=prompt, # The textual description for video generation
|
70 |
+
negative_prompt=None, # The negative prompt to guide the video generation
|
71 |
+
do_classifier_free_guidance=True, # Whether to use classifier-free guidance
|
72 |
+
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
73 |
+
max_sequence_length=226, # Maximum length of the sequence, must be 226
|
74 |
+
device=device, # Device to use for computation
|
75 |
+
dtype=dtype, # Data type for computation
|
76 |
+
)
|
77 |
+
|
78 |
+
# Generate the video frames using the pipeline
|
79 |
+
video = pipe(
|
80 |
+
num_inference_steps=num_inference_steps, # Number of inference steps
|
81 |
+
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
|
82 |
+
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
|
83 |
+
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
|
84 |
+
).frames[0]
|
85 |
+
|
86 |
+
# Export the generated frames to a video file. fps must be 8
|
87 |
+
export_to_video_imageio(video, output_path, fps=8)
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
92 |
+
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
93 |
+
parser.add_argument(
|
94 |
+
"--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
101 |
+
)
|
102 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
103 |
+
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
104 |
+
parser.add_argument(
|
105 |
+
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
106 |
+
)
|
107 |
+
|
108 |
+
parser.add_argument(
|
109 |
+
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
110 |
+
)
|
111 |
+
|
112 |
+
args = parser.parse_args()
|
113 |
+
|
114 |
+
# Convert dtype argument to torch.dtype, NOT suggest BF16.
|
115 |
+
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
116 |
+
|
117 |
+
# main function to generate video.
|
118 |
+
generate_video(
|
119 |
+
prompt=args.prompt,
|
120 |
+
model_path=args.model_path,
|
121 |
+
output_path=args.output_path,
|
122 |
+
num_inference_steps=args.num_inference_steps,
|
123 |
+
guidance_scale=args.guidance_scale,
|
124 |
+
num_videos_per_prompt=args.num_videos_per_prompt,
|
125 |
+
device=args.device,
|
126 |
+
dtype=dtype,
|
127 |
+
)
|
inference/cli_vae_demo.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
|
3 |
+
|
4 |
+
Note:
|
5 |
+
This script requires the `diffusers>=0.30.0` library to be installed.
|
6 |
+
If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
|
7 |
+
Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
|
8 |
+
|
9 |
+
Run the script:
|
10 |
+
$ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import torch
|
16 |
+
import imageio
|
17 |
+
import numpy as np
|
18 |
+
from diffusers import AutoencoderKLCogVideoX
|
19 |
+
from torchvision import transforms
|
20 |
+
|
21 |
+
|
22 |
+
def vae_demo(model_path, video_path, dtype, device):
|
23 |
+
"""
|
24 |
+
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
- model_path (str): The path to the pre-trained model.
|
28 |
+
- video_path (str): The path to the video file.
|
29 |
+
- dtype (torch.dtype): The data type for computation.
|
30 |
+
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
- torch.Tensor: The encoded video frames.
|
34 |
+
"""
|
35 |
+
# Load the pre-trained model
|
36 |
+
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
37 |
+
|
38 |
+
# Load video frames
|
39 |
+
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
40 |
+
frames = []
|
41 |
+
for frame in video_reader:
|
42 |
+
frames.append(frame)
|
43 |
+
video_reader.close()
|
44 |
+
|
45 |
+
# Transform frames to Tensor
|
46 |
+
transform = transforms.Compose(
|
47 |
+
[
|
48 |
+
transforms.ToTensor(),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
52 |
+
|
53 |
+
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
54 |
+
frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
|
55 |
+
|
56 |
+
# Run the model with Encoder and Decoder
|
57 |
+
with torch.no_grad():
|
58 |
+
output = model(frames_tensor)
|
59 |
+
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
def save_video(tensor, output_path):
|
64 |
+
"""
|
65 |
+
Saves the encoded video frames to a video file.
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
- tensor (torch.Tensor): The encoded video frames.
|
69 |
+
- output_path (str): The path to save the output video.
|
70 |
+
"""
|
71 |
+
# Remove batch dimension and permute back to [49, 480, 720, 3]
|
72 |
+
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
73 |
+
|
74 |
+
# Clip values to [0, 1] and convert to uint8
|
75 |
+
frames = np.clip(frames, 0, 1)
|
76 |
+
frames = (frames * 255).astype(np.uint8)
|
77 |
+
|
78 |
+
# Save frames to video
|
79 |
+
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
80 |
+
for frame in frames:
|
81 |
+
writer.append_data(frame)
|
82 |
+
writer.close()
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
87 |
+
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
88 |
+
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
89 |
+
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
|
90 |
+
parser.add_argument(
|
91 |
+
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
95 |
+
)
|
96 |
+
args = parser.parse_args()
|
97 |
+
|
98 |
+
# Set device and dtype
|
99 |
+
device = torch.device(args.device)
|
100 |
+
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
101 |
+
|
102 |
+
output = vae_demo(args.model_path, args.video_path, dtype, device)
|
103 |
+
save_video(output, args.output_path)
|
inference/convert_demo.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
The CogVideoX model is pre-trained and fine-tuned using longer and more detailed prompts.Therefore, it requires highly granular and detailed prompts as input.This script aims to transform user inputs into executable inputs for CogVideoX, enabling superior video generation.
|
4 |
+
|
5 |
+
This step is not mandatory; the model will still function correctly and without errors even if the prompts are not refined using this script. However, we strongly recommend using it to ensure the generation of high-quality videos.
|
6 |
+
|
7 |
+
Note:
|
8 |
+
Please set the OPENAI_API_KEY and OPENAI_BASE_URL(if needed) environment variable to your OpenAI API key before running this script.
|
9 |
+
|
10 |
+
Run the script:
|
11 |
+
$ python convert_demo.py --prompt "A girl ridding a bike." # Using with OpenAI's API
|
12 |
+
"""
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
|
16 |
+
from openai import OpenAI
|
17 |
+
|
18 |
+
|
19 |
+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
20 |
+
|
21 |
+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
22 |
+
There are a few rules to follow:
|
23 |
+
|
24 |
+
You will only ever output a single video description per user request.
|
25 |
+
|
26 |
+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
27 |
+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
28 |
+
|
29 |
+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
30 |
+
"""
|
31 |
+
|
32 |
+
|
33 |
+
def convert_prompt(prompt: str, retry_times: int = 3):
|
34 |
+
"""
|
35 |
+
Convert a prompt to a format that can be used by the model for inference
|
36 |
+
"""
|
37 |
+
|
38 |
+
client = OpenAI()
|
39 |
+
text = prompt.strip()
|
40 |
+
|
41 |
+
for i in range(retry_times):
|
42 |
+
response = client.chat.completions.create(
|
43 |
+
messages=[
|
44 |
+
{"role": "system", "content": f"{sys_prompt}"},
|
45 |
+
{
|
46 |
+
"role": "user",
|
47 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " a girl is on the beach"',
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"role": "assistant",
|
51 |
+
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"role": "user",
|
55 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A man jogging on a football field"',
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"role": "assistant",
|
59 |
+
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"role": "user",
|
63 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"role": "assistant",
|
67 |
+
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"role": "user",
|
71 |
+
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
|
72 |
+
},
|
73 |
+
],
|
74 |
+
model="glm-4-0520", # glm-4-0520 and gpt-4o have be tested
|
75 |
+
temperature=0.01,
|
76 |
+
top_p=0.7,
|
77 |
+
stream=False,
|
78 |
+
max_tokens=250,
|
79 |
+
)
|
80 |
+
if response.choices:
|
81 |
+
return response.choices[0].message.content
|
82 |
+
return prompt
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
parser = argparse.ArgumentParser()
|
87 |
+
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
88 |
+
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
converted_prompt = convert_prompt(args.prompt, args.retry_times)
|
92 |
+
print(converted_prompt)
|
inference/web_demo.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is used to create a Streamlit web application for generating videos using the CogVideoX model.
|
3 |
+
|
4 |
+
Run the script using Streamlit:
|
5 |
+
$ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key
|
6 |
+
$ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI
|
7 |
+
$ streamlit run web_demo.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import base64
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
import time
|
14 |
+
from datetime import datetime
|
15 |
+
from typing import List
|
16 |
+
|
17 |
+
import imageio
|
18 |
+
import numpy as np
|
19 |
+
import streamlit as st
|
20 |
+
import torch
|
21 |
+
from convert_demo import convert_prompt
|
22 |
+
from diffusers import CogVideoXPipeline
|
23 |
+
|
24 |
+
|
25 |
+
model_path: str = "THUDM/CogVideoX-2b"
|
26 |
+
|
27 |
+
|
28 |
+
# Load the model at the start
|
29 |
+
@st.cache_resource
|
30 |
+
def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline:
|
31 |
+
"""
|
32 |
+
Load the CogVideoX model.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
- model_path (str): Path to the model.
|
36 |
+
- dtype (torch.dtype): Data type for model.
|
37 |
+
- device (str): Device to load the model on.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
- CogVideoXPipeline: Loaded model pipeline.
|
41 |
+
"""
|
42 |
+
return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
43 |
+
|
44 |
+
|
45 |
+
# Define a function to generate video based on the provided prompt and model path
|
46 |
+
def generate_video(
|
47 |
+
pipe: CogVideoXPipeline,
|
48 |
+
prompt: str,
|
49 |
+
num_inference_steps: int = 50,
|
50 |
+
guidance_scale: float = 6.0,
|
51 |
+
num_videos_per_prompt: int = 1,
|
52 |
+
device: str = "cuda",
|
53 |
+
dtype: torch.dtype = torch.float16,
|
54 |
+
) -> List[np.ndarray]:
|
55 |
+
"""
|
56 |
+
Generate a video based on the provided prompt and model path.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
- pipe (CogVideoXPipeline): The pipeline for generating videos.
|
60 |
+
- prompt (str): Text prompt for video generation.
|
61 |
+
- num_inference_steps (int): Number of inference steps.
|
62 |
+
- guidance_scale (float): Guidance scale for generation.
|
63 |
+
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
64 |
+
- device (str): Device to run the generation on.
|
65 |
+
- dtype (torch.dtype): Data type for the model.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- List[np.ndarray]: Generated video frames.
|
69 |
+
"""
|
70 |
+
prompt_embeds, _ = pipe.encode_prompt(
|
71 |
+
prompt=prompt,
|
72 |
+
negative_prompt=None,
|
73 |
+
do_classifier_free_guidance=True,
|
74 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
75 |
+
max_sequence_length=226,
|
76 |
+
device=device,
|
77 |
+
dtype=dtype,
|
78 |
+
)
|
79 |
+
|
80 |
+
# Generate video
|
81 |
+
video = pipe(
|
82 |
+
num_inference_steps=num_inference_steps,
|
83 |
+
guidance_scale=guidance_scale,
|
84 |
+
prompt_embeds=prompt_embeds,
|
85 |
+
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
86 |
+
).frames[0]
|
87 |
+
return video
|
88 |
+
|
89 |
+
|
90 |
+
def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None:
|
91 |
+
"""
|
92 |
+
Save the generated video to a file.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
- video (List[np.ndarray]): Video frames.
|
96 |
+
- path (str): Path to save the video.
|
97 |
+
- fps (int): Frames per second for the video.
|
98 |
+
"""
|
99 |
+
# Remove the first frame
|
100 |
+
video = video[1:]
|
101 |
+
|
102 |
+
writer = imageio.get_writer(path, fps=fps, codec="libx264")
|
103 |
+
for frame in video:
|
104 |
+
np_frame = np.array(frame)
|
105 |
+
writer.append_data(np_frame)
|
106 |
+
|
107 |
+
writer.close()
|
108 |
+
|
109 |
+
|
110 |
+
def save_metadata(
|
111 |
+
prompt: str,
|
112 |
+
converted_prompt: str,
|
113 |
+
num_inference_steps: int,
|
114 |
+
guidance_scale: float,
|
115 |
+
num_videos_per_prompt: int,
|
116 |
+
path: str,
|
117 |
+
) -> None:
|
118 |
+
"""
|
119 |
+
Save metadata to a JSON file.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
- prompt (str): Original prompt.
|
123 |
+
- converted_prompt (str): Converted prompt.
|
124 |
+
- num_inference_steps (int): Number of inference steps.
|
125 |
+
- guidance_scale (float): Guidance scale.
|
126 |
+
- num_videos_per_prompt (int): Number of videos per prompt.
|
127 |
+
- path (str): Path to save the metadata.
|
128 |
+
"""
|
129 |
+
metadata = {
|
130 |
+
"prompt": prompt,
|
131 |
+
"converted_prompt": converted_prompt,
|
132 |
+
"num_inference_steps": num_inference_steps,
|
133 |
+
"guidance_scale": guidance_scale,
|
134 |
+
"num_videos_per_prompt": num_videos_per_prompt,
|
135 |
+
}
|
136 |
+
with open(path, "w") as f:
|
137 |
+
json.dump(metadata, f, indent=4)
|
138 |
+
|
139 |
+
|
140 |
+
def main() -> None:
|
141 |
+
"""
|
142 |
+
Main function to run the Streamlit web application.
|
143 |
+
"""
|
144 |
+
st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide")
|
145 |
+
st.write("# CogVideoX 🎥")
|
146 |
+
dtype: torch.dtype = torch.float16
|
147 |
+
device: str = "cuda"
|
148 |
+
|
149 |
+
global pipe
|
150 |
+
pipe = load_model(model_path, dtype, device)
|
151 |
+
|
152 |
+
with st.sidebar:
|
153 |
+
st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️")
|
154 |
+
num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50)
|
155 |
+
guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0)
|
156 |
+
num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1)
|
157 |
+
|
158 |
+
share_links_container = st.empty()
|
159 |
+
|
160 |
+
prompt: str = st.chat_input("Prompt")
|
161 |
+
|
162 |
+
if prompt:
|
163 |
+
# Not Necessary, Suggestions
|
164 |
+
with st.spinner("Refining prompts..."):
|
165 |
+
converted_prompt = convert_prompt(prompt=prompt, retry_times=1)
|
166 |
+
if converted_prompt is None:
|
167 |
+
st.error("Failed to Refining the prompt, Using origin one.")
|
168 |
+
|
169 |
+
st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}")
|
170 |
+
torch.cuda.empty_cache()
|
171 |
+
|
172 |
+
with st.spinner("Generating Video..."):
|
173 |
+
start_time = time.time()
|
174 |
+
video_paths = []
|
175 |
+
|
176 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
177 |
+
output_dir = f"./output/{timestamp}"
|
178 |
+
os.makedirs(output_dir, exist_ok=True)
|
179 |
+
|
180 |
+
metadata_path = os.path.join(output_dir, "config.json")
|
181 |
+
save_metadata(
|
182 |
+
prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path
|
183 |
+
)
|
184 |
+
|
185 |
+
for i in range(num_videos_per_prompt):
|
186 |
+
video_path = os.path.join(output_dir, f"output_{i + 1}.mp4")
|
187 |
+
|
188 |
+
video = generate_video(
|
189 |
+
pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype
|
190 |
+
)
|
191 |
+
save_video(video, video_path, fps=8)
|
192 |
+
video_paths.append(video_path)
|
193 |
+
with open(video_path, "rb") as video_file:
|
194 |
+
video_bytes: bytes = video_file.read()
|
195 |
+
st.video(video_bytes, autoplay=True, loop=True, format="video/mp4")
|
196 |
+
torch.cuda.empty_cache()
|
197 |
+
|
198 |
+
used_time: float = time.time() - start_time
|
199 |
+
st.success(f"Videos generated in {used_time:.2f} seconds.")
|
200 |
+
|
201 |
+
# Create download links in the sidebar
|
202 |
+
with share_links_container:
|
203 |
+
st.sidebar.write("### Download Links:")
|
204 |
+
for video_path in video_paths:
|
205 |
+
video_name = os.path.basename(video_path)
|
206 |
+
with open(video_path, "rb") as f:
|
207 |
+
video_bytes: bytes = f.read()
|
208 |
+
b64_video = base64.b64encode(video_bytes).decode()
|
209 |
+
href = f'<a href="data:video/mp4;base64,{b64_video}" download="{video_name}">Download {video_name}</a>'
|
210 |
+
st.sidebar.markdown(href, unsafe_allow_html=True)
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.ruff]
|
2 |
+
line-length = 119
|
3 |
+
|
4 |
+
[tool.ruff.lint]
|
5 |
+
# Never enforce `E501` (line length violations).
|
6 |
+
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
7 |
+
select = ["C", "E", "F", "I", "W"]
|
8 |
+
|
9 |
+
# Ignore import violations in all `__init__.py` files.
|
10 |
+
[tool.ruff.lint.per-file-ignores]
|
11 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
12 |
+
|
13 |
+
[tool.ruff.lint.isort]
|
14 |
+
lines-after-imports = 2
|
15 |
+
|
16 |
+
[tool.ruff.format]
|
17 |
+
# Like Black, use double quotes for strings.
|
18 |
+
quote-style = "double"
|
19 |
+
|
20 |
+
# Like Black, indent with spaces, rather than tabs.
|
21 |
+
indent-style = "space"
|
22 |
+
|
23 |
+
# Like Black, respect magic trailing commas.
|
24 |
+
skip-magic-trailing-comma = false
|
25 |
+
|
26 |
+
# Like Black, automatically detect the appropriate line ending.
|
27 |
+
line-ending = "auto"
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/diffusers.git@d1c575ad7ee0390c2735f50cc59a79aae666567a#egg=diffusers
|
2 |
+
torch==2.4.0
|
3 |
+
torchvision==0.19.0
|
4 |
+
streamlit==1.37.0
|
5 |
+
opencv-python
|
6 |
+
imageio-ffmpeg==0.5.1
|
7 |
+
openai==1.38.0
|
8 |
+
transformers==4.43.4
|
9 |
+
accelerate==0.33.0
|
10 |
+
sentencepiece==0.2.0
|
11 |
+
pillow==9.5.0
|
resources/CogVideoX.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25ba30aafcd9604178c6d7adbd17f2bf1b251f3d29d1d29498e576075cb67c4e
|
3 |
+
size 31028426
|
resources/WECHAT.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src=wechat.jpg width="60%"/>
|
3 |
+
|
4 |
+
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
5 |
+
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
6 |
+
</div>
|
7 |
+
|
resources/contribute.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contribution Guide
|
2 |
+
|
3 |
+
There may still be many incomplete aspects in this project.
|
4 |
+
|
5 |
+
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
|
6 |
+
and are willing to submit a PR and share it with the community, upon review, we
|
7 |
+
will acknowledge your contribution on the project homepage.
|
8 |
+
|
9 |
+
## Model Algorithms
|
10 |
+
|
11 |
+
- Support for model quantization inference (Int4, Int8, etc. quantization engineering)
|
12 |
+
- Support for multi-card inference / model inference concurrency engineering
|
13 |
+
- Support for non-CUDA architecture inference devices
|
14 |
+
|
15 |
+
## Model Engineering / Secondary Development
|
16 |
+
|
17 |
+
- Model fine-tuning examples / best prompt practices
|
18 |
+
- Video super-resolution/frame interpolation for enhancing video generation quality.
|
19 |
+
- Any peripheral tools for the model
|
20 |
+
- Any minimal complete open-source projects using the CogVideoX open-source model
|
21 |
+
|
22 |
+
## Code Standards
|
23 |
+
|
24 |
+
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
|
25 |
+
style. You can organize the code according to the following specifications:
|
26 |
+
|
27 |
+
1. Install the `ruff` tool
|
28 |
+
|
29 |
+
```shell
|
30 |
+
pip install ruff
|
31 |
+
```
|
32 |
+
|
33 |
+
Then, run the `ruff` tool
|
34 |
+
|
35 |
+
```shell
|
36 |
+
ruff check tools sat inference
|
37 |
+
```
|
38 |
+
|
39 |
+
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
|
40 |
+
|
41 |
+
```shell
|
42 |
+
ruff format tools sat inference
|
43 |
+
```
|
44 |
+
|
45 |
+
Once your code meets the standard, there should be no errors.
|
46 |
+
|
47 |
+
## Naming Conventions
|
48 |
+
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
|
49 |
+
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
|
50 |
+
|
resources/contribute_zh.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 贡献指南
|
2 |
+
|
3 |
+
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
|
4 |
+
|
5 |
+
## 模型算法
|
6 |
+
|
7 |
+
- 模型量化推理支持 (Int4,Int8等量化工程)
|
8 |
+
- 模型多卡推理支持 / 模型推理并发工程
|
9 |
+
- 非 CUDA 架构 推理设备支持
|
10 |
+
|
11 |
+
## 模型工程 / 模型二次开发
|
12 |
+
|
13 |
+
- 模型微调示例 / 最佳提示词实践
|
14 |
+
- 视频超分/插帧,用于美化视频生成效果。
|
15 |
+
- 任何模型周边工具
|
16 |
+
- 任何使用CogVideoX开源模型制作的最小完整开源项目
|
17 |
+
|
18 |
+
## 代码规范
|
19 |
+
|
20 |
+
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
|
21 |
+
|
22 |
+
1. 安装`ruff`工具
|
23 |
+
|
24 |
+
```shell
|
25 |
+
pip install ruff
|
26 |
+
```
|
27 |
+
|
28 |
+
接着,运行`ruff`工具
|
29 |
+
|
30 |
+
```shell
|
31 |
+
ruff check tools sat inference
|
32 |
+
```
|
33 |
+
|
34 |
+
检查代码风格,如果有问题,您可以通过`ruff formate`命令自动修复。
|
35 |
+
|
36 |
+
```shell
|
37 |
+
ruff formate tools sat inference
|
38 |
+
```
|
39 |
+
|
40 |
+
如果您的代码符合规范,应该不会出现任何的错误。
|
41 |
+
|
42 |
+
## 命名规范
|
43 |
+
|
44 |
+
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
45 |
+
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
resources/logo.svg
ADDED
resources/videos/1.mp4
ADDED
Binary file (636 kB). View file
|
|
resources/videos/2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e738926f262a28b3b9af6573987905457ea82cdcadb0ec04ad9ab134324f5cc
|
3 |
+
size 1683616
|
resources/videos/3.mp4
ADDED
Binary file (746 kB). View file
|
|
resources/videos/4.mp4
ADDED
Binary file (233 kB). View file
|
|
resources/web_demo.png
ADDED
Git LFS Details
|
resources/wechat.jpg
ADDED
sat/README.md
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SAT CogVideoX-2B
|
2 |
+
|
3 |
+
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
|
4 |
+
fine-tuning code for SAT weights.
|
5 |
+
|
6 |
+
This code is the framework used by the team to train the model. It has few comments and requires careful study.
|
7 |
+
|
8 |
+
## Inference Model
|
9 |
+
|
10 |
+
1. Ensure that you have correctly installed the dependencies required by this folder.
|
11 |
+
|
12 |
+
```shell
|
13 |
+
pip install -r requirements.txt
|
14 |
+
```
|
15 |
+
|
16 |
+
2. Download the model weights
|
17 |
+
|
18 |
+
First, go to the SAT mirror to download the dependencies.
|
19 |
+
|
20 |
+
```shell
|
21 |
+
mkdir CogVideoX-2b-sat
|
22 |
+
cd CogVideoX-2b-sat
|
23 |
+
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
24 |
+
mv 'index.html?dl=1' vae.zip
|
25 |
+
unzip vae.zip
|
26 |
+
wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
|
27 |
+
mv 'index.html?dl=1' transformer.zip
|
28 |
+
unzip transformer.zip
|
29 |
+
```
|
30 |
+
|
31 |
+
Then unzip, the model structure should look like this:
|
32 |
+
|
33 |
+
```
|
34 |
+
.
|
35 |
+
├── transformer
|
36 |
+
│ ├── 1000
|
37 |
+
│ │ └── mp_rank_00_model_states.pt
|
38 |
+
│ └── latest
|
39 |
+
└── vae
|
40 |
+
└── 3d-vae.pt
|
41 |
+
```
|
42 |
+
|
43 |
+
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
|
44 |
+
|
45 |
+
```shell
|
46 |
+
git lfs install
|
47 |
+
git clone https://huggingface.co/google/t5-v1_1-xxl.git
|
48 |
+
```
|
49 |
+
|
50 |
+
**We don't need the tf_model.h5** file. This file can be deleted.
|
51 |
+
|
52 |
+
3. Modify the file `configs/cogvideox_2b_infer.yaml`.
|
53 |
+
|
54 |
+
```yaml
|
55 |
+
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
|
56 |
+
|
57 |
+
conditioner_config:
|
58 |
+
target: sgm.modules.GeneralConditioner
|
59 |
+
params:
|
60 |
+
emb_models:
|
61 |
+
- is_trainable: false
|
62 |
+
input_key: txt
|
63 |
+
ucg_rate: 0.1
|
64 |
+
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
65 |
+
params:
|
66 |
+
model_dir: "google/t5-v1_1-xxl" ## T5 model path
|
67 |
+
max_length: 226
|
68 |
+
|
69 |
+
first_stage_config:
|
70 |
+
target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
|
71 |
+
params:
|
72 |
+
cp_size: 1
|
73 |
+
ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE model path
|
74 |
+
```
|
75 |
+
|
76 |
+
+ If using txt to save multiple prompts, please refer to `configs/test.txt` for modification. One prompt per line. If
|
77 |
+
you don't know how to write prompts, you can first use [this code](../inference/convert_demo.py) to call LLM for
|
78 |
+
refinement.
|
79 |
+
+ If using the command line as input, modify
|
80 |
+
|
81 |
+
```yaml
|
82 |
+
input_type: cli
|
83 |
+
```
|
84 |
+
|
85 |
+
so that prompts can be entered from the command line.
|
86 |
+
|
87 |
+
If you want to change the output video directory, you can modify:
|
88 |
+
|
89 |
+
```yaml
|
90 |
+
output_dir: outputs/
|
91 |
+
```
|
92 |
+
|
93 |
+
The default is saved in the `.outputs/` folder.
|
94 |
+
|
95 |
+
4. Run the inference code to start inference
|
96 |
+
|
97 |
+
```shell
|
98 |
+
bash inference.sh
|
99 |
+
```
|
100 |
+
|
101 |
+
## Fine-Tuning the Model
|
102 |
+
|
103 |
+
### Preparing the Dataset
|
104 |
+
|
105 |
+
The dataset format should be as follows:
|
106 |
+
|
107 |
+
```
|
108 |
+
.
|
109 |
+
├── labels
|
110 |
+
│ ├── 1.txt
|
111 |
+
│ ├── 2.txt
|
112 |
+
│ ├── ...
|
113 |
+
└── videos
|
114 |
+
├── 1.mp4
|
115 |
+
├── 2.mp4
|
116 |
+
├── ...
|
117 |
+
```
|
118 |
+
|
119 |
+
Each txt file should have the same name as its corresponding video file and contain the labels for that video. Each
|
120 |
+
video should have a one-to-one correspondence with a label. Typically, a video should not have multiple labels.
|
121 |
+
|
122 |
+
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to facilitate fitting.
|
123 |
+
|
124 |
+
### Modifying the Configuration File
|
125 |
+
|
126 |
+
We support both `Lora` and `full-parameter fine-tuning` methods. Please note that both fine-tuning methods only apply to the `transformer` part. The `VAE part` is not modified. `T5` is only used as an Encoder.
|
127 |
+
|
128 |
+
the `configs/cogvideox_2b_sft.yaml` (for full fine-tuning) as follows.
|
129 |
+
|
130 |
+
```yaml
|
131 |
+
# checkpoint_activations: True ## using gradient checkpointing (both checkpoint_activations in the configuration file need to be set to True)
|
132 |
+
model_parallel_size: 1 # Model parallel size
|
133 |
+
experiment_name: lora-disney # Experiment name (do not change)
|
134 |
+
mode: finetune # Mode (do not change)
|
135 |
+
load: "{your_CogVideoX-2b-sat_path}/transformer" # Transformer model path
|
136 |
+
no_load_rng: True # Whether to load the random seed
|
137 |
+
train_iters: 1000 # Number of training iterations
|
138 |
+
eval_iters: 1 # Number of evaluation iterations
|
139 |
+
eval_interval: 100 # Evaluation interval
|
140 |
+
eval_batch_size: 1 # Batch size for evaluation
|
141 |
+
save: ckpts # Model save path
|
142 |
+
save_interval: 100 # Model save interval
|
143 |
+
log_interval: 20 # Log output interval
|
144 |
+
train_data: [ "your train data path" ]
|
145 |
+
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
146 |
+
split: 1,0,0 # Ratio of training, validation, and test sets
|
147 |
+
num_workers: 8 # Number of worker threads for data loading
|
148 |
+
```
|
149 |
+
|
150 |
+
If you wish to use Lora fine-tuning, you also need to modify:
|
151 |
+
|
152 |
+
```yaml
|
153 |
+
model:
|
154 |
+
scale_factor: 1.15258426
|
155 |
+
disable_first_stage_autocast: true
|
156 |
+
not_trainable_prefixes: [ 'all' ] ## Uncomment
|
157 |
+
log_keys:
|
158 |
+
- txt'
|
159 |
+
|
160 |
+
lora_config: ## Uncomment
|
161 |
+
target: sat.model.finetune.lora2.LoraMixin
|
162 |
+
params:
|
163 |
+
r: 256
|
164 |
+
```
|
165 |
+
|
166 |
+
### Fine-Tuning and Validation
|
167 |
+
|
168 |
+
1. Run the inference code to start fine-tuning.
|
169 |
+
|
170 |
+
```shell
|
171 |
+
bash finetune.sh
|
172 |
+
```
|
173 |
+
|
174 |
+
### Converting to Huggingface Diffusers Supported Weights
|
175 |
+
|
176 |
+
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
|
177 |
+
|
178 |
+
```shell
|
179 |
+
python ../tools/convert_weight_sat2hf.py
|
180 |
+
```
|
181 |
+
|
182 |
+
**Note**: This content has not yet been tested with LORA fine-tuning models.
|
sat/README_zh.md
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SAT CogVideoX-2B
|
2 |
+
|
3 |
+
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
|
4 |
+
|
5 |
+
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
|
6 |
+
|
7 |
+
## 推理模型
|
8 |
+
|
9 |
+
1. 确保你已经正确安装本文件夹中的要求的依赖
|
10 |
+
|
11 |
+
```shell
|
12 |
+
pip install -r requirements.txt
|
13 |
+
```
|
14 |
+
|
15 |
+
2. 下载模型权重
|
16 |
+
|
17 |
+
首先,前往 SAT 镜像下载依赖。
|
18 |
+
|
19 |
+
```shell
|
20 |
+
mkdir CogVideoX-2b-sat
|
21 |
+
cd CogVideoX-2b-sat
|
22 |
+
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
23 |
+
mv 'index.html?dl=1' vae.zip
|
24 |
+
unzip vae.zip
|
25 |
+
wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
|
26 |
+
mv 'index.html?dl=1' transformer.zip
|
27 |
+
unzip transformer.zip
|
28 |
+
```
|
29 |
+
|
30 |
+
然后,解压文件,模型结构应该如下
|
31 |
+
|
32 |
+
```
|
33 |
+
.
|
34 |
+
├── transformer
|
35 |
+
│ ├── 1000
|
36 |
+
│ │ └── mp_rank_00_model_states.pt
|
37 |
+
│ └── latest
|
38 |
+
└── vae
|
39 |
+
└── 3d-vae.pt
|
40 |
+
```
|
41 |
+
|
42 |
+
接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。
|
43 |
+
|
44 |
+
```shell
|
45 |
+
git lfs install
|
46 |
+
git clone https://huggingface.co/google/t5-v1_1-xxl.git
|
47 |
+
```
|
48 |
+
|
49 |
+
**我们不需要使用tf_model.h5**文件。该文件可以删除。
|
50 |
+
|
51 |
+
3. 修改`configs/cogvideox_2b_infer.yaml`中的文件。
|
52 |
+
|
53 |
+
```yaml
|
54 |
+
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
|
55 |
+
|
56 |
+
conditioner_config:
|
57 |
+
target: sgm.modules.GeneralConditioner
|
58 |
+
params:
|
59 |
+
emb_models:
|
60 |
+
- is_trainable: false
|
61 |
+
input_key: txt
|
62 |
+
ucg_rate: 0.1
|
63 |
+
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
64 |
+
params:
|
65 |
+
model_dir: "google/t5-v1_1-xxl" ## T5 模型路径
|
66 |
+
max_length: 226
|
67 |
+
|
68 |
+
first_stage_config:
|
69 |
+
target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
|
70 |
+
params:
|
71 |
+
cp_size: 1
|
72 |
+
ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE 模型路径
|
73 |
+
|
74 |
+
```
|
75 |
+
|
76 |
+
+ 如果使用 txt 保存多个提示词,请参考`configs/test.txt`
|
77 |
+
进行修改。每一行一个提示词。如果您不知道如何书写提示词,可以先使用[此代码](../inference/convert_demo.py)调用 LLM进行润色。
|
78 |
+
+ 如果使用命令行作为输入,请修改
|
79 |
+
|
80 |
+
```yaml
|
81 |
+
input_type: cli
|
82 |
+
```
|
83 |
+
|
84 |
+
这样就可以从命令行输入提示词。
|
85 |
+
|
86 |
+
如果你希望修改输出视频的地址,你可以修改:
|
87 |
+
|
88 |
+
```yaml
|
89 |
+
output_dir: outputs/
|
90 |
+
```
|
91 |
+
|
92 |
+
默认保存在`.outputs/`文件夹下。
|
93 |
+
|
94 |
+
4. 运行推理代码,即可推理
|
95 |
+
|
96 |
+
```shell
|
97 |
+
bash inference.sh
|
98 |
+
```
|
99 |
+
|
100 |
+
## 微调模型
|
101 |
+
|
102 |
+
### 准备数据集
|
103 |
+
|
104 |
+
数据集格式应该如下:
|
105 |
+
|
106 |
+
```
|
107 |
+
.
|
108 |
+
├── labels
|
109 |
+
│ ├── 1.txt
|
110 |
+
│ ├── 2.txt
|
111 |
+
│ ├── ...
|
112 |
+
└── videos
|
113 |
+
├── 1.mp4
|
114 |
+
├── 2.mp4
|
115 |
+
├── ...
|
116 |
+
```
|
117 |
+
|
118 |
+
每个 txt 与视频同名,为视频的标签。视频与标签应该一一对应。通常情况下,不使用一个视频对应多个标签。
|
119 |
+
|
120 |
+
如果为风格微调,清准备至少50条风格相似的视频和标签,以利于拟合。
|
121 |
+
|
122 |
+
### 修改配置文件
|
123 |
+
|
124 |
+
我们支持 `Lora` 和 全参数微调两种方式。请注意,两种微调方式都仅仅对 `transformer` 部分进行微调。不改动 `VAE` 部分。`T5`仅作为
|
125 |
+
Encoder 使用。
|
126 |
+
部分。 请按照以下方式修改`configs/cogvideox_2b_sft.yaml`(全量微调) 中的文件。
|
127 |
+
|
128 |
+
```yaml
|
129 |
+
# checkpoint_activations: True ## using gradient checkpointing (配置文件中的两个checkpoint_activations都需要设置为True)
|
130 |
+
model_parallel_size: 1 # 模型并行大小
|
131 |
+
experiment_name: lora-disney # 实验名称(不要改动)
|
132 |
+
mode: finetune # 模式(不要改动)
|
133 |
+
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
|
134 |
+
no_load_rng: True # 是否加载随机数种子
|
135 |
+
train_iters: 1000 # 训练迭代次数
|
136 |
+
eval_iters: 1 # 验证迭代次数
|
137 |
+
eval_interval: 100 # 验证间隔
|
138 |
+
eval_batch_size: 1 # 验证集 batch size
|
139 |
+
save: ckpts # 模型保存路径
|
140 |
+
save_interval: 100 # 模型保存间隔
|
141 |
+
log_interval: 20 # 日志输出间隔
|
142 |
+
train_data: [ "your train data path" ]
|
143 |
+
valid_data: [ "your val data path" ] # 训练集和验证集可以相同
|
144 |
+
split: 1,0,0 # 训练集,验证集,测试集比例
|
145 |
+
num_workers: 8 # 数据加载器的工作线程数
|
146 |
+
```
|
147 |
+
|
148 |
+
如果你希望使用 Lora 微调,你还需要修改:
|
149 |
+
|
150 |
+
```yaml
|
151 |
+
model:
|
152 |
+
scale_factor: 1.15258426
|
153 |
+
disable_first_stage_autocast: true
|
154 |
+
not_trainable_prefixes: [ 'all' ] ## 解除注释
|
155 |
+
log_keys:
|
156 |
+
- txt'
|
157 |
+
|
158 |
+
lora_config: ## 解除注释
|
159 |
+
target: sat.model.finetune.lora2.LoraMixin
|
160 |
+
params:
|
161 |
+
r: 256
|
162 |
+
```
|
163 |
+
|
164 |
+
### 微调和验证
|
165 |
+
|
166 |
+
1. 运行推理代码,即可开始微调。
|
167 |
+
|
168 |
+
```shell
|
169 |
+
bash finetune.sh
|
170 |
+
```
|
171 |
+
|
172 |
+
### 转换到 Huggingface Diffusers 库支持的权重
|
173 |
+
|
174 |
+
SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
|
175 |
+
|
176 |
+
```shell
|
177 |
+
python ../tools/convert_weight_sat2hf.py
|
178 |
+
```
|
179 |
+
|
180 |
+
**注意** 本内容暂未测试 LORA 微调模型。
|
sat/arguments.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
import warnings
|
6 |
+
import omegaconf
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sat.helpers import print_rank0
|
9 |
+
from sat import mpu
|
10 |
+
from sat.arguments import set_random_seed
|
11 |
+
from sat.arguments import add_training_args, add_evaluation_args, add_data_args
|
12 |
+
import torch.distributed
|
13 |
+
|
14 |
+
|
15 |
+
def add_model_config_args(parser):
|
16 |
+
"""Model arguments"""
|
17 |
+
|
18 |
+
group = parser.add_argument_group("model", "model configuration")
|
19 |
+
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
20 |
+
group.add_argument(
|
21 |
+
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
|
22 |
+
)
|
23 |
+
group.add_argument("--force-pretrain", action="store_true")
|
24 |
+
group.add_argument("--device", type=int, default=-1)
|
25 |
+
group.add_argument("--debug", action="store_true")
|
26 |
+
group.add_argument("--log-image", type=bool, default=True)
|
27 |
+
|
28 |
+
return parser
|
29 |
+
|
30 |
+
|
31 |
+
def add_sampling_config_args(parser):
|
32 |
+
"""Sampling configurations"""
|
33 |
+
|
34 |
+
group = parser.add_argument_group("sampling", "Sampling Configurations")
|
35 |
+
group.add_argument("--output-dir", type=str, default="samples")
|
36 |
+
group.add_argument("--input-dir", type=str, default=None)
|
37 |
+
group.add_argument("--input-type", type=str, default="cli")
|
38 |
+
group.add_argument("--input-file", type=str, default="input.txt")
|
39 |
+
group.add_argument("--final-size", type=int, default=2048)
|
40 |
+
group.add_argument("--sdedit", action="store_true")
|
41 |
+
group.add_argument("--grid-num-rows", type=int, default=1)
|
42 |
+
group.add_argument("--force-inference", action="store_true")
|
43 |
+
group.add_argument("--lcm_steps", type=int, default=None)
|
44 |
+
group.add_argument("--sampling-num-frames", type=int, default=32)
|
45 |
+
group.add_argument("--sampling-fps", type=int, default=8)
|
46 |
+
group.add_argument("--only-save-latents", type=bool, default=False)
|
47 |
+
group.add_argument("--only-log-video-latents", type=bool, default=False)
|
48 |
+
group.add_argument("--latent-channels", type=int, default=32)
|
49 |
+
group.add_argument("--image2video", action="store_true")
|
50 |
+
|
51 |
+
return parser
|
52 |
+
|
53 |
+
|
54 |
+
def get_args(args_list=None, parser=None):
|
55 |
+
"""Parse all the args."""
|
56 |
+
if parser is None:
|
57 |
+
parser = argparse.ArgumentParser(description="sat")
|
58 |
+
else:
|
59 |
+
assert isinstance(parser, argparse.ArgumentParser)
|
60 |
+
parser = add_model_config_args(parser)
|
61 |
+
parser = add_sampling_config_args(parser)
|
62 |
+
parser = add_training_args(parser)
|
63 |
+
parser = add_evaluation_args(parser)
|
64 |
+
parser = add_data_args(parser)
|
65 |
+
|
66 |
+
import deepspeed
|
67 |
+
|
68 |
+
parser = deepspeed.add_config_arguments(parser)
|
69 |
+
|
70 |
+
args = parser.parse_args(args_list)
|
71 |
+
args = process_config_to_args(args)
|
72 |
+
|
73 |
+
if not args.train_data:
|
74 |
+
print_rank0("No training data specified", level="WARNING")
|
75 |
+
|
76 |
+
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
|
77 |
+
if args.train_iters is None and args.epochs is None:
|
78 |
+
args.train_iters = 10000 # default 10k iters
|
79 |
+
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
|
80 |
+
|
81 |
+
args.cuda = torch.cuda.is_available()
|
82 |
+
|
83 |
+
args.rank = int(os.getenv("RANK", "0"))
|
84 |
+
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
85 |
+
if args.local_rank is None:
|
86 |
+
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
|
87 |
+
|
88 |
+
if args.device == -1:
|
89 |
+
if torch.cuda.device_count() == 0:
|
90 |
+
args.device = "cpu"
|
91 |
+
elif args.local_rank is not None:
|
92 |
+
args.device = args.local_rank
|
93 |
+
else:
|
94 |
+
args.device = args.rank % torch.cuda.device_count()
|
95 |
+
|
96 |
+
if args.local_rank != args.device and args.mode != "inference":
|
97 |
+
raise ValueError(
|
98 |
+
"LOCAL_RANK (default 0) and args.device inconsistent. "
|
99 |
+
"This can only happens in inference mode. "
|
100 |
+
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
|
101 |
+
)
|
102 |
+
|
103 |
+
if args.rank == 0:
|
104 |
+
print_rank0("using world size: {}".format(args.world_size))
|
105 |
+
|
106 |
+
if args.train_data_weights is not None:
|
107 |
+
assert len(args.train_data_weights) == len(args.train_data)
|
108 |
+
|
109 |
+
if args.mode != "inference": # training with deepspeed
|
110 |
+
args.deepspeed = True
|
111 |
+
if args.deepspeed_config is None: # not specified
|
112 |
+
deepspeed_config_path = os.path.join(
|
113 |
+
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
|
114 |
+
)
|
115 |
+
with open(deepspeed_config_path) as file:
|
116 |
+
args.deepspeed_config = json.load(file)
|
117 |
+
override_deepspeed_config = True
|
118 |
+
else:
|
119 |
+
override_deepspeed_config = False
|
120 |
+
|
121 |
+
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
|
122 |
+
|
123 |
+
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
|
124 |
+
print_rank0("Automatically set fp16=True to use ZeRO.")
|
125 |
+
args.fp16 = True
|
126 |
+
args.bf16 = False
|
127 |
+
|
128 |
+
if args.deepspeed:
|
129 |
+
if args.checkpoint_activations:
|
130 |
+
args.deepspeed_activation_checkpointing = True
|
131 |
+
else:
|
132 |
+
args.deepspeed_activation_checkpointing = False
|
133 |
+
if args.deepspeed_config is not None:
|
134 |
+
deepspeed_config = args.deepspeed_config
|
135 |
+
|
136 |
+
if override_deepspeed_config: # not specify deepspeed_config, use args
|
137 |
+
if args.fp16:
|
138 |
+
deepspeed_config["fp16"]["enabled"] = True
|
139 |
+
elif args.bf16:
|
140 |
+
deepspeed_config["bf16"]["enabled"] = True
|
141 |
+
deepspeed_config["fp16"]["enabled"] = False
|
142 |
+
else:
|
143 |
+
deepspeed_config["fp16"]["enabled"] = False
|
144 |
+
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
|
145 |
+
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
|
146 |
+
optimizer_params_config = deepspeed_config["optimizer"]["params"]
|
147 |
+
optimizer_params_config["lr"] = args.lr
|
148 |
+
optimizer_params_config["weight_decay"] = args.weight_decay
|
149 |
+
else: # override args with values in deepspeed_config
|
150 |
+
if args.rank == 0:
|
151 |
+
print_rank0("Will override arguments with manually specified deepspeed_config!")
|
152 |
+
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
|
153 |
+
args.fp16 = True
|
154 |
+
else:
|
155 |
+
args.fp16 = False
|
156 |
+
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
|
157 |
+
args.bf16 = True
|
158 |
+
else:
|
159 |
+
args.bf16 = False
|
160 |
+
if "train_micro_batch_size_per_gpu" in deepspeed_config:
|
161 |
+
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
|
162 |
+
if "gradient_accumulation_steps" in deepspeed_config:
|
163 |
+
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
|
164 |
+
else:
|
165 |
+
args.gradient_accumulation_steps = None
|
166 |
+
if "optimizer" in deepspeed_config:
|
167 |
+
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
|
168 |
+
args.lr = optimizer_params_config.get("lr", args.lr)
|
169 |
+
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
|
170 |
+
args.deepspeed_config = deepspeed_config
|
171 |
+
|
172 |
+
# initialize distributed and random seed because it always seems to be necessary.
|
173 |
+
initialize_distributed(args)
|
174 |
+
args.seed = args.seed + mpu.get_data_parallel_rank()
|
175 |
+
set_random_seed(args.seed)
|
176 |
+
return args
|
177 |
+
|
178 |
+
|
179 |
+
def initialize_distributed(args):
|
180 |
+
"""Initialize torch.distributed."""
|
181 |
+
if torch.distributed.is_initialized():
|
182 |
+
if mpu.model_parallel_is_initialized():
|
183 |
+
if args.model_parallel_size != mpu.get_model_parallel_world_size():
|
184 |
+
raise ValueError(
|
185 |
+
"model_parallel_size is inconsistent with prior configuration."
|
186 |
+
"We currently do not support changing model_parallel_size."
|
187 |
+
)
|
188 |
+
return False
|
189 |
+
else:
|
190 |
+
if args.model_parallel_size > 1:
|
191 |
+
warnings.warn(
|
192 |
+
"model_parallel_size > 1 but torch.distributed is not initialized via SAT."
|
193 |
+
"Please carefully make sure the correctness on your own."
|
194 |
+
)
|
195 |
+
mpu.initialize_model_parallel(args.model_parallel_size)
|
196 |
+
return True
|
197 |
+
# the automatic assignment of devices has been moved to arguments.py
|
198 |
+
if args.device == "cpu":
|
199 |
+
pass
|
200 |
+
else:
|
201 |
+
torch.cuda.set_device(args.device)
|
202 |
+
# Call the init process
|
203 |
+
init_method = "tcp://"
|
204 |
+
args.master_ip = os.getenv("MASTER_ADDR", "localhost")
|
205 |
+
|
206 |
+
if args.world_size == 1:
|
207 |
+
from sat.helpers import get_free_port
|
208 |
+
|
209 |
+
default_master_port = str(get_free_port())
|
210 |
+
else:
|
211 |
+
default_master_port = "6000"
|
212 |
+
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
213 |
+
init_method += args.master_ip + ":" + args.master_port
|
214 |
+
torch.distributed.init_process_group(
|
215 |
+
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
216 |
+
)
|
217 |
+
|
218 |
+
# Set the model-parallel / data-parallel communicators.
|
219 |
+
mpu.initialize_model_parallel(args.model_parallel_size)
|
220 |
+
|
221 |
+
# Set vae context parallel group equal to model parallel group
|
222 |
+
from sgm.util import set_context_parallel_group, initialize_context_parallel
|
223 |
+
|
224 |
+
if args.model_parallel_size <= 2:
|
225 |
+
set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
|
226 |
+
else:
|
227 |
+
initialize_context_parallel(2)
|
228 |
+
# mpu.initialize_model_parallel(1)
|
229 |
+
# Optional DeepSpeed Activation Checkpointing Features
|
230 |
+
if args.deepspeed:
|
231 |
+
import deepspeed
|
232 |
+
|
233 |
+
deepspeed.init_distributed(
|
234 |
+
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
235 |
+
)
|
236 |
+
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
237 |
+
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
238 |
+
else:
|
239 |
+
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
|
240 |
+
try:
|
241 |
+
import deepspeed
|
242 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import (
|
243 |
+
_CUDA_RNG_STATE_TRACKER,
|
244 |
+
_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
245 |
+
)
|
246 |
+
|
247 |
+
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
|
248 |
+
except Exception as e:
|
249 |
+
from sat.helpers import print_rank0
|
250 |
+
|
251 |
+
print_rank0(str(e), level="DEBUG")
|
252 |
+
|
253 |
+
return True
|
254 |
+
|
255 |
+
|
256 |
+
def process_config_to_args(args):
|
257 |
+
"""Fetch args from only --base"""
|
258 |
+
|
259 |
+
configs = [OmegaConf.load(cfg) for cfg in args.base]
|
260 |
+
config = OmegaConf.merge(*configs)
|
261 |
+
|
262 |
+
args_config = config.pop("args", OmegaConf.create())
|
263 |
+
for key in args_config:
|
264 |
+
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
|
265 |
+
arg = OmegaConf.to_object(args_config[key])
|
266 |
+
else:
|
267 |
+
arg = args_config[key]
|
268 |
+
if hasattr(args, key):
|
269 |
+
setattr(args, key, arg)
|
270 |
+
|
271 |
+
if "model" in config:
|
272 |
+
model_config = config.pop("model", OmegaConf.create())
|
273 |
+
args.model_config = model_config
|
274 |
+
if "deepspeed" in config:
|
275 |
+
deepspeed_config = config.pop("deepspeed", OmegaConf.create())
|
276 |
+
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
|
277 |
+
if "data" in config:
|
278 |
+
data_config = config.pop("data", OmegaConf.create())
|
279 |
+
args.data_config = data_config
|
280 |
+
|
281 |
+
return args
|
sat/configs/cogvideox_2b_infer.yaml
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
latent_channels: 16
|
3 |
+
mode: inference
|
4 |
+
load: "CogVideoX-2b-sat/transformer"
|
5 |
+
batch_size: 1
|
6 |
+
input_type: txt
|
7 |
+
input_file: test.txt
|
8 |
+
sampling_num_frames: 13 # Must be 13, 11 or 9
|
9 |
+
sampling_fps: 8
|
10 |
+
fp16: True
|
11 |
+
output_dir: outputs/
|
12 |
+
force_inference: True
|
13 |
+
|
14 |
+
model:
|
15 |
+
scale_factor: 1.15258426
|
16 |
+
disable_first_stage_autocast: true
|
17 |
+
log_keys:
|
18 |
+
- txt
|
19 |
+
|
20 |
+
denoiser_config:
|
21 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
22 |
+
params:
|
23 |
+
num_idx: 1000
|
24 |
+
quantize_c_noise: False
|
25 |
+
|
26 |
+
weighting_config:
|
27 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
28 |
+
scaling_config:
|
29 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
30 |
+
discretization_config:
|
31 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
32 |
+
params:
|
33 |
+
shift_scale: 3.0
|
34 |
+
|
35 |
+
network_config:
|
36 |
+
target: dit_video_concat.DiffusionTransformer
|
37 |
+
params:
|
38 |
+
time_embed_dim: 512
|
39 |
+
elementwise_affine: True
|
40 |
+
num_frames: 49
|
41 |
+
time_compressed_rate: 4
|
42 |
+
latent_width: 90
|
43 |
+
latent_height: 60
|
44 |
+
num_layers: 30
|
45 |
+
patch_size: 2
|
46 |
+
in_channels: 16
|
47 |
+
out_channels: 16
|
48 |
+
hidden_size: 1920
|
49 |
+
adm_in_channels: 256
|
50 |
+
num_attention_heads: 30
|
51 |
+
|
52 |
+
transformer_args:
|
53 |
+
vocab_size: 1
|
54 |
+
max_sequence_length: 64
|
55 |
+
layernorm_order: pre
|
56 |
+
skip_init: false
|
57 |
+
model_parallel_size: 1
|
58 |
+
is_decoder: false
|
59 |
+
|
60 |
+
modules:
|
61 |
+
pos_embed_config:
|
62 |
+
target: dit_video_concat.Basic3DPositionEmbeddingMixin
|
63 |
+
params:
|
64 |
+
text_length: 226
|
65 |
+
height_interpolation: 1.875
|
66 |
+
width_interpolation: 1.875
|
67 |
+
|
68 |
+
patch_embed_config:
|
69 |
+
target: dit_video_concat.ImagePatchEmbeddingMixin
|
70 |
+
params:
|
71 |
+
text_hidden_size: 4096
|
72 |
+
|
73 |
+
adaln_layer_config:
|
74 |
+
target: dit_video_concat.AdaLNMixin
|
75 |
+
params:
|
76 |
+
qk_ln: True
|
77 |
+
|
78 |
+
final_layer_config:
|
79 |
+
target: dit_video_concat.FinalLayerMixin
|
80 |
+
|
81 |
+
conditioner_config:
|
82 |
+
target: sgm.modules.GeneralConditioner
|
83 |
+
params:
|
84 |
+
emb_models:
|
85 |
+
- is_trainable: false
|
86 |
+
input_key: txt
|
87 |
+
ucg_rate: 0.1
|
88 |
+
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
89 |
+
params:
|
90 |
+
model_dir: "google/t5-v1_1-xxl"
|
91 |
+
max_length: 226
|
92 |
+
|
93 |
+
first_stage_config:
|
94 |
+
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
95 |
+
params:
|
96 |
+
cp_size: 1
|
97 |
+
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
|
98 |
+
ignore_keys: [ 'loss' ]
|
99 |
+
|
100 |
+
loss_config:
|
101 |
+
target: torch.nn.Identity
|
102 |
+
|
103 |
+
regularizer_config:
|
104 |
+
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
105 |
+
|
106 |
+
encoder_config:
|
107 |
+
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
108 |
+
params:
|
109 |
+
double_z: true
|
110 |
+
z_channels: 16
|
111 |
+
resolution: 256
|
112 |
+
in_channels: 3
|
113 |
+
out_ch: 3
|
114 |
+
ch: 128
|
115 |
+
ch_mult: [ 1, 2, 2, 4 ]
|
116 |
+
attn_resolutions: [ ]
|
117 |
+
num_res_blocks: 3
|
118 |
+
dropout: 0.0
|
119 |
+
gather_norm: True
|
120 |
+
|
121 |
+
decoder_config:
|
122 |
+
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
123 |
+
params:
|
124 |
+
double_z: True
|
125 |
+
z_channels: 16
|
126 |
+
resolution: 256
|
127 |
+
in_channels: 3
|
128 |
+
out_ch: 3
|
129 |
+
ch: 128
|
130 |
+
ch_mult: [ 1, 2, 2, 4 ]
|
131 |
+
attn_resolutions: [ ]
|
132 |
+
num_res_blocks: 3
|
133 |
+
dropout: 0.0
|
134 |
+
gather_norm: false
|
135 |
+
|
136 |
+
loss_fn_config:
|
137 |
+
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
138 |
+
params:
|
139 |
+
offset_noise_level: 0
|
140 |
+
sigma_sampler_config:
|
141 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
142 |
+
params:
|
143 |
+
uniform_sampling: True
|
144 |
+
num_idx: 1000
|
145 |
+
discretization_config:
|
146 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
147 |
+
params:
|
148 |
+
shift_scale: 3.0
|
149 |
+
|
150 |
+
sampler_config:
|
151 |
+
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
152 |
+
params:
|
153 |
+
num_steps: 50
|
154 |
+
verbose: True
|
155 |
+
|
156 |
+
discretization_config:
|
157 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
158 |
+
params:
|
159 |
+
shift_scale: 3.0
|
160 |
+
|
161 |
+
guider_config:
|
162 |
+
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
163 |
+
params:
|
164 |
+
scale: 6
|
165 |
+
exp: 5
|
166 |
+
num_steps: 50
|
sat/configs/cogvideox_2b_sft.yaml
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
checkpoint_activations: True ## using gradient checkpointing
|
3 |
+
model_parallel_size: 1
|
4 |
+
experiment_name: lora-disney
|
5 |
+
mode: finetune
|
6 |
+
load: "CogVideoX-2b-sat/transformer"
|
7 |
+
no_load_rng: True
|
8 |
+
train_iters: 1000
|
9 |
+
eval_iters: 1
|
10 |
+
eval_interval: 100
|
11 |
+
eval_batch_size: 1
|
12 |
+
save: ckpts
|
13 |
+
save_interval: 100
|
14 |
+
log_interval: 20
|
15 |
+
train_data: ["disney"]
|
16 |
+
valid_data: ["disney"]
|
17 |
+
split: 1,0,0
|
18 |
+
num_workers: 8
|
19 |
+
force_train: True
|
20 |
+
only_log_video_latents: True
|
21 |
+
|
22 |
+
data:
|
23 |
+
target: data_video.SFTDataset
|
24 |
+
params:
|
25 |
+
video_size: [480, 720]
|
26 |
+
fps: 8
|
27 |
+
max_num_frames: 49
|
28 |
+
skip_frms_num: 3.
|
29 |
+
|
30 |
+
deepspeed:
|
31 |
+
train_micro_batch_size_per_gpu: 1
|
32 |
+
gradient_accumulation_steps: 1
|
33 |
+
steps_per_print: 50
|
34 |
+
gradient_clipping: 0.1
|
35 |
+
zero_optimization:
|
36 |
+
stage: 2
|
37 |
+
cpu_offload: false
|
38 |
+
contiguous_gradients: false
|
39 |
+
overlap_comm: true
|
40 |
+
reduce_scatter: true
|
41 |
+
reduce_bucket_size: 1000000000
|
42 |
+
allgather_bucket_size: 1000000000
|
43 |
+
load_from_fp32_weights: false
|
44 |
+
zero_allow_untested_optimizer: true
|
45 |
+
bf16:
|
46 |
+
enabled: False
|
47 |
+
fp16:
|
48 |
+
enabled: True
|
49 |
+
loss_scale: 0
|
50 |
+
loss_scale_window: 400
|
51 |
+
hysteresis: 2
|
52 |
+
min_loss_scale: 1
|
53 |
+
optimizer:
|
54 |
+
type: sat.ops.FusedEmaAdam
|
55 |
+
params:
|
56 |
+
lr: 0.0002
|
57 |
+
betas: [0.9, 0.95]
|
58 |
+
eps: 1e-8
|
59 |
+
weight_decay: 1e-4
|
60 |
+
activation_checkpointing:
|
61 |
+
partition_activations: false
|
62 |
+
contiguous_memory_optimization: false
|
63 |
+
wall_clock_breakdown: false
|
64 |
+
|
65 |
+
|
66 |
+
model:
|
67 |
+
scale_factor: 1.15258426
|
68 |
+
disable_first_stage_autocast: true
|
69 |
+
not_trainable_prefixes: ['all'] ## Using Lora
|
70 |
+
log_keys:
|
71 |
+
- txt
|
72 |
+
|
73 |
+
denoiser_config:
|
74 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
75 |
+
params:
|
76 |
+
num_idx: 1000
|
77 |
+
quantize_c_noise: False
|
78 |
+
|
79 |
+
weighting_config:
|
80 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
81 |
+
scaling_config:
|
82 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
83 |
+
discretization_config:
|
84 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
85 |
+
params:
|
86 |
+
shift_scale: 3.0
|
87 |
+
|
88 |
+
network_config:
|
89 |
+
target: dit_video_concat.DiffusionTransformer
|
90 |
+
params:
|
91 |
+
time_embed_dim: 512
|
92 |
+
elementwise_affine: True
|
93 |
+
num_frames: 49
|
94 |
+
time_compressed_rate: 4
|
95 |
+
latent_width: 90
|
96 |
+
latent_height: 60
|
97 |
+
num_layers: 30
|
98 |
+
patch_size: 2
|
99 |
+
in_channels: 16
|
100 |
+
out_channels: 16
|
101 |
+
hidden_size: 1920
|
102 |
+
adm_in_channels: 256
|
103 |
+
num_attention_heads: 30
|
104 |
+
|
105 |
+
transformer_args:
|
106 |
+
checkpoint_activations: True ## using gradient checkpointing
|
107 |
+
vocab_size: 1
|
108 |
+
max_sequence_length: 64
|
109 |
+
layernorm_order: pre
|
110 |
+
skip_init: false
|
111 |
+
model_parallel_size: 1
|
112 |
+
is_decoder: false
|
113 |
+
|
114 |
+
modules:
|
115 |
+
pos_embed_config:
|
116 |
+
target: dit_video_concat.Basic3DPositionEmbeddingMixin
|
117 |
+
params:
|
118 |
+
text_length: 226
|
119 |
+
height_interpolation: 1.875
|
120 |
+
width_interpolation: 1.875
|
121 |
+
|
122 |
+
lora_config: ## Using Lora
|
123 |
+
target: sat.model.finetune.lora2.LoraMixin
|
124 |
+
params:
|
125 |
+
r: 128
|
126 |
+
|
127 |
+
patch_embed_config:
|
128 |
+
target: dit_video_concat.ImagePatchEmbeddingMixin
|
129 |
+
params:
|
130 |
+
text_hidden_size: 4096
|
131 |
+
|
132 |
+
adaln_layer_config:
|
133 |
+
target: dit_video_concat.AdaLNMixin
|
134 |
+
params:
|
135 |
+
qk_ln: True
|
136 |
+
|
137 |
+
final_layer_config:
|
138 |
+
target: dit_video_concat.FinalLayerMixin
|
139 |
+
|
140 |
+
conditioner_config:
|
141 |
+
target: sgm.modules.GeneralConditioner
|
142 |
+
params:
|
143 |
+
emb_models:
|
144 |
+
- is_trainable: false
|
145 |
+
input_key: txt
|
146 |
+
ucg_rate: 0.1
|
147 |
+
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
148 |
+
params:
|
149 |
+
model_dir: "google/t5-v1_1-xxl"
|
150 |
+
max_length: 226
|
151 |
+
|
152 |
+
first_stage_config:
|
153 |
+
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
154 |
+
params:
|
155 |
+
cp_size: 1
|
156 |
+
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
|
157 |
+
ignore_keys: [ 'loss' ]
|
158 |
+
|
159 |
+
loss_config:
|
160 |
+
target: torch.nn.Identity
|
161 |
+
|
162 |
+
regularizer_config:
|
163 |
+
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
164 |
+
|
165 |
+
encoder_config:
|
166 |
+
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
167 |
+
params:
|
168 |
+
double_z: true
|
169 |
+
z_channels: 16
|
170 |
+
resolution: 256
|
171 |
+
in_channels: 3
|
172 |
+
out_ch: 3
|
173 |
+
ch: 128
|
174 |
+
ch_mult: [ 1, 2, 2, 4 ]
|
175 |
+
attn_resolutions: [ ]
|
176 |
+
num_res_blocks: 3
|
177 |
+
dropout: 0.0
|
178 |
+
gather_norm: True
|
179 |
+
|
180 |
+
decoder_config:
|
181 |
+
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
182 |
+
params:
|
183 |
+
double_z: True
|
184 |
+
z_channels: 16
|
185 |
+
resolution: 256
|
186 |
+
in_channels: 3
|
187 |
+
out_ch: 3
|
188 |
+
ch: 128
|
189 |
+
ch_mult: [ 1, 2, 2, 4 ]
|
190 |
+
attn_resolutions: [ ]
|
191 |
+
num_res_blocks: 3
|
192 |
+
dropout: 0.0
|
193 |
+
gather_norm: false
|
194 |
+
|
195 |
+
loss_fn_config:
|
196 |
+
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
197 |
+
params:
|
198 |
+
offset_noise_level: 0
|
199 |
+
sigma_sampler_config:
|
200 |
+
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
201 |
+
params:
|
202 |
+
uniform_sampling: True
|
203 |
+
num_idx: 1000
|
204 |
+
discretization_config:
|
205 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
206 |
+
params:
|
207 |
+
shift_scale: 3.0
|
208 |
+
|
209 |
+
sampler_config:
|
210 |
+
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
211 |
+
params:
|
212 |
+
num_steps: 50
|
213 |
+
verbose: True
|
214 |
+
|
215 |
+
discretization_config:
|
216 |
+
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
217 |
+
params:
|
218 |
+
shift_scale: 3.0
|
219 |
+
|
220 |
+
guider_config:
|
221 |
+
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
222 |
+
params:
|
223 |
+
scale: 6
|
224 |
+
exp: 5
|
225 |
+
num_steps: 50
|
sat/configs/test.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
2 |
+
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
|
3 |
+
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
|
sat/data_video.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from functools import partial
|
5 |
+
import math
|
6 |
+
import torchvision.transforms as TT
|
7 |
+
from sgm.webds import MetaDistributedWebDataset
|
8 |
+
import random
|
9 |
+
from fractions import Fraction
|
10 |
+
from typing import Union, Optional, Dict, Any, Tuple
|
11 |
+
from torchvision.io.video import av
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torchvision.io import _video_opt
|
15 |
+
from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
|
16 |
+
from torchvision.transforms.functional import center_crop, resize
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
import decord
|
19 |
+
from decord import VideoReader
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
|
22 |
+
|
23 |
+
def read_video(
|
24 |
+
filename: str,
|
25 |
+
start_pts: Union[float, Fraction] = 0,
|
26 |
+
end_pts: Optional[Union[float, Fraction]] = None,
|
27 |
+
pts_unit: str = "pts",
|
28 |
+
output_format: str = "THWC",
|
29 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
30 |
+
"""
|
31 |
+
Reads a video from a file, returning both the video frames and the audio frames
|
32 |
+
|
33 |
+
Args:
|
34 |
+
filename (str): path to the video file
|
35 |
+
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
36 |
+
The start presentation time of the video
|
37 |
+
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
38 |
+
The end presentation time
|
39 |
+
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
40 |
+
either 'pts' or 'sec'. Defaults to 'pts'.
|
41 |
+
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
45 |
+
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
46 |
+
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
47 |
+
"""
|
48 |
+
|
49 |
+
output_format = output_format.upper()
|
50 |
+
if output_format not in ("THWC", "TCHW"):
|
51 |
+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
52 |
+
|
53 |
+
_check_av_available()
|
54 |
+
|
55 |
+
if end_pts is None:
|
56 |
+
end_pts = float("inf")
|
57 |
+
|
58 |
+
if end_pts < start_pts:
|
59 |
+
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
60 |
+
|
61 |
+
info = {}
|
62 |
+
audio_frames = []
|
63 |
+
audio_timebase = _video_opt.default_timebase
|
64 |
+
|
65 |
+
with av.open(filename, metadata_errors="ignore") as container:
|
66 |
+
if container.streams.audio:
|
67 |
+
audio_timebase = container.streams.audio[0].time_base
|
68 |
+
if container.streams.video:
|
69 |
+
video_frames = _read_from_stream(
|
70 |
+
container,
|
71 |
+
start_pts,
|
72 |
+
end_pts,
|
73 |
+
pts_unit,
|
74 |
+
container.streams.video[0],
|
75 |
+
{"video": 0},
|
76 |
+
)
|
77 |
+
video_fps = container.streams.video[0].average_rate
|
78 |
+
# guard against potentially corrupted files
|
79 |
+
if video_fps is not None:
|
80 |
+
info["video_fps"] = float(video_fps)
|
81 |
+
|
82 |
+
if container.streams.audio:
|
83 |
+
audio_frames = _read_from_stream(
|
84 |
+
container,
|
85 |
+
start_pts,
|
86 |
+
end_pts,
|
87 |
+
pts_unit,
|
88 |
+
container.streams.audio[0],
|
89 |
+
{"audio": 0},
|
90 |
+
)
|
91 |
+
info["audio_fps"] = container.streams.audio[0].rate
|
92 |
+
|
93 |
+
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
94 |
+
|
95 |
+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
96 |
+
|
97 |
+
if aframes_list:
|
98 |
+
aframes = np.concatenate(aframes_list, 1)
|
99 |
+
aframes = torch.as_tensor(aframes)
|
100 |
+
if pts_unit == "sec":
|
101 |
+
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
102 |
+
if end_pts != float("inf"):
|
103 |
+
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
104 |
+
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
105 |
+
else:
|
106 |
+
aframes = torch.empty((1, 0), dtype=torch.float32)
|
107 |
+
|
108 |
+
if output_format == "TCHW":
|
109 |
+
# [T,H,W,C] --> [T,C,H,W]
|
110 |
+
vframes = vframes.permute(0, 3, 1, 2)
|
111 |
+
|
112 |
+
return vframes, aframes, info
|
113 |
+
|
114 |
+
|
115 |
+
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
116 |
+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
117 |
+
arr = resize(
|
118 |
+
arr,
|
119 |
+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
120 |
+
interpolation=InterpolationMode.BICUBIC,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
arr = resize(
|
124 |
+
arr,
|
125 |
+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
126 |
+
interpolation=InterpolationMode.BICUBIC,
|
127 |
+
)
|
128 |
+
|
129 |
+
h, w = arr.shape[2], arr.shape[3]
|
130 |
+
arr = arr.squeeze(0)
|
131 |
+
|
132 |
+
delta_h = h - image_size[0]
|
133 |
+
delta_w = w - image_size[1]
|
134 |
+
|
135 |
+
if reshape_mode == "random" or reshape_mode == "none":
|
136 |
+
top = np.random.randint(0, delta_h + 1)
|
137 |
+
left = np.random.randint(0, delta_w + 1)
|
138 |
+
elif reshape_mode == "center":
|
139 |
+
top, left = delta_h // 2, delta_w // 2
|
140 |
+
else:
|
141 |
+
raise NotImplementedError
|
142 |
+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
143 |
+
return arr
|
144 |
+
|
145 |
+
|
146 |
+
def pad_last_frame(tensor, num_frames):
|
147 |
+
# T, H, W, C
|
148 |
+
if tensor.shape[0] < num_frames:
|
149 |
+
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
|
150 |
+
padded_tensor = torch.cat([tensor, last_frame], dim=0)
|
151 |
+
return padded_tensor
|
152 |
+
else:
|
153 |
+
return tensor[:num_frames]
|
154 |
+
|
155 |
+
|
156 |
+
def load_video(
|
157 |
+
video_data,
|
158 |
+
sampling="uniform",
|
159 |
+
duration=None,
|
160 |
+
num_frames=4,
|
161 |
+
wanted_fps=None,
|
162 |
+
actual_fps=None,
|
163 |
+
skip_frms_num=0.0,
|
164 |
+
nb_read_frames=None,
|
165 |
+
):
|
166 |
+
decord.bridge.set_bridge("torch")
|
167 |
+
vr = VideoReader(uri=video_data, height=-1, width=-1)
|
168 |
+
if nb_read_frames is not None:
|
169 |
+
ori_vlen = nb_read_frames
|
170 |
+
else:
|
171 |
+
ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
|
172 |
+
|
173 |
+
max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
|
174 |
+
start = random.randint(skip_frms_num, max_seek + 1)
|
175 |
+
end = int(start + num_frames / wanted_fps * actual_fps)
|
176 |
+
n_frms = num_frames
|
177 |
+
|
178 |
+
if sampling == "uniform":
|
179 |
+
indices = np.arange(start, end, (end - start) / n_frms).astype(int)
|
180 |
+
else:
|
181 |
+
raise NotImplementedError
|
182 |
+
|
183 |
+
# get_batch -> T, H, W, C
|
184 |
+
temp_frms = vr.get_batch(np.arange(start, end))
|
185 |
+
assert temp_frms is not None
|
186 |
+
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
187 |
+
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
188 |
+
|
189 |
+
return pad_last_frame(tensor_frms, num_frames)
|
190 |
+
|
191 |
+
|
192 |
+
import threading
|
193 |
+
|
194 |
+
|
195 |
+
def load_video_with_timeout(*args, **kwargs):
|
196 |
+
video_container = {}
|
197 |
+
|
198 |
+
def target_function():
|
199 |
+
video = load_video(*args, **kwargs)
|
200 |
+
video_container["video"] = video
|
201 |
+
|
202 |
+
thread = threading.Thread(target=target_function)
|
203 |
+
thread.start()
|
204 |
+
timeout = 20
|
205 |
+
thread.join(timeout)
|
206 |
+
|
207 |
+
if thread.is_alive():
|
208 |
+
print("Loading video timed out")
|
209 |
+
raise TimeoutError
|
210 |
+
return video_container.get("video", None).contiguous()
|
211 |
+
|
212 |
+
|
213 |
+
def process_video(
|
214 |
+
video_path,
|
215 |
+
image_size=None,
|
216 |
+
duration=None,
|
217 |
+
num_frames=4,
|
218 |
+
wanted_fps=None,
|
219 |
+
actual_fps=None,
|
220 |
+
skip_frms_num=0.0,
|
221 |
+
nb_read_frames=None,
|
222 |
+
):
|
223 |
+
"""
|
224 |
+
video_path: str or io.BytesIO
|
225 |
+
image_size: .
|
226 |
+
duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
|
227 |
+
num_frames: wanted num_frames.
|
228 |
+
wanted_fps: .
|
229 |
+
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
230 |
+
"""
|
231 |
+
|
232 |
+
video = load_video_with_timeout(
|
233 |
+
video_path,
|
234 |
+
duration=duration,
|
235 |
+
num_frames=num_frames,
|
236 |
+
wanted_fps=wanted_fps,
|
237 |
+
actual_fps=actual_fps,
|
238 |
+
skip_frms_num=skip_frms_num,
|
239 |
+
nb_read_frames=nb_read_frames,
|
240 |
+
)
|
241 |
+
|
242 |
+
# --- copy and modify the image process ---
|
243 |
+
video = video.permute(0, 3, 1, 2) # [T, C, H, W]
|
244 |
+
|
245 |
+
# resize
|
246 |
+
if image_size is not None:
|
247 |
+
video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
|
248 |
+
|
249 |
+
return video
|
250 |
+
|
251 |
+
|
252 |
+
def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
|
253 |
+
while True:
|
254 |
+
r = next(src)
|
255 |
+
if "mp4" in r:
|
256 |
+
video_data = r["mp4"]
|
257 |
+
elif "avi" in r:
|
258 |
+
video_data = r["avi"]
|
259 |
+
else:
|
260 |
+
print("No video data found")
|
261 |
+
continue
|
262 |
+
|
263 |
+
if txt_key not in r:
|
264 |
+
txt = ""
|
265 |
+
else:
|
266 |
+
txt = r[txt_key]
|
267 |
+
|
268 |
+
if isinstance(txt, bytes):
|
269 |
+
txt = txt.decode("utf-8")
|
270 |
+
else:
|
271 |
+
txt = str(txt)
|
272 |
+
|
273 |
+
duration = r.get("duration", None)
|
274 |
+
if duration is not None:
|
275 |
+
duration = float(duration)
|
276 |
+
else:
|
277 |
+
continue
|
278 |
+
|
279 |
+
actual_fps = r.get("fps", None)
|
280 |
+
if actual_fps is not None:
|
281 |
+
actual_fps = float(actual_fps)
|
282 |
+
else:
|
283 |
+
continue
|
284 |
+
|
285 |
+
required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
|
286 |
+
required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
|
287 |
+
|
288 |
+
if duration is not None and duration < required_duration:
|
289 |
+
continue
|
290 |
+
|
291 |
+
try:
|
292 |
+
frames = process_video(
|
293 |
+
io.BytesIO(video_data),
|
294 |
+
num_frames=num_frames,
|
295 |
+
wanted_fps=fps,
|
296 |
+
image_size=image_size,
|
297 |
+
duration=duration,
|
298 |
+
actual_fps=actual_fps,
|
299 |
+
skip_frms_num=skip_frms_num,
|
300 |
+
)
|
301 |
+
frames = (frames - 127.5) / 127.5
|
302 |
+
except Exception as e:
|
303 |
+
print(e)
|
304 |
+
continue
|
305 |
+
|
306 |
+
item = {
|
307 |
+
"mp4": frames,
|
308 |
+
"txt": txt,
|
309 |
+
"num_frames": num_frames,
|
310 |
+
"fps": fps,
|
311 |
+
}
|
312 |
+
|
313 |
+
yield item
|
314 |
+
|
315 |
+
|
316 |
+
class VideoDataset(MetaDistributedWebDataset):
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
path,
|
320 |
+
image_size,
|
321 |
+
num_frames,
|
322 |
+
fps,
|
323 |
+
skip_frms_num=0.0,
|
324 |
+
nshards=sys.maxsize,
|
325 |
+
seed=1,
|
326 |
+
meta_names=None,
|
327 |
+
shuffle_buffer=1000,
|
328 |
+
include_dirs=None,
|
329 |
+
txt_key="caption",
|
330 |
+
**kwargs,
|
331 |
+
):
|
332 |
+
if seed == -1:
|
333 |
+
seed = random.randint(0, 1000000)
|
334 |
+
if meta_names is None:
|
335 |
+
meta_names = []
|
336 |
+
|
337 |
+
if path.startswith(";"):
|
338 |
+
path, include_dirs = path.split(";", 1)
|
339 |
+
super().__init__(
|
340 |
+
path,
|
341 |
+
partial(
|
342 |
+
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
|
343 |
+
),
|
344 |
+
seed,
|
345 |
+
meta_names=meta_names,
|
346 |
+
shuffle_buffer=shuffle_buffer,
|
347 |
+
nshards=nshards,
|
348 |
+
include_dirs=include_dirs,
|
349 |
+
)
|
350 |
+
|
351 |
+
@classmethod
|
352 |
+
def create_dataset_function(cls, path, args, **kwargs):
|
353 |
+
return cls(path, **kwargs)
|
354 |
+
|
355 |
+
|
356 |
+
class SFTDataset(Dataset):
|
357 |
+
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
|
358 |
+
"""
|
359 |
+
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
360 |
+
"""
|
361 |
+
super(SFTDataset, self).__init__()
|
362 |
+
|
363 |
+
self.videos_list = []
|
364 |
+
self.captions_list = []
|
365 |
+
self.num_frames_list = []
|
366 |
+
self.fps_list = []
|
367 |
+
|
368 |
+
decord.bridge.set_bridge("torch")
|
369 |
+
for root, dirnames, filenames in os.walk(data_dir):
|
370 |
+
for filename in filenames:
|
371 |
+
if filename.endswith(".mp4"):
|
372 |
+
video_path = os.path.join(root, filename)
|
373 |
+
vr = VideoReader(uri=video_path, height=-1, width=-1)
|
374 |
+
actual_fps = vr.get_avg_fps()
|
375 |
+
ori_vlen = len(vr)
|
376 |
+
|
377 |
+
if ori_vlen / actual_fps * fps > max_num_frames:
|
378 |
+
num_frames = max_num_frames
|
379 |
+
start = int(skip_frms_num)
|
380 |
+
end = int(start + num_frames / fps * actual_fps)
|
381 |
+
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
382 |
+
temp_frms = vr.get_batch(np.arange(start, end))
|
383 |
+
assert temp_frms is not None
|
384 |
+
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
385 |
+
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
386 |
+
else:
|
387 |
+
if ori_vlen > max_num_frames:
|
388 |
+
num_frames = max_num_frames
|
389 |
+
start = int(skip_frms_num)
|
390 |
+
end = int(ori_vlen - skip_frms_num)
|
391 |
+
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
392 |
+
temp_frms = vr.get_batch(np.arange(start, end))
|
393 |
+
assert temp_frms is not None
|
394 |
+
tensor_frms = (
|
395 |
+
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
396 |
+
)
|
397 |
+
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
398 |
+
else:
|
399 |
+
|
400 |
+
def nearest_smaller_4k_plus_1(n):
|
401 |
+
remainder = n % 4
|
402 |
+
if remainder == 0:
|
403 |
+
return n - 3
|
404 |
+
else:
|
405 |
+
return n - remainder + 1
|
406 |
+
|
407 |
+
start = int(skip_frms_num)
|
408 |
+
end = int(ori_vlen - skip_frms_num)
|
409 |
+
num_frames = nearest_smaller_4k_plus_1(
|
410 |
+
end - start
|
411 |
+
) # 3D VAE requires the number of frames to be 4k+1
|
412 |
+
end = int(start + num_frames)
|
413 |
+
temp_frms = vr.get_batch(np.arange(start, end))
|
414 |
+
assert temp_frms is not None
|
415 |
+
tensor_frms = (
|
416 |
+
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
417 |
+
)
|
418 |
+
|
419 |
+
tensor_frms = pad_last_frame(
|
420 |
+
tensor_frms, num_frames
|
421 |
+
) # the len of indices may be less than num_frames, due to round error
|
422 |
+
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
423 |
+
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
424 |
+
tensor_frms = (tensor_frms - 127.5) / 127.5
|
425 |
+
self.videos_list.append(tensor_frms)
|
426 |
+
|
427 |
+
# caption
|
428 |
+
caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
|
429 |
+
if os.path.exists(caption_path):
|
430 |
+
caption = open(caption_path, "r").read().splitlines()[0]
|
431 |
+
else:
|
432 |
+
caption = ""
|
433 |
+
self.captions_list.append(caption)
|
434 |
+
self.num_frames_list.append(num_frames)
|
435 |
+
self.fps_list.append(fps)
|
436 |
+
|
437 |
+
def __getitem__(self, index):
|
438 |
+
item = {
|
439 |
+
"mp4": self.videos_list[index],
|
440 |
+
"txt": self.captions_list[index],
|
441 |
+
"num_frames": self.num_frames_list[index],
|
442 |
+
"fps": self.fps_list[index],
|
443 |
+
}
|
444 |
+
return item
|
445 |
+
|
446 |
+
def __len__(self):
|
447 |
+
return len(self.fps_list)
|
448 |
+
|
449 |
+
@classmethod
|
450 |
+
def create_dataset_function(cls, path, args, **kwargs):
|
451 |
+
return cls(data_dir=path, **kwargs)
|
sat/diffusion_video.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Any, Dict, List, Tuple, Union, Optional
|
4 |
+
from omegaconf import ListConfig, OmegaConf
|
5 |
+
from copy import deepcopy
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from sat.helpers import print_rank0
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from sgm.modules import UNCONDITIONAL_CONFIG
|
13 |
+
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
14 |
+
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
15 |
+
from sgm.util import (
|
16 |
+
default,
|
17 |
+
disabled_train,
|
18 |
+
get_obj_from_str,
|
19 |
+
instantiate_from_config,
|
20 |
+
log_txt_as_img,
|
21 |
+
)
|
22 |
+
import gc
|
23 |
+
from sat import mpu
|
24 |
+
import random
|
25 |
+
|
26 |
+
|
27 |
+
class SATVideoDiffusionEngine(nn.Module):
|
28 |
+
def __init__(self, args, **kwargs):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
model_config = args.model_config
|
32 |
+
# model args preprocess
|
33 |
+
log_keys = model_config.get("log_keys", None)
|
34 |
+
input_key = model_config.get("input_key", "mp4")
|
35 |
+
network_config = model_config.get("network_config", None)
|
36 |
+
network_wrapper = model_config.get("network_wrapper", None)
|
37 |
+
denoiser_config = model_config.get("denoiser_config", None)
|
38 |
+
sampler_config = model_config.get("sampler_config", None)
|
39 |
+
conditioner_config = model_config.get("conditioner_config", None)
|
40 |
+
first_stage_config = model_config.get("first_stage_config", None)
|
41 |
+
loss_fn_config = model_config.get("loss_fn_config", None)
|
42 |
+
scale_factor = model_config.get("scale_factor", 1.0)
|
43 |
+
latent_input = model_config.get("latent_input", False)
|
44 |
+
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
45 |
+
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
46 |
+
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
|
47 |
+
compile_model = model_config.get("compile_model", False)
|
48 |
+
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
49 |
+
lr_scale = model_config.get("lr_scale", None)
|
50 |
+
lora_train = model_config.get("lora_train", False)
|
51 |
+
self.use_pd = model_config.get("use_pd", False) # progressive distillation
|
52 |
+
|
53 |
+
self.log_keys = log_keys
|
54 |
+
self.input_key = input_key
|
55 |
+
self.not_trainable_prefixes = not_trainable_prefixes
|
56 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
57 |
+
self.lr_scale = lr_scale
|
58 |
+
self.lora_train = lora_train
|
59 |
+
self.noised_image_input = model_config.get("noised_image_input", False)
|
60 |
+
self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
|
61 |
+
self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
|
62 |
+
if args.fp16:
|
63 |
+
dtype = torch.float16
|
64 |
+
dtype_str = "fp16"
|
65 |
+
elif args.bf16:
|
66 |
+
dtype = torch.bfloat16
|
67 |
+
dtype_str = "bf16"
|
68 |
+
else:
|
69 |
+
dtype = torch.float32
|
70 |
+
dtype_str = "fp32"
|
71 |
+
self.dtype = dtype
|
72 |
+
self.dtype_str = dtype_str
|
73 |
+
|
74 |
+
network_config["params"]["dtype"] = dtype_str
|
75 |
+
model = instantiate_from_config(network_config)
|
76 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
77 |
+
model, compile_model=compile_model, dtype=dtype
|
78 |
+
)
|
79 |
+
|
80 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
81 |
+
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
|
82 |
+
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
|
83 |
+
|
84 |
+
self._init_first_stage(first_stage_config)
|
85 |
+
|
86 |
+
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
87 |
+
|
88 |
+
self.latent_input = latent_input
|
89 |
+
self.scale_factor = scale_factor
|
90 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
91 |
+
self.no_cond_log = no_cond_log
|
92 |
+
self.device = args.device
|
93 |
+
|
94 |
+
def disable_untrainable_params(self):
|
95 |
+
total_trainable = 0
|
96 |
+
for n, p in self.named_parameters():
|
97 |
+
if p.requires_grad == False:
|
98 |
+
continue
|
99 |
+
flag = False
|
100 |
+
for prefix in self.not_trainable_prefixes:
|
101 |
+
if n.startswith(prefix) or prefix == "all":
|
102 |
+
flag = True
|
103 |
+
break
|
104 |
+
|
105 |
+
lora_prefix = ["matrix_A", "matrix_B"]
|
106 |
+
for prefix in lora_prefix:
|
107 |
+
if prefix in n:
|
108 |
+
flag = False
|
109 |
+
break
|
110 |
+
|
111 |
+
if flag:
|
112 |
+
p.requires_grad_(False)
|
113 |
+
else:
|
114 |
+
total_trainable += p.numel()
|
115 |
+
|
116 |
+
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
|
117 |
+
|
118 |
+
def reinit(self, parent_model=None):
|
119 |
+
# reload the initial params from previous trained modules
|
120 |
+
# you can also get access to other mixins through parent_model.get_mixin().
|
121 |
+
pass
|
122 |
+
|
123 |
+
def _init_first_stage(self, config):
|
124 |
+
model = instantiate_from_config(config).eval()
|
125 |
+
model.train = disabled_train
|
126 |
+
for param in model.parameters():
|
127 |
+
param.requires_grad = False
|
128 |
+
self.first_stage_model = model
|
129 |
+
|
130 |
+
def forward(self, x, batch):
|
131 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
132 |
+
loss_mean = loss.mean()
|
133 |
+
loss_dict = {"loss": loss_mean}
|
134 |
+
return loss_mean, loss_dict
|
135 |
+
|
136 |
+
def shared_step(self, batch: Dict) -> Any:
|
137 |
+
x = self.get_input(batch)
|
138 |
+
if self.lr_scale is not None:
|
139 |
+
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
|
140 |
+
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
|
141 |
+
lr_z = self.encode_first_stage(lr_x, batch)
|
142 |
+
batch["lr_input"] = lr_z
|
143 |
+
|
144 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
145 |
+
x = self.encode_first_stage(x, batch)
|
146 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
147 |
+
|
148 |
+
gc.collect()
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
loss, loss_dict = self(x, batch)
|
151 |
+
return loss, loss_dict
|
152 |
+
|
153 |
+
def get_input(self, batch):
|
154 |
+
return batch[self.input_key].to(self.dtype)
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def decode_first_stage(self, z):
|
158 |
+
z = 1.0 / self.scale_factor * z
|
159 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
160 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
161 |
+
all_out = []
|
162 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
163 |
+
for n in range(n_rounds):
|
164 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
165 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
166 |
+
else:
|
167 |
+
kwargs = {}
|
168 |
+
use_cp = False
|
169 |
+
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
|
170 |
+
all_out.append(out)
|
171 |
+
out = torch.cat(all_out, dim=0)
|
172 |
+
return out
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def encode_first_stage(self, x, batch):
|
176 |
+
frame = x.shape[2]
|
177 |
+
|
178 |
+
if frame > 1 and self.latent_input:
|
179 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
180 |
+
return x * self.scale_factor # already encoded
|
181 |
+
|
182 |
+
use_cp = False
|
183 |
+
|
184 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
185 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
186 |
+
all_out = []
|
187 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
188 |
+
for n in range(n_rounds):
|
189 |
+
out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
|
190 |
+
all_out.append(out)
|
191 |
+
z = torch.cat(all_out, dim=0)
|
192 |
+
z = self.scale_factor * z
|
193 |
+
return z
|
194 |
+
|
195 |
+
@torch.no_grad()
|
196 |
+
def sample(
|
197 |
+
self,
|
198 |
+
cond: Dict,
|
199 |
+
uc: Union[Dict, None] = None,
|
200 |
+
batch_size: int = 16,
|
201 |
+
shape: Union[None, Tuple, List] = None,
|
202 |
+
prefix=None,
|
203 |
+
concat_images=None,
|
204 |
+
**kwargs,
|
205 |
+
):
|
206 |
+
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
|
207 |
+
if hasattr(self, "seeded_noise"):
|
208 |
+
randn = self.seeded_noise(randn)
|
209 |
+
|
210 |
+
if prefix is not None:
|
211 |
+
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
|
212 |
+
|
213 |
+
# broadcast noise
|
214 |
+
mp_size = mpu.get_model_parallel_world_size()
|
215 |
+
if mp_size > 1:
|
216 |
+
global_rank = torch.distributed.get_rank() // mp_size
|
217 |
+
src = global_rank * mp_size
|
218 |
+
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
|
219 |
+
|
220 |
+
scale = None
|
221 |
+
scale_emb = None
|
222 |
+
|
223 |
+
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
|
224 |
+
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
225 |
+
)
|
226 |
+
|
227 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
|
228 |
+
samples = samples.to(self.dtype)
|
229 |
+
return samples
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
233 |
+
"""
|
234 |
+
Defines heuristics to log different conditionings.
|
235 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
236 |
+
"""
|
237 |
+
image_h, image_w = batch[self.input_key].shape[3:]
|
238 |
+
log = dict()
|
239 |
+
|
240 |
+
for embedder in self.conditioner.embedders:
|
241 |
+
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
|
242 |
+
x = batch[embedder.input_key][:n]
|
243 |
+
if isinstance(x, torch.Tensor):
|
244 |
+
if x.dim() == 1:
|
245 |
+
# class-conditional, convert integer to string
|
246 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
247 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
248 |
+
elif x.dim() == 2:
|
249 |
+
# size and crop cond and the like
|
250 |
+
x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
|
251 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
252 |
+
else:
|
253 |
+
raise NotImplementedError()
|
254 |
+
elif isinstance(x, (List, ListConfig)):
|
255 |
+
if isinstance(x[0], str):
|
256 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
257 |
+
else:
|
258 |
+
raise NotImplementedError()
|
259 |
+
else:
|
260 |
+
raise NotImplementedError()
|
261 |
+
log[embedder.input_key] = xc
|
262 |
+
return log
|
263 |
+
|
264 |
+
@torch.no_grad()
|
265 |
+
def log_video(
|
266 |
+
self,
|
267 |
+
batch: Dict,
|
268 |
+
N: int = 8,
|
269 |
+
ucg_keys: List[str] = None,
|
270 |
+
only_log_video_latents=False,
|
271 |
+
**kwargs,
|
272 |
+
) -> Dict:
|
273 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
274 |
+
if ucg_keys:
|
275 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
276 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
277 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
ucg_keys = conditioner_input_keys
|
281 |
+
log = dict()
|
282 |
+
|
283 |
+
x = self.get_input(batch)
|
284 |
+
|
285 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
286 |
+
batch,
|
287 |
+
force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
|
288 |
+
)
|
289 |
+
|
290 |
+
sampling_kwargs = {}
|
291 |
+
|
292 |
+
N = min(x.shape[0], N)
|
293 |
+
x = x.to(self.device)[:N]
|
294 |
+
if not self.latent_input:
|
295 |
+
log["inputs"] = x.to(torch.float32)
|
296 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
297 |
+
z = self.encode_first_stage(x, batch)
|
298 |
+
if not only_log_video_latents:
|
299 |
+
log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
|
300 |
+
log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
|
301 |
+
z = z.permute(0, 2, 1, 3, 4).contiguous()
|
302 |
+
|
303 |
+
log.update(self.log_conditionings(batch, N))
|
304 |
+
|
305 |
+
for k in c:
|
306 |
+
if isinstance(c[k], torch.Tensor):
|
307 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
308 |
+
|
309 |
+
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
310 |
+
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
311 |
+
if only_log_video_latents:
|
312 |
+
latents = 1.0 / self.scale_factor * samples
|
313 |
+
log["latents"] = latents
|
314 |
+
else:
|
315 |
+
samples = self.decode_first_stage(samples).to(torch.float32)
|
316 |
+
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
317 |
+
log["samples"] = samples
|
318 |
+
return log
|
sat/dit_video_concat.py
ADDED
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from einops import rearrange, repeat
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from sat.model.base_model import BaseModel, non_conflict
|
10 |
+
from sat.model.mixins import BaseMixin
|
11 |
+
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
|
12 |
+
from sat.mpu.layers import ColumnParallelLinear
|
13 |
+
from sgm.util import instantiate_from_config
|
14 |
+
|
15 |
+
from sgm.modules.diffusionmodules.openaimodel import Timestep
|
16 |
+
from sgm.modules.diffusionmodules.util import (
|
17 |
+
linear,
|
18 |
+
timestep_embedding,
|
19 |
+
)
|
20 |
+
from sat.ops.layernorm import LayerNorm, RMSNorm
|
21 |
+
|
22 |
+
|
23 |
+
class ImagePatchEmbeddingMixin(BaseMixin):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
in_channels,
|
27 |
+
hidden_size,
|
28 |
+
patch_size,
|
29 |
+
bias=True,
|
30 |
+
text_hidden_size=None,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
|
34 |
+
if text_hidden_size is not None:
|
35 |
+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
36 |
+
else:
|
37 |
+
self.text_proj = None
|
38 |
+
|
39 |
+
def word_embedding_forward(self, input_ids, **kwargs):
|
40 |
+
# now is 3d patch
|
41 |
+
images = kwargs["images"] # (b,t,c,h,w)
|
42 |
+
B, T = images.shape[:2]
|
43 |
+
emb = images.view(-1, *images.shape[2:])
|
44 |
+
emb = self.proj(emb) # ((b t),d,h/2,w/2)
|
45 |
+
emb = emb.view(B, T, *emb.shape[1:])
|
46 |
+
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
|
47 |
+
emb = rearrange(emb, "b t n d -> b (t n) d")
|
48 |
+
|
49 |
+
if self.text_proj is not None:
|
50 |
+
text_emb = self.text_proj(kwargs["encoder_outputs"])
|
51 |
+
emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
|
52 |
+
|
53 |
+
emb = emb.contiguous()
|
54 |
+
return emb # (b,n_t+t*n_i,d)
|
55 |
+
|
56 |
+
def reinit(self, parent_model=None):
|
57 |
+
w = self.proj.weight.data
|
58 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
59 |
+
nn.init.constant_(self.proj.bias, 0)
|
60 |
+
del self.transformer.word_embeddings
|
61 |
+
|
62 |
+
|
63 |
+
def get_3d_sincos_pos_embed(
|
64 |
+
embed_dim,
|
65 |
+
grid_height,
|
66 |
+
grid_width,
|
67 |
+
t_size,
|
68 |
+
cls_token=False,
|
69 |
+
height_interpolation=1.0,
|
70 |
+
width_interpolation=1.0,
|
71 |
+
time_interpolation=1.0,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
grid_size: int of the grid height and width
|
75 |
+
t_size: int of the temporal size
|
76 |
+
return:
|
77 |
+
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
78 |
+
"""
|
79 |
+
assert embed_dim % 4 == 0
|
80 |
+
embed_dim_spatial = embed_dim // 4 * 3
|
81 |
+
embed_dim_temporal = embed_dim // 4
|
82 |
+
|
83 |
+
# spatial
|
84 |
+
grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
|
85 |
+
grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
|
86 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
87 |
+
grid = np.stack(grid, axis=0)
|
88 |
+
|
89 |
+
grid = grid.reshape([2, 1, grid_height, grid_width])
|
90 |
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
91 |
+
|
92 |
+
# temporal
|
93 |
+
grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
|
94 |
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
95 |
+
|
96 |
+
# concate: [T, H, W] order
|
97 |
+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
98 |
+
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
|
99 |
+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
100 |
+
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
101 |
+
|
102 |
+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
103 |
+
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
104 |
+
|
105 |
+
return pos_embed # [T, H*W, D]
|
106 |
+
|
107 |
+
|
108 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
|
109 |
+
"""
|
110 |
+
grid_size: int of the grid height and width
|
111 |
+
return:
|
112 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
113 |
+
"""
|
114 |
+
grid_h = np.arange(grid_height, dtype=np.float32)
|
115 |
+
grid_w = np.arange(grid_width, dtype=np.float32)
|
116 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
117 |
+
grid = np.stack(grid, axis=0)
|
118 |
+
|
119 |
+
grid = grid.reshape([2, 1, grid_height, grid_width])
|
120 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
121 |
+
if cls_token and extra_tokens > 0:
|
122 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
123 |
+
return pos_embed
|
124 |
+
|
125 |
+
|
126 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
127 |
+
assert embed_dim % 2 == 0
|
128 |
+
|
129 |
+
# use half of dimensions to encode grid_h
|
130 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
131 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
132 |
+
|
133 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
134 |
+
return emb
|
135 |
+
|
136 |
+
|
137 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
138 |
+
"""
|
139 |
+
embed_dim: output dimension for each position
|
140 |
+
pos: a list of positions to be encoded: size (M,)
|
141 |
+
out: (M, D)
|
142 |
+
"""
|
143 |
+
assert embed_dim % 2 == 0
|
144 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
145 |
+
omega /= embed_dim / 2.0
|
146 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
147 |
+
|
148 |
+
pos = pos.reshape(-1) # (M,)
|
149 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
150 |
+
|
151 |
+
emb_sin = np.sin(out) # (M, D/2)
|
152 |
+
emb_cos = np.cos(out) # (M, D/2)
|
153 |
+
|
154 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
155 |
+
return emb
|
156 |
+
|
157 |
+
|
158 |
+
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
height,
|
162 |
+
width,
|
163 |
+
compressed_num_frames,
|
164 |
+
hidden_size,
|
165 |
+
text_length=0,
|
166 |
+
height_interpolation=1.0,
|
167 |
+
width_interpolation=1.0,
|
168 |
+
time_interpolation=1.0,
|
169 |
+
):
|
170 |
+
super().__init__()
|
171 |
+
self.height = height
|
172 |
+
self.width = width
|
173 |
+
self.text_length = text_length
|
174 |
+
self.compressed_num_frames = compressed_num_frames
|
175 |
+
self.spatial_length = height * width
|
176 |
+
self.num_patches = height * width * compressed_num_frames
|
177 |
+
self.pos_embedding = nn.Parameter(
|
178 |
+
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
|
179 |
+
)
|
180 |
+
self.height_interpolation = height_interpolation
|
181 |
+
self.width_interpolation = width_interpolation
|
182 |
+
self.time_interpolation = time_interpolation
|
183 |
+
|
184 |
+
def position_embedding_forward(self, position_ids, **kwargs):
|
185 |
+
if kwargs["images"].shape[1] == 1:
|
186 |
+
return self.pos_embedding[:, : self.text_length + self.spatial_length]
|
187 |
+
|
188 |
+
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
189 |
+
|
190 |
+
def reinit(self, parent_model=None):
|
191 |
+
del self.transformer.position_embeddings
|
192 |
+
pos_embed = get_3d_sincos_pos_embed(
|
193 |
+
self.pos_embedding.shape[-1],
|
194 |
+
self.height,
|
195 |
+
self.width,
|
196 |
+
self.compressed_num_frames,
|
197 |
+
height_interpolation=self.height_interpolation,
|
198 |
+
width_interpolation=self.width_interpolation,
|
199 |
+
time_interpolation=self.time_interpolation,
|
200 |
+
)
|
201 |
+
pos_embed = torch.from_numpy(pos_embed).float()
|
202 |
+
pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
|
203 |
+
self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
|
204 |
+
|
205 |
+
|
206 |
+
def broadcat(tensors, dim=-1):
|
207 |
+
num_tensors = len(tensors)
|
208 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
209 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
210 |
+
shape_len = list(shape_lens)[0]
|
211 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
212 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
213 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
214 |
+
assert all(
|
215 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
216 |
+
), "invalid dimensions for broadcastable concatentation"
|
217 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
218 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
219 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
220 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
221 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
222 |
+
return torch.cat(tensors, dim=dim)
|
223 |
+
|
224 |
+
|
225 |
+
def rotate_half(x):
|
226 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
227 |
+
x1, x2 = x.unbind(dim=-1)
|
228 |
+
x = torch.stack((-x2, x1), dim=-1)
|
229 |
+
return rearrange(x, "... d r -> ... (d r)")
|
230 |
+
|
231 |
+
|
232 |
+
class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
height,
|
236 |
+
width,
|
237 |
+
compressed_num_frames,
|
238 |
+
hidden_size,
|
239 |
+
hidden_size_head,
|
240 |
+
text_length,
|
241 |
+
theta=10000,
|
242 |
+
rot_v=False,
|
243 |
+
pnp=False,
|
244 |
+
learnable_pos_embed=False,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
self.rot_v = rot_v
|
248 |
+
|
249 |
+
dim_t = hidden_size_head // 4
|
250 |
+
dim_h = hidden_size_head // 8 * 3
|
251 |
+
dim_w = hidden_size_head // 8 * 3
|
252 |
+
|
253 |
+
# 'lang':
|
254 |
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
255 |
+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
256 |
+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
257 |
+
|
258 |
+
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
|
259 |
+
grid_h = torch.arange(height, dtype=torch.float32)
|
260 |
+
grid_w = torch.arange(width, dtype=torch.float32)
|
261 |
+
|
262 |
+
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
263 |
+
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
264 |
+
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
265 |
+
|
266 |
+
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
267 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
268 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
269 |
+
|
270 |
+
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
271 |
+
# (T H W D)
|
272 |
+
|
273 |
+
self.pnp = pnp
|
274 |
+
|
275 |
+
if not self.pnp:
|
276 |
+
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
277 |
+
|
278 |
+
freqs = freqs.contiguous()
|
279 |
+
freqs_sin = freqs.sin()
|
280 |
+
freqs_cos = freqs.cos()
|
281 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
282 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
283 |
+
|
284 |
+
self.text_length = text_length
|
285 |
+
if learnable_pos_embed:
|
286 |
+
num_patches = height * width * compressed_num_frames + text_length
|
287 |
+
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
288 |
+
else:
|
289 |
+
self.pos_embedding = None
|
290 |
+
|
291 |
+
def rotary(self, t, **kwargs):
|
292 |
+
if self.pnp:
|
293 |
+
t_coords = kwargs["rope_position_ids"][:, :, 0]
|
294 |
+
x_coords = kwargs["rope_position_ids"][:, :, 1]
|
295 |
+
y_coords = kwargs["rope_position_ids"][:, :, 2]
|
296 |
+
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
|
297 |
+
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
|
298 |
+
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
|
299 |
+
|
300 |
+
else:
|
301 |
+
|
302 |
+
def reshape_freq(freqs):
|
303 |
+
frame = t.shape[2]
|
304 |
+
freqs = freqs[:frame].contiguous()
|
305 |
+
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
306 |
+
return freqs
|
307 |
+
|
308 |
+
freqs_cos = reshape_freq(self.freqs_cos)
|
309 |
+
freqs_sin = reshape_freq(self.freqs_sin)
|
310 |
+
|
311 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
312 |
+
|
313 |
+
def position_embedding_forward(self, position_ids, **kwargs):
|
314 |
+
if self.pos_embedding is not None:
|
315 |
+
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
316 |
+
else:
|
317 |
+
return None
|
318 |
+
|
319 |
+
def attention_fn(
|
320 |
+
self,
|
321 |
+
query_layer,
|
322 |
+
key_layer,
|
323 |
+
value_layer,
|
324 |
+
attention_mask,
|
325 |
+
attention_dropout=None,
|
326 |
+
log_attention_weights=None,
|
327 |
+
scaling_attention_score=True,
|
328 |
+
**kwargs,
|
329 |
+
):
|
330 |
+
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
331 |
+
|
332 |
+
if self.pnp:
|
333 |
+
query_layer = self.rotary(query_layer, **kwargs)
|
334 |
+
key_layer = self.rotary(key_layer, **kwargs)
|
335 |
+
if self.rot_v:
|
336 |
+
value_layer = self.rotary(value_layer)
|
337 |
+
else:
|
338 |
+
query_layer = torch.cat(
|
339 |
+
(
|
340 |
+
query_layer[
|
341 |
+
:,
|
342 |
+
:,
|
343 |
+
: kwargs["text_length"],
|
344 |
+
],
|
345 |
+
self.rotary(
|
346 |
+
query_layer[
|
347 |
+
:,
|
348 |
+
:,
|
349 |
+
kwargs["text_length"] :,
|
350 |
+
]
|
351 |
+
),
|
352 |
+
),
|
353 |
+
dim=2,
|
354 |
+
)
|
355 |
+
key_layer = torch.cat(
|
356 |
+
(
|
357 |
+
key_layer[
|
358 |
+
:,
|
359 |
+
:,
|
360 |
+
: kwargs["text_length"],
|
361 |
+
],
|
362 |
+
self.rotary(
|
363 |
+
key_layer[
|
364 |
+
:,
|
365 |
+
:,
|
366 |
+
kwargs["text_length"] :,
|
367 |
+
]
|
368 |
+
),
|
369 |
+
),
|
370 |
+
dim=2,
|
371 |
+
)
|
372 |
+
if self.rot_v:
|
373 |
+
value_layer = torch.cat(
|
374 |
+
(
|
375 |
+
value_layer[
|
376 |
+
:,
|
377 |
+
:,
|
378 |
+
: kwargs["text_length"],
|
379 |
+
],
|
380 |
+
self.rotary(
|
381 |
+
value_layer[
|
382 |
+
:,
|
383 |
+
:,
|
384 |
+
kwargs["text_length"] :,
|
385 |
+
]
|
386 |
+
),
|
387 |
+
),
|
388 |
+
dim=2,
|
389 |
+
)
|
390 |
+
|
391 |
+
return attention_fn_default(
|
392 |
+
query_layer,
|
393 |
+
key_layer,
|
394 |
+
value_layer,
|
395 |
+
attention_mask,
|
396 |
+
attention_dropout=attention_dropout,
|
397 |
+
log_attention_weights=log_attention_weights,
|
398 |
+
scaling_attention_score=scaling_attention_score,
|
399 |
+
**kwargs,
|
400 |
+
)
|
401 |
+
|
402 |
+
|
403 |
+
def modulate(x, shift, scale):
|
404 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
405 |
+
|
406 |
+
|
407 |
+
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
|
408 |
+
"""
|
409 |
+
x: (N, T/2 * S, patch_size**3 * C)
|
410 |
+
imgs: (N, T, H, W, C)
|
411 |
+
"""
|
412 |
+
if rope_position_ids is not None:
|
413 |
+
assert NotImplementedError
|
414 |
+
# do pix2struct unpatchify
|
415 |
+
L = x.shape[1]
|
416 |
+
x = x.reshape(shape=(x.shape[0], L, p, p, c))
|
417 |
+
x = torch.einsum("nlpqc->ncplq", x)
|
418 |
+
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
|
419 |
+
else:
|
420 |
+
b = x.shape[0]
|
421 |
+
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
|
422 |
+
|
423 |
+
return imgs
|
424 |
+
|
425 |
+
|
426 |
+
class FinalLayerMixin(BaseMixin):
|
427 |
+
def __init__(
|
428 |
+
self,
|
429 |
+
hidden_size,
|
430 |
+
time_embed_dim,
|
431 |
+
patch_size,
|
432 |
+
out_channels,
|
433 |
+
latent_width,
|
434 |
+
latent_height,
|
435 |
+
elementwise_affine,
|
436 |
+
):
|
437 |
+
super().__init__()
|
438 |
+
self.hidden_size = hidden_size
|
439 |
+
self.patch_size = patch_size
|
440 |
+
self.out_channels = out_channels
|
441 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
442 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
443 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
444 |
+
|
445 |
+
self.spatial_length = latent_width * latent_height // patch_size**2
|
446 |
+
self.latent_width = latent_width
|
447 |
+
self.latent_height = latent_height
|
448 |
+
|
449 |
+
def final_forward(self, logits, **kwargs):
|
450 |
+
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
|
451 |
+
|
452 |
+
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
453 |
+
x = modulate(self.norm_final(x), shift, scale)
|
454 |
+
x = self.linear(x)
|
455 |
+
|
456 |
+
return unpatchify(
|
457 |
+
x,
|
458 |
+
c=self.out_channels,
|
459 |
+
p=self.patch_size,
|
460 |
+
w=self.latent_width // self.patch_size,
|
461 |
+
h=self.latent_height // self.patch_size,
|
462 |
+
rope_position_ids=kwargs.get("rope_position_ids", None),
|
463 |
+
**kwargs,
|
464 |
+
)
|
465 |
+
|
466 |
+
def reinit(self, parent_model=None):
|
467 |
+
nn.init.xavier_uniform_(self.linear.weight)
|
468 |
+
nn.init.constant_(self.linear.bias, 0)
|
469 |
+
|
470 |
+
|
471 |
+
class SwiGLUMixin(BaseMixin):
|
472 |
+
def __init__(self, num_layers, in_features, hidden_features, bias=False):
|
473 |
+
super().__init__()
|
474 |
+
self.w2 = nn.ModuleList(
|
475 |
+
[
|
476 |
+
ColumnParallelLinear(
|
477 |
+
in_features,
|
478 |
+
hidden_features,
|
479 |
+
gather_output=False,
|
480 |
+
bias=bias,
|
481 |
+
module=self,
|
482 |
+
name="dense_h_to_4h_gate",
|
483 |
+
)
|
484 |
+
for i in range(num_layers)
|
485 |
+
]
|
486 |
+
)
|
487 |
+
|
488 |
+
def mlp_forward(self, hidden_states, **kw_args):
|
489 |
+
x = hidden_states
|
490 |
+
origin = self.transformer.layers[kw_args["layer_id"]].mlp
|
491 |
+
x1 = origin.dense_h_to_4h(x)
|
492 |
+
x2 = self.w2[kw_args["layer_id"]](x)
|
493 |
+
hidden = origin.activation_func(x2) * x1
|
494 |
+
x = origin.dense_4h_to_h(hidden)
|
495 |
+
return x
|
496 |
+
|
497 |
+
|
498 |
+
class AdaLNMixin(BaseMixin):
|
499 |
+
def __init__(
|
500 |
+
self,
|
501 |
+
width,
|
502 |
+
height,
|
503 |
+
hidden_size,
|
504 |
+
num_layers,
|
505 |
+
time_embed_dim,
|
506 |
+
compressed_num_frames,
|
507 |
+
qk_ln=True,
|
508 |
+
hidden_size_head=None,
|
509 |
+
elementwise_affine=True,
|
510 |
+
):
|
511 |
+
super().__init__()
|
512 |
+
self.num_layers = num_layers
|
513 |
+
self.width = width
|
514 |
+
self.height = height
|
515 |
+
self.compressed_num_frames = compressed_num_frames
|
516 |
+
|
517 |
+
self.adaLN_modulations = nn.ModuleList(
|
518 |
+
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
|
519 |
+
)
|
520 |
+
|
521 |
+
self.qk_ln = qk_ln
|
522 |
+
if qk_ln:
|
523 |
+
self.query_layernorm_list = nn.ModuleList(
|
524 |
+
[
|
525 |
+
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
|
526 |
+
for _ in range(num_layers)
|
527 |
+
]
|
528 |
+
)
|
529 |
+
self.key_layernorm_list = nn.ModuleList(
|
530 |
+
[
|
531 |
+
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
|
532 |
+
for _ in range(num_layers)
|
533 |
+
]
|
534 |
+
)
|
535 |
+
|
536 |
+
def layer_forward(
|
537 |
+
self,
|
538 |
+
hidden_states,
|
539 |
+
mask,
|
540 |
+
*args,
|
541 |
+
**kwargs,
|
542 |
+
):
|
543 |
+
text_length = kwargs["text_length"]
|
544 |
+
# hidden_states (b,(n_t+t*n_i),d)
|
545 |
+
text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
|
546 |
+
img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
|
547 |
+
layer = self.transformer.layers[kwargs["layer_id"]]
|
548 |
+
adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
|
549 |
+
|
550 |
+
(
|
551 |
+
shift_msa,
|
552 |
+
scale_msa,
|
553 |
+
gate_msa,
|
554 |
+
shift_mlp,
|
555 |
+
scale_mlp,
|
556 |
+
gate_mlp,
|
557 |
+
text_shift_msa,
|
558 |
+
text_scale_msa,
|
559 |
+
text_gate_msa,
|
560 |
+
text_shift_mlp,
|
561 |
+
text_scale_mlp,
|
562 |
+
text_gate_mlp,
|
563 |
+
) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
|
564 |
+
gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
|
565 |
+
gate_msa.unsqueeze(1),
|
566 |
+
gate_mlp.unsqueeze(1),
|
567 |
+
text_gate_msa.unsqueeze(1),
|
568 |
+
text_gate_mlp.unsqueeze(1),
|
569 |
+
)
|
570 |
+
|
571 |
+
# self full attention (b,(t n),d)
|
572 |
+
img_attention_input = layer.input_layernorm(img_hidden_states)
|
573 |
+
text_attention_input = layer.input_layernorm(text_hidden_states)
|
574 |
+
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
575 |
+
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
576 |
+
|
577 |
+
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
|
578 |
+
attention_output = layer.attention(attention_input, mask, **kwargs)
|
579 |
+
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
580 |
+
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
581 |
+
|
582 |
+
if self.transformer.layernorm_order == "sandwich":
|
583 |
+
text_attention_output = layer.third_layernorm(text_attention_output)
|
584 |
+
img_attention_output = layer.third_layernorm(img_attention_output)
|
585 |
+
img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
|
586 |
+
text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
|
587 |
+
|
588 |
+
# mlp (b,(t n),d)
|
589 |
+
img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
|
590 |
+
text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
|
591 |
+
img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
|
592 |
+
text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
|
593 |
+
mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
|
594 |
+
mlp_output = layer.mlp(mlp_input, **kwargs)
|
595 |
+
img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
|
596 |
+
text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
|
597 |
+
if self.transformer.layernorm_order == "sandwich":
|
598 |
+
text_mlp_output = layer.fourth_layernorm(text_mlp_output)
|
599 |
+
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
600 |
+
|
601 |
+
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
602 |
+
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
|
603 |
+
|
604 |
+
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
|
605 |
+
return hidden_states
|
606 |
+
|
607 |
+
def reinit(self, parent_model=None):
|
608 |
+
for layer in self.adaLN_modulations:
|
609 |
+
nn.init.constant_(layer[-1].weight, 0)
|
610 |
+
nn.init.constant_(layer[-1].bias, 0)
|
611 |
+
|
612 |
+
@non_conflict
|
613 |
+
def attention_fn(
|
614 |
+
self,
|
615 |
+
query_layer,
|
616 |
+
key_layer,
|
617 |
+
value_layer,
|
618 |
+
attention_mask,
|
619 |
+
attention_dropout=None,
|
620 |
+
log_attention_weights=None,
|
621 |
+
scaling_attention_score=True,
|
622 |
+
old_impl=attention_fn_default,
|
623 |
+
**kwargs,
|
624 |
+
):
|
625 |
+
if self.qk_ln:
|
626 |
+
query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
|
627 |
+
key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
|
628 |
+
query_layer = query_layernorm(query_layer)
|
629 |
+
key_layer = key_layernorm(key_layer)
|
630 |
+
|
631 |
+
return old_impl(
|
632 |
+
query_layer,
|
633 |
+
key_layer,
|
634 |
+
value_layer,
|
635 |
+
attention_mask,
|
636 |
+
attention_dropout=attention_dropout,
|
637 |
+
log_attention_weights=log_attention_weights,
|
638 |
+
scaling_attention_score=scaling_attention_score,
|
639 |
+
**kwargs,
|
640 |
+
)
|
641 |
+
|
642 |
+
|
643 |
+
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
644 |
+
|
645 |
+
|
646 |
+
class DiffusionTransformer(BaseModel):
|
647 |
+
def __init__(
|
648 |
+
self,
|
649 |
+
transformer_args,
|
650 |
+
num_frames,
|
651 |
+
time_compressed_rate,
|
652 |
+
latent_width,
|
653 |
+
latent_height,
|
654 |
+
patch_size,
|
655 |
+
in_channels,
|
656 |
+
out_channels,
|
657 |
+
hidden_size,
|
658 |
+
num_layers,
|
659 |
+
num_attention_heads,
|
660 |
+
elementwise_affine,
|
661 |
+
time_embed_dim=None,
|
662 |
+
num_classes=None,
|
663 |
+
modules={},
|
664 |
+
input_time="adaln",
|
665 |
+
adm_in_channels=None,
|
666 |
+
parallel_output=True,
|
667 |
+
height_interpolation=1.0,
|
668 |
+
width_interpolation=1.0,
|
669 |
+
time_interpolation=1.0,
|
670 |
+
use_SwiGLU=False,
|
671 |
+
use_RMSNorm=False,
|
672 |
+
zero_init_y_embed=False,
|
673 |
+
**kwargs,
|
674 |
+
):
|
675 |
+
self.latent_width = latent_width
|
676 |
+
self.latent_height = latent_height
|
677 |
+
self.patch_size = patch_size
|
678 |
+
self.num_frames = num_frames
|
679 |
+
self.time_compressed_rate = time_compressed_rate
|
680 |
+
self.spatial_length = latent_width * latent_height // patch_size**2
|
681 |
+
self.in_channels = in_channels
|
682 |
+
self.out_channels = out_channels
|
683 |
+
self.hidden_size = hidden_size
|
684 |
+
self.model_channels = hidden_size
|
685 |
+
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
686 |
+
self.num_classes = num_classes
|
687 |
+
self.adm_in_channels = adm_in_channels
|
688 |
+
self.input_time = input_time
|
689 |
+
self.num_layers = num_layers
|
690 |
+
self.num_attention_heads = num_attention_heads
|
691 |
+
self.is_decoder = transformer_args.is_decoder
|
692 |
+
self.elementwise_affine = elementwise_affine
|
693 |
+
self.height_interpolation = height_interpolation
|
694 |
+
self.width_interpolation = width_interpolation
|
695 |
+
self.time_interpolation = time_interpolation
|
696 |
+
self.inner_hidden_size = hidden_size * 4
|
697 |
+
self.zero_init_y_embed = zero_init_y_embed
|
698 |
+
try:
|
699 |
+
self.dtype = str_to_dtype[kwargs.pop("dtype")]
|
700 |
+
except:
|
701 |
+
self.dtype = torch.float32
|
702 |
+
|
703 |
+
if use_SwiGLU:
|
704 |
+
kwargs["activation_func"] = F.silu
|
705 |
+
elif "activation_func" not in kwargs:
|
706 |
+
approx_gelu = nn.GELU(approximate="tanh")
|
707 |
+
kwargs["activation_func"] = approx_gelu
|
708 |
+
|
709 |
+
if use_RMSNorm:
|
710 |
+
kwargs["layernorm"] = RMSNorm
|
711 |
+
else:
|
712 |
+
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
|
713 |
+
|
714 |
+
transformer_args.num_layers = num_layers
|
715 |
+
transformer_args.hidden_size = hidden_size
|
716 |
+
transformer_args.num_attention_heads = num_attention_heads
|
717 |
+
transformer_args.parallel_output = parallel_output
|
718 |
+
super().__init__(args=transformer_args, transformer=None, **kwargs)
|
719 |
+
|
720 |
+
module_configs = modules
|
721 |
+
self._build_modules(module_configs)
|
722 |
+
|
723 |
+
if use_SwiGLU:
|
724 |
+
self.add_mixin(
|
725 |
+
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
|
726 |
+
)
|
727 |
+
|
728 |
+
def _build_modules(self, module_configs):
|
729 |
+
model_channels = self.hidden_size
|
730 |
+
# time_embed_dim = model_channels * 4
|
731 |
+
time_embed_dim = self.time_embed_dim
|
732 |
+
self.time_embed = nn.Sequential(
|
733 |
+
linear(model_channels, time_embed_dim),
|
734 |
+
nn.SiLU(),
|
735 |
+
linear(time_embed_dim, time_embed_dim),
|
736 |
+
)
|
737 |
+
|
738 |
+
if self.num_classes is not None:
|
739 |
+
if isinstance(self.num_classes, int):
|
740 |
+
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
741 |
+
elif self.num_classes == "continuous":
|
742 |
+
print("setting up linear c_adm embedding layer")
|
743 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
744 |
+
elif self.num_classes == "timestep":
|
745 |
+
self.label_emb = nn.Sequential(
|
746 |
+
Timestep(model_channels),
|
747 |
+
nn.Sequential(
|
748 |
+
linear(model_channels, time_embed_dim),
|
749 |
+
nn.SiLU(),
|
750 |
+
linear(time_embed_dim, time_embed_dim),
|
751 |
+
),
|
752 |
+
)
|
753 |
+
elif self.num_classes == "sequential":
|
754 |
+
assert self.adm_in_channels is not None
|
755 |
+
self.label_emb = nn.Sequential(
|
756 |
+
nn.Sequential(
|
757 |
+
linear(self.adm_in_channels, time_embed_dim),
|
758 |
+
nn.SiLU(),
|
759 |
+
linear(time_embed_dim, time_embed_dim),
|
760 |
+
)
|
761 |
+
)
|
762 |
+
if self.zero_init_y_embed:
|
763 |
+
nn.init.constant_(self.label_emb[0][2].weight, 0)
|
764 |
+
nn.init.constant_(self.label_emb[0][2].bias, 0)
|
765 |
+
else:
|
766 |
+
raise ValueError()
|
767 |
+
|
768 |
+
pos_embed_config = module_configs["pos_embed_config"]
|
769 |
+
self.add_mixin(
|
770 |
+
"pos_embed",
|
771 |
+
instantiate_from_config(
|
772 |
+
pos_embed_config,
|
773 |
+
height=self.latent_height // self.patch_size,
|
774 |
+
width=self.latent_width // self.patch_size,
|
775 |
+
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
776 |
+
hidden_size=self.hidden_size,
|
777 |
+
),
|
778 |
+
reinit=True,
|
779 |
+
)
|
780 |
+
|
781 |
+
patch_embed_config = module_configs["patch_embed_config"]
|
782 |
+
self.add_mixin(
|
783 |
+
"patch_embed",
|
784 |
+
instantiate_from_config(
|
785 |
+
patch_embed_config,
|
786 |
+
patch_size=self.patch_size,
|
787 |
+
hidden_size=self.hidden_size,
|
788 |
+
in_channels=self.in_channels,
|
789 |
+
),
|
790 |
+
reinit=True,
|
791 |
+
)
|
792 |
+
if self.input_time == "adaln":
|
793 |
+
adaln_layer_config = module_configs["adaln_layer_config"]
|
794 |
+
self.add_mixin(
|
795 |
+
"adaln_layer",
|
796 |
+
instantiate_from_config(
|
797 |
+
adaln_layer_config,
|
798 |
+
height=self.latent_height // self.patch_size,
|
799 |
+
width=self.latent_width // self.patch_size,
|
800 |
+
hidden_size=self.hidden_size,
|
801 |
+
num_layers=self.num_layers,
|
802 |
+
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
803 |
+
hidden_size_head=self.hidden_size // self.num_attention_heads,
|
804 |
+
time_embed_dim=self.time_embed_dim,
|
805 |
+
elementwise_affine=self.elementwise_affine,
|
806 |
+
),
|
807 |
+
)
|
808 |
+
else:
|
809 |
+
raise NotImplementedError
|
810 |
+
|
811 |
+
final_layer_config = module_configs["final_layer_config"]
|
812 |
+
self.add_mixin(
|
813 |
+
"final_layer",
|
814 |
+
instantiate_from_config(
|
815 |
+
final_layer_config,
|
816 |
+
hidden_size=self.hidden_size,
|
817 |
+
patch_size=self.patch_size,
|
818 |
+
out_channels=self.out_channels,
|
819 |
+
time_embed_dim=self.time_embed_dim,
|
820 |
+
latent_width=self.latent_width,
|
821 |
+
latent_height=self.latent_height,
|
822 |
+
elementwise_affine=self.elementwise_affine,
|
823 |
+
),
|
824 |
+
reinit=True,
|
825 |
+
)
|
826 |
+
|
827 |
+
if "lora_config" in module_configs:
|
828 |
+
lora_config = module_configs["lora_config"]
|
829 |
+
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
830 |
+
|
831 |
+
return
|
832 |
+
|
833 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
834 |
+
b, t, d, h, w = x.shape
|
835 |
+
if x.dtype != self.dtype:
|
836 |
+
x = x.to(self.dtype)
|
837 |
+
assert (y is not None) == (
|
838 |
+
self.num_classes is not None
|
839 |
+
), "must specify y if and only if the model is class-conditional"
|
840 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
841 |
+
emb = self.time_embed(t_emb)
|
842 |
+
|
843 |
+
if self.num_classes is not None:
|
844 |
+
# assert y.shape[0] == x.shape[0]
|
845 |
+
assert x.shape[0] % y.shape[0] == 0
|
846 |
+
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
|
847 |
+
emb = emb + self.label_emb(y)
|
848 |
+
|
849 |
+
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
|
850 |
+
kwargs["images"] = x
|
851 |
+
kwargs["emb"] = emb
|
852 |
+
kwargs["encoder_outputs"] = context
|
853 |
+
kwargs["text_length"] = context.shape[1]
|
854 |
+
|
855 |
+
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
856 |
+
output = super().forward(**kwargs)[0]
|
857 |
+
|
858 |
+
return output
|
sat/finetune.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
4 |
+
|
5 |
+
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
6 |
+
|
7 |
+
run_cmd="$environs python train_video.py --base configs/cogvideox_2b_sft.yaml --seed $RANDOM"
|
8 |
+
|
9 |
+
echo ${run_cmd}
|
10 |
+
eval ${run_cmd}
|
11 |
+
|
12 |
+
echo "DONE on `hostname`"
|
sat/inference.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
4 |
+
|
5 |
+
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
6 |
+
|
7 |
+
run_cmd="$environs python sample_video.py --base configs/cogvideox_2b_infer.yaml"
|
8 |
+
|
9 |
+
echo ${run_cmd}
|
10 |
+
eval ${run_cmd}
|
11 |
+
|
12 |
+
echo "DONE on `hostname`"
|
sat/requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/spacegoing/SwissArmyTransformer.git
|
2 |
+
diffusers>=0.29.2
|
3 |
+
omegaconf>=2.3.0
|
4 |
+
torch>=2.3.1
|
5 |
+
torchvision>=0.19.0
|
6 |
+
pytorch_lightning>=2.3.3
|
7 |
+
kornia>=0.7.3
|
8 |
+
beartype>=0.18.5
|
9 |
+
numpy>=2.0.1
|
10 |
+
fsspec>=2024.5.0
|
11 |
+
safetensors>=0.4.3
|
12 |
+
imageio-ffmpeg>=0.5.1
|
13 |
+
imageio>=2.34.2
|
14 |
+
scipy>=1.14.0
|
15 |
+
decord>=0.6.0
|
16 |
+
wandb>=0.17.5
|
17 |
+
deepspeed>=0.14.4
|
sat/sample_video.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
from typing import List, Union
|
5 |
+
from tqdm import tqdm
|
6 |
+
from omegaconf import ListConfig
|
7 |
+
import imageio
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
from einops import rearrange
|
12 |
+
import torchvision.transforms as TT
|
13 |
+
|
14 |
+
from sat.model.base_model import get_model
|
15 |
+
from sat.training.model_io import load_checkpoint
|
16 |
+
from sat import mpu
|
17 |
+
|
18 |
+
from diffusion_video import SATVideoDiffusionEngine
|
19 |
+
from arguments import get_args
|
20 |
+
from torchvision.transforms.functional import center_crop, resize
|
21 |
+
from torchvision.transforms import InterpolationMode
|
22 |
+
|
23 |
+
|
24 |
+
def read_from_cli():
|
25 |
+
cnt = 0
|
26 |
+
try:
|
27 |
+
while True:
|
28 |
+
x = input("Please input English text (Ctrl-D quit): ")
|
29 |
+
yield x.strip(), cnt
|
30 |
+
cnt += 1
|
31 |
+
except EOFError as e:
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
def read_from_file(p, rank=0, world_size=1):
|
36 |
+
with open(p, "r") as fin:
|
37 |
+
cnt = -1
|
38 |
+
for l in fin:
|
39 |
+
cnt += 1
|
40 |
+
if cnt % world_size != rank:
|
41 |
+
continue
|
42 |
+
yield l.strip(), cnt
|
43 |
+
|
44 |
+
|
45 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
46 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
47 |
+
|
48 |
+
|
49 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
|
50 |
+
batch = {}
|
51 |
+
batch_uc = {}
|
52 |
+
|
53 |
+
for key in keys:
|
54 |
+
if key == "txt":
|
55 |
+
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
56 |
+
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
57 |
+
else:
|
58 |
+
batch[key] = value_dict[key]
|
59 |
+
|
60 |
+
if T is not None:
|
61 |
+
batch["num_video_frames"] = T
|
62 |
+
|
63 |
+
for key in batch.keys():
|
64 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
65 |
+
batch_uc[key] = torch.clone(batch[key])
|
66 |
+
return batch, batch_uc
|
67 |
+
|
68 |
+
|
69 |
+
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
70 |
+
os.makedirs(save_path, exist_ok=True)
|
71 |
+
|
72 |
+
for i, vid in enumerate(video_batch):
|
73 |
+
gif_frames = []
|
74 |
+
for frame in vid:
|
75 |
+
frame = rearrange(frame, "c h w -> h w c")
|
76 |
+
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
|
77 |
+
gif_frames.append(frame)
|
78 |
+
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
|
79 |
+
with imageio.get_writer(now_save_path, fps=fps) as writer:
|
80 |
+
for frame in gif_frames:
|
81 |
+
writer.append_data(frame)
|
82 |
+
|
83 |
+
|
84 |
+
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
85 |
+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
86 |
+
arr = resize(
|
87 |
+
arr,
|
88 |
+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
89 |
+
interpolation=InterpolationMode.BICUBIC,
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
arr = resize(
|
93 |
+
arr,
|
94 |
+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
95 |
+
interpolation=InterpolationMode.BICUBIC,
|
96 |
+
)
|
97 |
+
|
98 |
+
h, w = arr.shape[2], arr.shape[3]
|
99 |
+
arr = arr.squeeze(0)
|
100 |
+
|
101 |
+
delta_h = h - image_size[0]
|
102 |
+
delta_w = w - image_size[1]
|
103 |
+
|
104 |
+
if reshape_mode == "random" or reshape_mode == "none":
|
105 |
+
top = np.random.randint(0, delta_h + 1)
|
106 |
+
left = np.random.randint(0, delta_w + 1)
|
107 |
+
elif reshape_mode == "center":
|
108 |
+
top, left = delta_h // 2, delta_w // 2
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
112 |
+
return arr
|
113 |
+
|
114 |
+
|
115 |
+
def sampling_main(args, model_cls):
|
116 |
+
if isinstance(model_cls, type):
|
117 |
+
model = get_model(args, model_cls)
|
118 |
+
else:
|
119 |
+
model = model_cls
|
120 |
+
|
121 |
+
load_checkpoint(model, args)
|
122 |
+
model.eval()
|
123 |
+
|
124 |
+
if args.input_type == "cli":
|
125 |
+
data_iter = read_from_cli()
|
126 |
+
elif args.input_type == "txt":
|
127 |
+
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
128 |
+
print("rank and world_size", rank, world_size)
|
129 |
+
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
130 |
+
else:
|
131 |
+
raise NotImplementedError
|
132 |
+
|
133 |
+
image_size = [480, 720]
|
134 |
+
|
135 |
+
sample_func = model.sample
|
136 |
+
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
137 |
+
num_samples = [1]
|
138 |
+
force_uc_zero_embeddings = ["txt"]
|
139 |
+
device = model.device
|
140 |
+
with torch.no_grad():
|
141 |
+
for text, cnt in tqdm(data_iter):
|
142 |
+
# reload model on GPU
|
143 |
+
model.to(device)
|
144 |
+
print("rank:", rank, "start to process", text, cnt)
|
145 |
+
# TODO: broadcast image2video
|
146 |
+
value_dict = {
|
147 |
+
"prompt": text,
|
148 |
+
"negative_prompt": "",
|
149 |
+
"num_frames": torch.tensor(T).unsqueeze(0),
|
150 |
+
}
|
151 |
+
|
152 |
+
batch, batch_uc = get_batch(
|
153 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
154 |
+
)
|
155 |
+
for key in batch:
|
156 |
+
if isinstance(batch[key], torch.Tensor):
|
157 |
+
print(key, batch[key].shape)
|
158 |
+
elif isinstance(batch[key], list):
|
159 |
+
print(key, [len(l) for l in batch[key]])
|
160 |
+
else:
|
161 |
+
print(key, batch[key])
|
162 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
163 |
+
batch,
|
164 |
+
batch_uc=batch_uc,
|
165 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
166 |
+
)
|
167 |
+
|
168 |
+
for k in c:
|
169 |
+
if not k == "crossattn":
|
170 |
+
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
171 |
+
for index in range(args.batch_size):
|
172 |
+
# reload model on GPU
|
173 |
+
model.to(device)
|
174 |
+
samples_z = sample_func(
|
175 |
+
c,
|
176 |
+
uc=uc,
|
177 |
+
batch_size=1,
|
178 |
+
shape=(T, C, H // F, W // F),
|
179 |
+
)
|
180 |
+
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
181 |
+
|
182 |
+
# Unload the model from GPU to save GPU memory
|
183 |
+
model.to("cpu")
|
184 |
+
torch.cuda.empty_cache()
|
185 |
+
first_stage_model = model.first_stage_model
|
186 |
+
first_stage_model = first_stage_model.to(device)
|
187 |
+
|
188 |
+
latent = 1.0 / model.scale_factor * samples_z
|
189 |
+
|
190 |
+
# Decode latent serial to save GPU memory
|
191 |
+
recons = []
|
192 |
+
loop_num = (T - 1) // 2
|
193 |
+
for i in range(loop_num):
|
194 |
+
if i == 0:
|
195 |
+
start_frame, end_frame = 0, 3
|
196 |
+
else:
|
197 |
+
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
198 |
+
if i == loop_num - 1:
|
199 |
+
clear_fake_cp_cache = True
|
200 |
+
else:
|
201 |
+
clear_fake_cp_cache = False
|
202 |
+
with torch.no_grad():
|
203 |
+
recon = first_stage_model.decode(
|
204 |
+
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
205 |
+
)
|
206 |
+
|
207 |
+
recons.append(recon)
|
208 |
+
|
209 |
+
recon = torch.cat(recons, dim=2).to(torch.float32)
|
210 |
+
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
211 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
212 |
+
|
213 |
+
save_path = os.path.join(
|
214 |
+
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
215 |
+
)
|
216 |
+
if mpu.get_model_parallel_rank() == 0:
|
217 |
+
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
218 |
+
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
222 |
+
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
223 |
+
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
224 |
+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
225 |
+
py_parser = argparse.ArgumentParser(add_help=False)
|
226 |
+
known, args_list = py_parser.parse_known_args()
|
227 |
+
|
228 |
+
args = get_args(args_list)
|
229 |
+
args = argparse.Namespace(**vars(args), **vars(known))
|
230 |
+
del args.deepspeed_config
|
231 |
+
args.model_config.first_stage_config.params.cp_size = 1
|
232 |
+
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
|
233 |
+
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
|
234 |
+
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
|
235 |
+
|
236 |
+
sampling_main(args, model_cls=SATVideoDiffusionEngine)
|
sat/sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sat/sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
32 |
+
self.last_lr = lr
|
33 |
+
return lr
|
34 |
+
else:
|
35 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
36 |
+
t = min(t, 1.0)
|
37 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
|
38 |
+
self.last_lr = lr
|
39 |
+
return lr
|
40 |
+
|
41 |
+
def __call__(self, n, **kwargs):
|
42 |
+
return self.schedule(n, **kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
class LambdaWarmUpCosineScheduler2:
|
46 |
+
"""
|
47 |
+
supports repeated iterations, configurable via lists
|
48 |
+
note: use with a base_lr of 1.0.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
52 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
53 |
+
self.lr_warm_up_steps = warm_up_steps
|
54 |
+
self.f_start = f_start
|
55 |
+
self.f_min = f_min
|
56 |
+
self.f_max = f_max
|
57 |
+
self.cycle_lengths = cycle_lengths
|
58 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
59 |
+
self.last_f = 0.0
|
60 |
+
self.verbosity_interval = verbosity_interval
|
61 |
+
|
62 |
+
def find_in_interval(self, n):
|
63 |
+
interval = 0
|
64 |
+
for cl in self.cum_cycles[1:]:
|
65 |
+
if n <= cl:
|
66 |
+
return interval
|
67 |
+
interval += 1
|
68 |
+
|
69 |
+
def schedule(self, n, **kwargs):
|
70 |
+
cycle = self.find_in_interval(n)
|
71 |
+
n = n - self.cum_cycles[cycle]
|
72 |
+
if self.verbosity_interval > 0:
|
73 |
+
if n % self.verbosity_interval == 0:
|
74 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
75 |
+
if n < self.lr_warm_up_steps[cycle]:
|
76 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
77 |
+
self.last_f = f
|
78 |
+
return f
|
79 |
+
else:
|
80 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
81 |
+
t = min(t, 1.0)
|
82 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
83 |
+
self.last_f = f
|
84 |
+
return f
|
85 |
+
|
86 |
+
def __call__(self, n, **kwargs):
|
87 |
+
return self.schedule(n, **kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
91 |
+
def schedule(self, n, **kwargs):
|
92 |
+
cycle = self.find_in_interval(n)
|
93 |
+
n = n - self.cum_cycles[cycle]
|
94 |
+
if self.verbosity_interval > 0:
|
95 |
+
if n % self.verbosity_interval == 0:
|
96 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
97 |
+
|
98 |
+
if n < self.lr_warm_up_steps[cycle]:
|
99 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
100 |
+
self.last_f = f
|
101 |
+
return f
|
102 |
+
else:
|
103 |
+
f = (
|
104 |
+
self.f_min[cycle]
|
105 |
+
+ (self.f_max[cycle] - self.f_min[cycle])
|
106 |
+
* (self.cycle_lengths[cycle] - n)
|
107 |
+
/ (self.cycle_lengths[cycle])
|
108 |
+
)
|
109 |
+
self.last_f = f
|
110 |
+
return f
|
sat/sgm/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
sat/sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
import random
|
5 |
+
from abc import abstractmethod
|
6 |
+
from contextlib import contextmanager
|
7 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
import torch
|
12 |
+
import torch.distributed
|
13 |
+
import torch.nn as nn
|
14 |
+
from einops import rearrange
|
15 |
+
from packaging import version
|
16 |
+
|
17 |
+
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
18 |
+
from ..modules.ema import LitEma
|
19 |
+
from ..util import (
|
20 |
+
default,
|
21 |
+
get_nested_attribute,
|
22 |
+
get_obj_from_str,
|
23 |
+
instantiate_from_config,
|
24 |
+
initialize_context_parallel,
|
25 |
+
get_context_parallel_group,
|
26 |
+
get_context_parallel_group_rank,
|
27 |
+
is_context_parallel_initialized,
|
28 |
+
)
|
29 |
+
from ..modules.cp_enc_dec import _conv_split, _conv_gather
|
30 |
+
|
31 |
+
logpy = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class AbstractAutoencoder(pl.LightningModule):
|
35 |
+
"""
|
36 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
37 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
38 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
ema_decay: Union[None, float] = None,
|
44 |
+
monitor: Union[None, str] = None,
|
45 |
+
input_key: str = "jpg",
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.input_key = input_key
|
50 |
+
self.use_ema = ema_decay is not None
|
51 |
+
if monitor is not None:
|
52 |
+
self.monitor = monitor
|
53 |
+
|
54 |
+
if self.use_ema:
|
55 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
56 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
57 |
+
|
58 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
59 |
+
self.automatic_optimization = False
|
60 |
+
|
61 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
62 |
+
if ckpt is None:
|
63 |
+
return
|
64 |
+
if isinstance(ckpt, str):
|
65 |
+
ckpt = {
|
66 |
+
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
67 |
+
"params": {"ckpt_path": ckpt},
|
68 |
+
}
|
69 |
+
engine = instantiate_from_config(ckpt)
|
70 |
+
engine(self)
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def get_input(self, batch) -> Any:
|
74 |
+
raise NotImplementedError()
|
75 |
+
|
76 |
+
def on_train_batch_end(self, *args, **kwargs):
|
77 |
+
# for EMA computation
|
78 |
+
if self.use_ema:
|
79 |
+
self.model_ema(self)
|
80 |
+
|
81 |
+
@contextmanager
|
82 |
+
def ema_scope(self, context=None):
|
83 |
+
if self.use_ema:
|
84 |
+
self.model_ema.store(self.parameters())
|
85 |
+
self.model_ema.copy_to(self)
|
86 |
+
if context is not None:
|
87 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
88 |
+
try:
|
89 |
+
yield None
|
90 |
+
finally:
|
91 |
+
if self.use_ema:
|
92 |
+
self.model_ema.restore(self.parameters())
|
93 |
+
if context is not None:
|
94 |
+
logpy.info(f"{context}: Restored training weights")
|
95 |
+
|
96 |
+
@abstractmethod
|
97 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
98 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
99 |
+
|
100 |
+
@abstractmethod
|
101 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
102 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
103 |
+
|
104 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
105 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
106 |
+
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
107 |
+
|
108 |
+
def configure_optimizers(self) -> Any:
|
109 |
+
raise NotImplementedError()
|
110 |
+
|
111 |
+
|
112 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
113 |
+
"""
|
114 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
115 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
116 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
*args,
|
122 |
+
encoder_config: Dict,
|
123 |
+
decoder_config: Dict,
|
124 |
+
loss_config: Dict,
|
125 |
+
regularizer_config: Dict,
|
126 |
+
optimizer_config: Union[Dict, None] = None,
|
127 |
+
lr_g_factor: float = 1.0,
|
128 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
129 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
130 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
131 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
132 |
+
disc_start_iter: int = 0,
|
133 |
+
diff_boost_factor: float = 3.0,
|
134 |
+
ckpt_engine: Union[None, str, dict] = None,
|
135 |
+
ckpt_path: Optional[str] = None,
|
136 |
+
additional_decode_keys: Optional[List[str]] = None,
|
137 |
+
**kwargs,
|
138 |
+
):
|
139 |
+
super().__init__(*args, **kwargs)
|
140 |
+
self.automatic_optimization = False # pytorch lightning
|
141 |
+
|
142 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
143 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
144 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
145 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config)
|
146 |
+
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
|
147 |
+
self.diff_boost_factor = diff_boost_factor
|
148 |
+
self.disc_start_iter = disc_start_iter
|
149 |
+
self.lr_g_factor = lr_g_factor
|
150 |
+
self.trainable_ae_params = trainable_ae_params
|
151 |
+
if self.trainable_ae_params is not None:
|
152 |
+
self.ae_optimizer_args = default(
|
153 |
+
ae_optimizer_args,
|
154 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
155 |
+
)
|
156 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
157 |
+
else:
|
158 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
159 |
+
|
160 |
+
self.trainable_disc_params = trainable_disc_params
|
161 |
+
if self.trainable_disc_params is not None:
|
162 |
+
self.disc_optimizer_args = default(
|
163 |
+
disc_optimizer_args,
|
164 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
165 |
+
)
|
166 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
167 |
+
else:
|
168 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
169 |
+
|
170 |
+
if ckpt_path is not None:
|
171 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
172 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
173 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
174 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
175 |
+
|
176 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
177 |
+
# assuming unified data format, dataloader returns a dict.
|
178 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
179 |
+
# format (e.g., bchw instead if bhwc)
|
180 |
+
return batch[self.input_key]
|
181 |
+
|
182 |
+
def get_autoencoder_params(self) -> list:
|
183 |
+
params = []
|
184 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
185 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
186 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
187 |
+
params += list(self.regularization.get_trainable_parameters())
|
188 |
+
params = params + list(self.encoder.parameters())
|
189 |
+
params = params + list(self.decoder.parameters())
|
190 |
+
return params
|
191 |
+
|
192 |
+
def get_discriminator_params(self) -> list:
|
193 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
194 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
195 |
+
else:
|
196 |
+
params = []
|
197 |
+
return params
|
198 |
+
|
199 |
+
def get_last_layer(self):
|
200 |
+
return self.decoder.get_last_layer()
|
201 |
+
|
202 |
+
def encode(
|
203 |
+
self,
|
204 |
+
x: torch.Tensor,
|
205 |
+
return_reg_log: bool = False,
|
206 |
+
unregularized: bool = False,
|
207 |
+
**kwargs,
|
208 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
209 |
+
z = self.encoder(x, **kwargs)
|
210 |
+
if unregularized:
|
211 |
+
return z, dict()
|
212 |
+
z, reg_log = self.regularization(z)
|
213 |
+
if return_reg_log:
|
214 |
+
return z, reg_log
|
215 |
+
return z
|
216 |
+
|
217 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
218 |
+
x = self.decoder(z, **kwargs)
|
219 |
+
return x
|
220 |
+
|
221 |
+
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
222 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
223 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
224 |
+
return z, dec, reg_log
|
225 |
+
|
226 |
+
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
227 |
+
x = self.get_input(batch)
|
228 |
+
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
229 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
230 |
+
if hasattr(self.loss, "forward_keys"):
|
231 |
+
extra_info = {
|
232 |
+
"z": z,
|
233 |
+
"optimizer_idx": optimizer_idx,
|
234 |
+
"global_step": self.global_step,
|
235 |
+
"last_layer": self.get_last_layer(),
|
236 |
+
"split": "train",
|
237 |
+
"regularization_log": regularization_log,
|
238 |
+
"autoencoder": self,
|
239 |
+
}
|
240 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
241 |
+
else:
|
242 |
+
extra_info = dict()
|
243 |
+
|
244 |
+
if optimizer_idx == 0:
|
245 |
+
# autoencode
|
246 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
247 |
+
if isinstance(out_loss, tuple):
|
248 |
+
aeloss, log_dict_ae = out_loss
|
249 |
+
else:
|
250 |
+
# simple loss function
|
251 |
+
aeloss = out_loss
|
252 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
253 |
+
|
254 |
+
self.log_dict(
|
255 |
+
log_dict_ae,
|
256 |
+
prog_bar=False,
|
257 |
+
logger=True,
|
258 |
+
on_step=True,
|
259 |
+
on_epoch=True,
|
260 |
+
sync_dist=False,
|
261 |
+
)
|
262 |
+
self.log(
|
263 |
+
"loss",
|
264 |
+
aeloss.mean().detach(),
|
265 |
+
prog_bar=True,
|
266 |
+
logger=False,
|
267 |
+
on_epoch=False,
|
268 |
+
on_step=True,
|
269 |
+
)
|
270 |
+
return aeloss
|
271 |
+
elif optimizer_idx == 1:
|
272 |
+
# discriminator
|
273 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
274 |
+
# -> discriminator always needs to return a tuple
|
275 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
276 |
+
return discloss
|
277 |
+
else:
|
278 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
279 |
+
|
280 |
+
def training_step(self, batch: dict, batch_idx: int):
|
281 |
+
opts = self.optimizers()
|
282 |
+
if not isinstance(opts, list):
|
283 |
+
# Non-adversarial case
|
284 |
+
opts = [opts]
|
285 |
+
optimizer_idx = batch_idx % len(opts)
|
286 |
+
if self.global_step < self.disc_start_iter:
|
287 |
+
optimizer_idx = 0
|
288 |
+
opt = opts[optimizer_idx]
|
289 |
+
opt.zero_grad()
|
290 |
+
with opt.toggle_model():
|
291 |
+
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
|
292 |
+
self.manual_backward(loss)
|
293 |
+
opt.step()
|
294 |
+
|
295 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
296 |
+
log_dict = self._validation_step(batch, batch_idx)
|
297 |
+
with self.ema_scope():
|
298 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
299 |
+
log_dict.update(log_dict_ema)
|
300 |
+
return log_dict
|
301 |
+
|
302 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
303 |
+
x = self.get_input(batch)
|
304 |
+
|
305 |
+
z, xrec, regularization_log = self(x)
|
306 |
+
if hasattr(self.loss, "forward_keys"):
|
307 |
+
extra_info = {
|
308 |
+
"z": z,
|
309 |
+
"optimizer_idx": 0,
|
310 |
+
"global_step": self.global_step,
|
311 |
+
"last_layer": self.get_last_layer(),
|
312 |
+
"split": "val" + postfix,
|
313 |
+
"regularization_log": regularization_log,
|
314 |
+
"autoencoder": self,
|
315 |
+
}
|
316 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
317 |
+
else:
|
318 |
+
extra_info = dict()
|
319 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
320 |
+
if isinstance(out_loss, tuple):
|
321 |
+
aeloss, log_dict_ae = out_loss
|
322 |
+
else:
|
323 |
+
# simple loss function
|
324 |
+
aeloss = out_loss
|
325 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
326 |
+
full_log_dict = log_dict_ae
|
327 |
+
|
328 |
+
if "optimizer_idx" in extra_info:
|
329 |
+
extra_info["optimizer_idx"] = 1
|
330 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
331 |
+
full_log_dict.update(log_dict_disc)
|
332 |
+
self.log(
|
333 |
+
f"val{postfix}/loss/rec",
|
334 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
335 |
+
sync_dist=True,
|
336 |
+
)
|
337 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
338 |
+
return full_log_dict
|
339 |
+
|
340 |
+
def get_param_groups(
|
341 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
342 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
343 |
+
groups = []
|
344 |
+
num_params = 0
|
345 |
+
for names, args in zip(parameter_names, optimizer_args):
|
346 |
+
params = []
|
347 |
+
for pattern_ in names:
|
348 |
+
pattern_params = []
|
349 |
+
pattern = re.compile(pattern_)
|
350 |
+
for p_name, param in self.named_parameters():
|
351 |
+
if re.match(pattern, p_name):
|
352 |
+
pattern_params.append(param)
|
353 |
+
num_params += param.numel()
|
354 |
+
if len(pattern_params) == 0:
|
355 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
356 |
+
params.extend(pattern_params)
|
357 |
+
groups.append({"params": params, **args})
|
358 |
+
return groups, num_params
|
359 |
+
|
360 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
361 |
+
if self.trainable_ae_params is None:
|
362 |
+
ae_params = self.get_autoencoder_params()
|
363 |
+
else:
|
364 |
+
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
365 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
366 |
+
if self.trainable_disc_params is None:
|
367 |
+
disc_params = self.get_discriminator_params()
|
368 |
+
else:
|
369 |
+
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
370 |
+
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
371 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
372 |
+
ae_params,
|
373 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
374 |
+
self.optimizer_config,
|
375 |
+
)
|
376 |
+
opts = [opt_ae]
|
377 |
+
if len(disc_params) > 0:
|
378 |
+
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
379 |
+
opts.append(opt_disc)
|
380 |
+
|
381 |
+
return opts
|
382 |
+
|
383 |
+
@torch.no_grad()
|
384 |
+
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
385 |
+
log = dict()
|
386 |
+
additional_decode_kwargs = {}
|
387 |
+
x = self.get_input(batch)
|
388 |
+
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
389 |
+
|
390 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
391 |
+
log["inputs"] = x
|
392 |
+
log["reconstructions"] = xrec
|
393 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
394 |
+
diff.clamp_(0, 1.0)
|
395 |
+
log["diff"] = 2.0 * diff - 1.0
|
396 |
+
# diff_boost shows location of small errors, by boosting their
|
397 |
+
# brightness.
|
398 |
+
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
399 |
+
if hasattr(self.loss, "log_images"):
|
400 |
+
log.update(self.loss.log_images(x, xrec))
|
401 |
+
with self.ema_scope():
|
402 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
403 |
+
log["reconstructions_ema"] = xrec_ema
|
404 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
405 |
+
diff_ema.clamp_(0, 1.0)
|
406 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
407 |
+
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
408 |
+
if additional_log_kwargs:
|
409 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
410 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
411 |
+
log_str = "reconstructions-" + "-".join(
|
412 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
413 |
+
)
|
414 |
+
log[log_str] = xrec_add
|
415 |
+
return log
|
416 |
+
|
417 |
+
|
418 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
419 |
+
def __init__(self, embed_dim: int, **kwargs):
|
420 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
421 |
+
ddconfig = kwargs.pop("ddconfig")
|
422 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
423 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
424 |
+
super().__init__(
|
425 |
+
encoder_config={
|
426 |
+
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
427 |
+
"params": ddconfig,
|
428 |
+
},
|
429 |
+
decoder_config={
|
430 |
+
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
431 |
+
"params": ddconfig,
|
432 |
+
},
|
433 |
+
**kwargs,
|
434 |
+
)
|
435 |
+
self.quant_conv = torch.nn.Conv2d(
|
436 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
437 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
438 |
+
1,
|
439 |
+
)
|
440 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
441 |
+
self.embed_dim = embed_dim
|
442 |
+
|
443 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
444 |
+
|
445 |
+
def get_autoencoder_params(self) -> list:
|
446 |
+
params = super().get_autoencoder_params()
|
447 |
+
return params
|
448 |
+
|
449 |
+
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
450 |
+
if self.max_batch_size is None:
|
451 |
+
z = self.encoder(x)
|
452 |
+
z = self.quant_conv(z)
|
453 |
+
else:
|
454 |
+
N = x.shape[0]
|
455 |
+
bs = self.max_batch_size
|
456 |
+
n_batches = int(math.ceil(N / bs))
|
457 |
+
z = list()
|
458 |
+
for i_batch in range(n_batches):
|
459 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
460 |
+
z_batch = self.quant_conv(z_batch)
|
461 |
+
z.append(z_batch)
|
462 |
+
z = torch.cat(z, 0)
|
463 |
+
|
464 |
+
z, reg_log = self.regularization(z)
|
465 |
+
if return_reg_log:
|
466 |
+
return z, reg_log
|
467 |
+
return z
|
468 |
+
|
469 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
470 |
+
if self.max_batch_size is None:
|
471 |
+
dec = self.post_quant_conv(z)
|
472 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
473 |
+
else:
|
474 |
+
N = z.shape[0]
|
475 |
+
bs = self.max_batch_size
|
476 |
+
n_batches = int(math.ceil(N / bs))
|
477 |
+
dec = list()
|
478 |
+
for i_batch in range(n_batches):
|
479 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
480 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
481 |
+
dec.append(dec_batch)
|
482 |
+
dec = torch.cat(dec, 0)
|
483 |
+
|
484 |
+
return dec
|
485 |
+
|
486 |
+
|
487 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
488 |
+
def __init__(self, *args, **kwargs):
|
489 |
+
super().__init__(*args, **kwargs)
|
490 |
+
|
491 |
+
def get_input(self, x: Any) -> Any:
|
492 |
+
return x
|
493 |
+
|
494 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
495 |
+
return x
|
496 |
+
|
497 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
498 |
+
return
|
499 |
+
|
500 |
+
|
501 |
+
class VideoAutoencodingEngine(AutoencodingEngine):
|
502 |
+
def __init__(
|
503 |
+
self,
|
504 |
+
ckpt_path: Union[None, str] = None,
|
505 |
+
ignore_keys: Union[Tuple, list] = (),
|
506 |
+
image_video_weights=[1, 1],
|
507 |
+
only_train_decoder=False,
|
508 |
+
context_parallel_size=0,
|
509 |
+
**kwargs,
|
510 |
+
):
|
511 |
+
super().__init__(**kwargs)
|
512 |
+
self.context_parallel_size = context_parallel_size
|
513 |
+
if ckpt_path is not None:
|
514 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
515 |
+
|
516 |
+
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
517 |
+
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
518 |
+
|
519 |
+
def get_input(self, batch: dict) -> torch.Tensor:
|
520 |
+
if self.context_parallel_size > 0:
|
521 |
+
if not is_context_parallel_initialized():
|
522 |
+
initialize_context_parallel(self.context_parallel_size)
|
523 |
+
|
524 |
+
batch = batch[self.input_key]
|
525 |
+
|
526 |
+
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
527 |
+
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
528 |
+
|
529 |
+
batch = _conv_split(batch, dim=2, kernel_size=1)
|
530 |
+
return batch
|
531 |
+
|
532 |
+
return batch[self.input_key]
|
533 |
+
|
534 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
535 |
+
if ckpt is None:
|
536 |
+
return
|
537 |
+
self.init_from_ckpt(ckpt)
|
538 |
+
|
539 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
540 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
541 |
+
keys = list(sd.keys())
|
542 |
+
for k in keys:
|
543 |
+
for ik in ignore_keys:
|
544 |
+
if k.startswith(ik):
|
545 |
+
del sd[k]
|
546 |
+
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
547 |
+
print("Missing keys: ", missing_keys)
|
548 |
+
print("Unexpected keys: ", unexpected_keys)
|
549 |
+
print(f"Restored from {path}")
|
550 |
+
|
551 |
+
|
552 |
+
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
553 |
+
def __init__(
|
554 |
+
self,
|
555 |
+
cp_size=0,
|
556 |
+
*args,
|
557 |
+
**kwargs,
|
558 |
+
):
|
559 |
+
self.cp_size = cp_size
|
560 |
+
return super().__init__(*args, **kwargs)
|
561 |
+
|
562 |
+
def encode(
|
563 |
+
self,
|
564 |
+
x: torch.Tensor,
|
565 |
+
return_reg_log: bool = False,
|
566 |
+
unregularized: bool = False,
|
567 |
+
input_cp: bool = False,
|
568 |
+
output_cp: bool = False,
|
569 |
+
use_cp: bool = True,
|
570 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
571 |
+
if self.cp_size <= 1:
|
572 |
+
use_cp = False
|
573 |
+
if self.cp_size > 0 and use_cp and not input_cp:
|
574 |
+
if not is_context_parallel_initialized:
|
575 |
+
initialize_context_parallel(self.cp_size)
|
576 |
+
|
577 |
+
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
578 |
+
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
|
579 |
+
|
580 |
+
x = _conv_split(x, dim=2, kernel_size=1)
|
581 |
+
|
582 |
+
if return_reg_log:
|
583 |
+
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
584 |
+
else:
|
585 |
+
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
586 |
+
|
587 |
+
if self.cp_size > 0 and use_cp and not output_cp:
|
588 |
+
z = _conv_gather(z, dim=2, kernel_size=1)
|
589 |
+
|
590 |
+
if return_reg_log:
|
591 |
+
return z, reg_log
|
592 |
+
return z
|
593 |
+
|
594 |
+
def decode(
|
595 |
+
self,
|
596 |
+
z: torch.Tensor,
|
597 |
+
input_cp: bool = False,
|
598 |
+
output_cp: bool = False,
|
599 |
+
use_cp: bool = True,
|
600 |
+
**kwargs,
|
601 |
+
):
|
602 |
+
if self.cp_size <= 1:
|
603 |
+
use_cp = False
|
604 |
+
if self.cp_size > 0 and use_cp and not input_cp:
|
605 |
+
if not is_context_parallel_initialized:
|
606 |
+
initialize_context_parallel(self.cp_size)
|
607 |
+
|
608 |
+
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
609 |
+
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
610 |
+
|
611 |
+
z = _conv_split(z, dim=2, kernel_size=1)
|
612 |
+
|
613 |
+
x = super().decode(z, use_cp=use_cp, **kwargs)
|
614 |
+
|
615 |
+
if self.cp_size > 0 and use_cp and not output_cp:
|
616 |
+
x = _conv_gather(x, dim=2, kernel_size=1)
|
617 |
+
|
618 |
+
return x
|
619 |
+
|
620 |
+
def forward(
|
621 |
+
self,
|
622 |
+
x: torch.Tensor,
|
623 |
+
input_cp: bool = False,
|
624 |
+
latent_cp: bool = False,
|
625 |
+
output_cp: bool = False,
|
626 |
+
**additional_decode_kwargs,
|
627 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
628 |
+
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
629 |
+
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
630 |
+
return z, dec, reg_log
|
sat/sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
sat/sgm/modules/attention.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
from packaging import version
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
12 |
+
SDP_IS_AVAILABLE = True
|
13 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
14 |
+
|
15 |
+
BACKEND_MAP = {
|
16 |
+
SDPBackend.MATH: {
|
17 |
+
"enable_math": True,
|
18 |
+
"enable_flash": False,
|
19 |
+
"enable_mem_efficient": False,
|
20 |
+
},
|
21 |
+
SDPBackend.FLASH_ATTENTION: {
|
22 |
+
"enable_math": False,
|
23 |
+
"enable_flash": True,
|
24 |
+
"enable_mem_efficient": False,
|
25 |
+
},
|
26 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
27 |
+
"enable_math": False,
|
28 |
+
"enable_flash": False,
|
29 |
+
"enable_mem_efficient": True,
|
30 |
+
},
|
31 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
32 |
+
}
|
33 |
+
else:
|
34 |
+
from contextlib import nullcontext
|
35 |
+
|
36 |
+
SDP_IS_AVAILABLE = False
|
37 |
+
sdp_kernel = nullcontext
|
38 |
+
BACKEND_MAP = {}
|
39 |
+
print(
|
40 |
+
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
41 |
+
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
42 |
+
)
|
43 |
+
|
44 |
+
try:
|
45 |
+
import xformers
|
46 |
+
import xformers.ops
|
47 |
+
|
48 |
+
XFORMERS_IS_AVAILABLE = True
|
49 |
+
except:
|
50 |
+
XFORMERS_IS_AVAILABLE = False
|
51 |
+
print("no module 'xformers'. Processing without...")
|
52 |
+
|
53 |
+
from .diffusionmodules.util import checkpoint
|
54 |
+
|
55 |
+
|
56 |
+
def exists(val):
|
57 |
+
return val is not None
|
58 |
+
|
59 |
+
|
60 |
+
def uniq(arr):
|
61 |
+
return {el: True for el in arr}.keys()
|
62 |
+
|
63 |
+
|
64 |
+
def default(val, d):
|
65 |
+
if exists(val):
|
66 |
+
return val
|
67 |
+
return d() if isfunction(d) else d
|
68 |
+
|
69 |
+
|
70 |
+
def max_neg_value(t):
|
71 |
+
return -torch.finfo(t.dtype).max
|
72 |
+
|
73 |
+
|
74 |
+
def init_(tensor):
|
75 |
+
dim = tensor.shape[-1]
|
76 |
+
std = 1 / math.sqrt(dim)
|
77 |
+
tensor.uniform_(-std, std)
|
78 |
+
return tensor
|
79 |
+
|
80 |
+
|
81 |
+
# feedforward
|
82 |
+
class GEGLU(nn.Module):
|
83 |
+
def __init__(self, dim_in, dim_out):
|
84 |
+
super().__init__()
|
85 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
89 |
+
return x * F.gelu(gate)
|
90 |
+
|
91 |
+
|
92 |
+
class FeedForward(nn.Module):
|
93 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
94 |
+
super().__init__()
|
95 |
+
inner_dim = int(dim * mult)
|
96 |
+
dim_out = default(dim_out, dim)
|
97 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
98 |
+
|
99 |
+
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return self.net(x)
|
103 |
+
|
104 |
+
|
105 |
+
def zero_module(module):
|
106 |
+
"""
|
107 |
+
Zero out the parameters of a module and return it.
|
108 |
+
"""
|
109 |
+
for p in module.parameters():
|
110 |
+
p.detach().zero_()
|
111 |
+
return module
|
112 |
+
|
113 |
+
|
114 |
+
def Normalize(in_channels):
|
115 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
116 |
+
|
117 |
+
|
118 |
+
class LinearAttention(nn.Module):
|
119 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
120 |
+
super().__init__()
|
121 |
+
self.heads = heads
|
122 |
+
hidden_dim = dim_head * heads
|
123 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
124 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
b, c, h, w = x.shape
|
128 |
+
qkv = self.to_qkv(x)
|
129 |
+
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
130 |
+
k = k.softmax(dim=-1)
|
131 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
132 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
133 |
+
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
134 |
+
return self.to_out(out)
|
135 |
+
|
136 |
+
|
137 |
+
class SpatialSelfAttention(nn.Module):
|
138 |
+
def __init__(self, in_channels):
|
139 |
+
super().__init__()
|
140 |
+
self.in_channels = in_channels
|
141 |
+
|
142 |
+
self.norm = Normalize(in_channels)
|
143 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
144 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
145 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
146 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
h_ = x
|
150 |
+
h_ = self.norm(h_)
|
151 |
+
q = self.q(h_)
|
152 |
+
k = self.k(h_)
|
153 |
+
v = self.v(h_)
|
154 |
+
|
155 |
+
# compute attention
|
156 |
+
b, c, h, w = q.shape
|
157 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
158 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
159 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
160 |
+
|
161 |
+
w_ = w_ * (int(c) ** (-0.5))
|
162 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
163 |
+
|
164 |
+
# attend to values
|
165 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
166 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
167 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
168 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
169 |
+
h_ = self.proj_out(h_)
|
170 |
+
|
171 |
+
return x + h_
|
172 |
+
|
173 |
+
|
174 |
+
class CrossAttention(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
query_dim,
|
178 |
+
context_dim=None,
|
179 |
+
heads=8,
|
180 |
+
dim_head=64,
|
181 |
+
dropout=0.0,
|
182 |
+
backend=None,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
inner_dim = dim_head * heads
|
186 |
+
context_dim = default(context_dim, query_dim)
|
187 |
+
|
188 |
+
self.scale = dim_head**-0.5
|
189 |
+
self.heads = heads
|
190 |
+
|
191 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
192 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
193 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
194 |
+
|
195 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
196 |
+
self.backend = backend
|
197 |
+
|
198 |
+
def forward(
|
199 |
+
self,
|
200 |
+
x,
|
201 |
+
context=None,
|
202 |
+
mask=None,
|
203 |
+
additional_tokens=None,
|
204 |
+
n_times_crossframe_attn_in_self=0,
|
205 |
+
):
|
206 |
+
h = self.heads
|
207 |
+
|
208 |
+
if additional_tokens is not None:
|
209 |
+
# get the number of masked tokens at the beginning of the output sequence
|
210 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
211 |
+
# add additional token
|
212 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
213 |
+
|
214 |
+
q = self.to_q(x)
|
215 |
+
context = default(context, x)
|
216 |
+
k = self.to_k(context)
|
217 |
+
v = self.to_v(context)
|
218 |
+
|
219 |
+
if n_times_crossframe_attn_in_self:
|
220 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
221 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
222 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
223 |
+
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
224 |
+
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
225 |
+
|
226 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
227 |
+
|
228 |
+
## old
|
229 |
+
"""
|
230 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
231 |
+
del q, k
|
232 |
+
|
233 |
+
if exists(mask):
|
234 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
235 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
236 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
237 |
+
sim.masked_fill_(~mask, max_neg_value)
|
238 |
+
|
239 |
+
# attention, what we cannot get enough of
|
240 |
+
sim = sim.softmax(dim=-1)
|
241 |
+
|
242 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
243 |
+
"""
|
244 |
+
## new
|
245 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
246 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
247 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
248 |
+
|
249 |
+
del q, k, v
|
250 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
251 |
+
|
252 |
+
if additional_tokens is not None:
|
253 |
+
# remove additional token
|
254 |
+
out = out[:, n_tokens_to_mask:]
|
255 |
+
return self.to_out(out)
|
256 |
+
|
257 |
+
|
258 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
259 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
260 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
|
261 |
+
super().__init__()
|
262 |
+
print(
|
263 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
264 |
+
f"{heads} heads with a dimension of {dim_head}."
|
265 |
+
)
|
266 |
+
inner_dim = dim_head * heads
|
267 |
+
context_dim = default(context_dim, query_dim)
|
268 |
+
|
269 |
+
self.heads = heads
|
270 |
+
self.dim_head = dim_head
|
271 |
+
|
272 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
273 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
274 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
+
|
276 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
277 |
+
self.attention_op: Optional[Any] = None
|
278 |
+
|
279 |
+
def forward(
|
280 |
+
self,
|
281 |
+
x,
|
282 |
+
context=None,
|
283 |
+
mask=None,
|
284 |
+
additional_tokens=None,
|
285 |
+
n_times_crossframe_attn_in_self=0,
|
286 |
+
):
|
287 |
+
if additional_tokens is not None:
|
288 |
+
# get the number of masked tokens at the beginning of the output sequence
|
289 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
290 |
+
# add additional token
|
291 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
292 |
+
q = self.to_q(x)
|
293 |
+
context = default(context, x)
|
294 |
+
k = self.to_k(context)
|
295 |
+
v = self.to_v(context)
|
296 |
+
|
297 |
+
if n_times_crossframe_attn_in_self:
|
298 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
299 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
300 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
301 |
+
k = repeat(
|
302 |
+
k[::n_times_crossframe_attn_in_self],
|
303 |
+
"b ... -> (b n) ...",
|
304 |
+
n=n_times_crossframe_attn_in_self,
|
305 |
+
)
|
306 |
+
v = repeat(
|
307 |
+
v[::n_times_crossframe_attn_in_self],
|
308 |
+
"b ... -> (b n) ...",
|
309 |
+
n=n_times_crossframe_attn_in_self,
|
310 |
+
)
|
311 |
+
|
312 |
+
b, _, _ = q.shape
|
313 |
+
q, k, v = map(
|
314 |
+
lambda t: t.unsqueeze(3)
|
315 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
316 |
+
.permute(0, 2, 1, 3)
|
317 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
318 |
+
.contiguous(),
|
319 |
+
(q, k, v),
|
320 |
+
)
|
321 |
+
|
322 |
+
# actually compute the attention, what we cannot get enough of
|
323 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
324 |
+
|
325 |
+
# TODO: Use this directly in the attention operation, as a bias
|
326 |
+
if exists(mask):
|
327 |
+
raise NotImplementedError
|
328 |
+
out = (
|
329 |
+
out.unsqueeze(0)
|
330 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
331 |
+
.permute(0, 2, 1, 3)
|
332 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
333 |
+
)
|
334 |
+
if additional_tokens is not None:
|
335 |
+
# remove additional token
|
336 |
+
out = out[:, n_tokens_to_mask:]
|
337 |
+
return self.to_out(out)
|
338 |
+
|
339 |
+
|
340 |
+
class BasicTransformerBlock(nn.Module):
|
341 |
+
ATTENTION_MODES = {
|
342 |
+
"softmax": CrossAttention, # vanilla attention
|
343 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
344 |
+
}
|
345 |
+
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
dim,
|
349 |
+
n_heads,
|
350 |
+
d_head,
|
351 |
+
dropout=0.0,
|
352 |
+
context_dim=None,
|
353 |
+
gated_ff=True,
|
354 |
+
checkpoint=True,
|
355 |
+
disable_self_attn=False,
|
356 |
+
attn_mode="softmax",
|
357 |
+
sdp_backend=None,
|
358 |
+
):
|
359 |
+
super().__init__()
|
360 |
+
assert attn_mode in self.ATTENTION_MODES
|
361 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
362 |
+
print(
|
363 |
+
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
364 |
+
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
365 |
+
)
|
366 |
+
attn_mode = "softmax"
|
367 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
368 |
+
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
|
369 |
+
if not XFORMERS_IS_AVAILABLE:
|
370 |
+
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
371 |
+
else:
|
372 |
+
print("Falling back to xformers efficient attention.")
|
373 |
+
attn_mode = "softmax-xformers"
|
374 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
375 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
376 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
377 |
+
else:
|
378 |
+
assert sdp_backend is None
|
379 |
+
self.disable_self_attn = disable_self_attn
|
380 |
+
self.attn1 = attn_cls(
|
381 |
+
query_dim=dim,
|
382 |
+
heads=n_heads,
|
383 |
+
dim_head=d_head,
|
384 |
+
dropout=dropout,
|
385 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
386 |
+
backend=sdp_backend,
|
387 |
+
) # is a self-attention if not self.disable_self_attn
|
388 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
389 |
+
self.attn2 = attn_cls(
|
390 |
+
query_dim=dim,
|
391 |
+
context_dim=context_dim,
|
392 |
+
heads=n_heads,
|
393 |
+
dim_head=d_head,
|
394 |
+
dropout=dropout,
|
395 |
+
backend=sdp_backend,
|
396 |
+
) # is self-attn if context is none
|
397 |
+
self.norm1 = nn.LayerNorm(dim)
|
398 |
+
self.norm2 = nn.LayerNorm(dim)
|
399 |
+
self.norm3 = nn.LayerNorm(dim)
|
400 |
+
self.checkpoint = checkpoint
|
401 |
+
if self.checkpoint:
|
402 |
+
print(f"{self.__class__.__name__} is using checkpointing")
|
403 |
+
|
404 |
+
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
405 |
+
kwargs = {"x": x}
|
406 |
+
|
407 |
+
if context is not None:
|
408 |
+
kwargs.update({"context": context})
|
409 |
+
|
410 |
+
if additional_tokens is not None:
|
411 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
412 |
+
|
413 |
+
if n_times_crossframe_attn_in_self:
|
414 |
+
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
|
415 |
+
|
416 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
417 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
418 |
+
|
419 |
+
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
420 |
+
x = (
|
421 |
+
self.attn1(
|
422 |
+
self.norm1(x),
|
423 |
+
context=context if self.disable_self_attn else None,
|
424 |
+
additional_tokens=additional_tokens,
|
425 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
426 |
+
)
|
427 |
+
+ x
|
428 |
+
)
|
429 |
+
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
|
430 |
+
x = self.ff(self.norm3(x)) + x
|
431 |
+
return x
|
432 |
+
|
433 |
+
|
434 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
435 |
+
ATTENTION_MODES = {
|
436 |
+
"softmax": CrossAttention, # vanilla attention
|
437 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
438 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
439 |
+
}
|
440 |
+
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
dim,
|
444 |
+
n_heads,
|
445 |
+
d_head,
|
446 |
+
dropout=0.0,
|
447 |
+
context_dim=None,
|
448 |
+
gated_ff=True,
|
449 |
+
checkpoint=True,
|
450 |
+
attn_mode="softmax",
|
451 |
+
):
|
452 |
+
super().__init__()
|
453 |
+
assert attn_mode in self.ATTENTION_MODES
|
454 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
455 |
+
self.attn1 = attn_cls(
|
456 |
+
query_dim=dim,
|
457 |
+
heads=n_heads,
|
458 |
+
dim_head=d_head,
|
459 |
+
dropout=dropout,
|
460 |
+
context_dim=context_dim,
|
461 |
+
)
|
462 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
463 |
+
self.norm1 = nn.LayerNorm(dim)
|
464 |
+
self.norm2 = nn.LayerNorm(dim)
|
465 |
+
self.checkpoint = checkpoint
|
466 |
+
|
467 |
+
def forward(self, x, context=None):
|
468 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
469 |
+
|
470 |
+
def _forward(self, x, context=None):
|
471 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
472 |
+
x = self.ff(self.norm2(x)) + x
|
473 |
+
return x
|
474 |
+
|
475 |
+
|
476 |
+
class SpatialTransformer(nn.Module):
|
477 |
+
"""
|
478 |
+
Transformer block for image-like data.
|
479 |
+
First, project the input (aka embedding)
|
480 |
+
and reshape to b, t, d.
|
481 |
+
Then apply standard transformer action.
|
482 |
+
Finally, reshape to image
|
483 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
484 |
+
"""
|
485 |
+
|
486 |
+
def __init__(
|
487 |
+
self,
|
488 |
+
in_channels,
|
489 |
+
n_heads,
|
490 |
+
d_head,
|
491 |
+
depth=1,
|
492 |
+
dropout=0.0,
|
493 |
+
context_dim=None,
|
494 |
+
disable_self_attn=False,
|
495 |
+
use_linear=False,
|
496 |
+
attn_type="softmax",
|
497 |
+
use_checkpoint=True,
|
498 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
499 |
+
sdp_backend=None,
|
500 |
+
):
|
501 |
+
super().__init__()
|
502 |
+
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
503 |
+
from omegaconf import ListConfig
|
504 |
+
|
505 |
+
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
506 |
+
context_dim = [context_dim]
|
507 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
508 |
+
if depth != len(context_dim):
|
509 |
+
print(
|
510 |
+
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
511 |
+
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
512 |
+
)
|
513 |
+
# depth does not match context dims.
|
514 |
+
assert all(
|
515 |
+
map(lambda x: x == context_dim[0], context_dim)
|
516 |
+
), "need homogenous context_dim to match depth automatically"
|
517 |
+
context_dim = depth * [context_dim[0]]
|
518 |
+
elif context_dim is None:
|
519 |
+
context_dim = [None] * depth
|
520 |
+
self.in_channels = in_channels
|
521 |
+
inner_dim = n_heads * d_head
|
522 |
+
self.norm = Normalize(in_channels)
|
523 |
+
if not use_linear:
|
524 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
525 |
+
else:
|
526 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
527 |
+
|
528 |
+
self.transformer_blocks = nn.ModuleList(
|
529 |
+
[
|
530 |
+
BasicTransformerBlock(
|
531 |
+
inner_dim,
|
532 |
+
n_heads,
|
533 |
+
d_head,
|
534 |
+
dropout=dropout,
|
535 |
+
context_dim=context_dim[d],
|
536 |
+
disable_self_attn=disable_self_attn,
|
537 |
+
attn_mode=attn_type,
|
538 |
+
checkpoint=use_checkpoint,
|
539 |
+
sdp_backend=sdp_backend,
|
540 |
+
)
|
541 |
+
for d in range(depth)
|
542 |
+
]
|
543 |
+
)
|
544 |
+
if not use_linear:
|
545 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
546 |
+
else:
|
547 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
548 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
549 |
+
self.use_linear = use_linear
|
550 |
+
|
551 |
+
def forward(self, x, context=None):
|
552 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
553 |
+
if not isinstance(context, list):
|
554 |
+
context = [context]
|
555 |
+
b, c, h, w = x.shape
|
556 |
+
x_in = x
|
557 |
+
x = self.norm(x)
|
558 |
+
if not self.use_linear:
|
559 |
+
x = self.proj_in(x)
|
560 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
561 |
+
if self.use_linear:
|
562 |
+
x = self.proj_in(x)
|
563 |
+
for i, block in enumerate(self.transformer_blocks):
|
564 |
+
if i > 0 and len(context) == 1:
|
565 |
+
i = 0 # use same context for each block
|
566 |
+
x = block(x, context=context[i])
|
567 |
+
if self.use_linear:
|
568 |
+
x = self.proj_out(x)
|
569 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
570 |
+
if not self.use_linear:
|
571 |
+
x = self.proj_out(x)
|
572 |
+
return x + x_in
|
sat/sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|
sat/sgm/modules/autoencoding/losses/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"GeneralLPIPSWithDiscriminator",
|
3 |
+
"LatentLPIPS",
|
4 |
+
]
|
5 |
+
|
6 |
+
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
7 |
+
from .lpips import LatentLPIPS
|
8 |
+
from .video_loss import VideoAutoencoderLoss
|
sat/sgm/modules/autoencoding/losses/discriminator_loss.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
from matplotlib import colormaps
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
|
11 |
+
from ....util import default, instantiate_from_config
|
12 |
+
from ..lpips.loss.lpips import LPIPS
|
13 |
+
from ..lpips.model.model import weights_init
|
14 |
+
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
15 |
+
|
16 |
+
|
17 |
+
class GeneralLPIPSWithDiscriminator(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
disc_start: int,
|
21 |
+
logvar_init: float = 0.0,
|
22 |
+
disc_num_layers: int = 3,
|
23 |
+
disc_in_channels: int = 3,
|
24 |
+
disc_factor: float = 1.0,
|
25 |
+
disc_weight: float = 1.0,
|
26 |
+
perceptual_weight: float = 1.0,
|
27 |
+
disc_loss: str = "hinge",
|
28 |
+
scale_input_to_tgt_size: bool = False,
|
29 |
+
dims: int = 2,
|
30 |
+
learn_logvar: bool = False,
|
31 |
+
regularization_weights: Union[None, Dict[str, float]] = None,
|
32 |
+
additional_log_keys: Optional[List[str]] = None,
|
33 |
+
discriminator_config: Optional[Dict] = None,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.dims = dims
|
37 |
+
if self.dims > 2:
|
38 |
+
print(
|
39 |
+
f"running with dims={dims}. This means that for perceptual loss "
|
40 |
+
f"calculation, the LPIPS loss will be applied to each frame "
|
41 |
+
f"independently."
|
42 |
+
)
|
43 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
44 |
+
assert disc_loss in ["hinge", "vanilla"]
|
45 |
+
self.perceptual_loss = LPIPS().eval()
|
46 |
+
self.perceptual_weight = perceptual_weight
|
47 |
+
# output log variance
|
48 |
+
self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
|
49 |
+
self.learn_logvar = learn_logvar
|
50 |
+
|
51 |
+
discriminator_config = default(
|
52 |
+
discriminator_config,
|
53 |
+
{
|
54 |
+
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
55 |
+
"params": {
|
56 |
+
"input_nc": disc_in_channels,
|
57 |
+
"n_layers": disc_num_layers,
|
58 |
+
"use_actnorm": False,
|
59 |
+
},
|
60 |
+
},
|
61 |
+
)
|
62 |
+
|
63 |
+
self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
|
64 |
+
self.discriminator_iter_start = disc_start
|
65 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
66 |
+
self.disc_factor = disc_factor
|
67 |
+
self.discriminator_weight = disc_weight
|
68 |
+
self.regularization_weights = default(regularization_weights, {})
|
69 |
+
|
70 |
+
self.forward_keys = [
|
71 |
+
"optimizer_idx",
|
72 |
+
"global_step",
|
73 |
+
"last_layer",
|
74 |
+
"split",
|
75 |
+
"regularization_log",
|
76 |
+
]
|
77 |
+
|
78 |
+
self.additional_log_keys = set(default(additional_log_keys, []))
|
79 |
+
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
80 |
+
|
81 |
+
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
82 |
+
return self.discriminator.parameters()
|
83 |
+
|
84 |
+
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
85 |
+
if self.learn_logvar:
|
86 |
+
yield self.logvar
|
87 |
+
yield from ()
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
|
91 |
+
# calc logits of real/fake
|
92 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
93 |
+
if len(logits_real.shape) < 4:
|
94 |
+
# Non patch-discriminator
|
95 |
+
return dict()
|
96 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
97 |
+
# -> (b, 1, h, w)
|
98 |
+
|
99 |
+
# parameters for colormapping
|
100 |
+
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
101 |
+
cmap = colormaps["PiYG"] # diverging colormap
|
102 |
+
|
103 |
+
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
104 |
+
"""(b, 1, ...) -> (b, 3, ...)"""
|
105 |
+
logits = (logits + high) / (2 * high)
|
106 |
+
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
107 |
+
# -> (b, 1, ..., 3)
|
108 |
+
logits = torch.from_numpy(logits_np).to(logits.device)
|
109 |
+
return rearrange(logits, "b 1 ... c -> b c ...")
|
110 |
+
|
111 |
+
logits_real = torch.nn.functional.interpolate(
|
112 |
+
logits_real,
|
113 |
+
size=inputs.shape[-2:],
|
114 |
+
mode="nearest",
|
115 |
+
antialias=False,
|
116 |
+
)
|
117 |
+
logits_fake = torch.nn.functional.interpolate(
|
118 |
+
logits_fake,
|
119 |
+
size=reconstructions.shape[-2:],
|
120 |
+
mode="nearest",
|
121 |
+
antialias=False,
|
122 |
+
)
|
123 |
+
|
124 |
+
# alpha value of logits for overlay
|
125 |
+
alpha_real = torch.abs(logits_real) / high
|
126 |
+
alpha_fake = torch.abs(logits_fake) / high
|
127 |
+
# -> (b, 1, h, w) in range [0, 0.5]
|
128 |
+
# alpha value of lines don't really matter, since the values are the same
|
129 |
+
# for both images and logits anyway
|
130 |
+
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
131 |
+
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
132 |
+
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
133 |
+
# -> (1, h, w)
|
134 |
+
# blend logits and images together
|
135 |
+
|
136 |
+
# prepare logits for plotting
|
137 |
+
logits_real = to_colormap(logits_real)
|
138 |
+
logits_fake = to_colormap(logits_fake)
|
139 |
+
# resize logits
|
140 |
+
# -> (b, 3, h, w)
|
141 |
+
|
142 |
+
# make some grids
|
143 |
+
# add all logits to one plot
|
144 |
+
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
145 |
+
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
146 |
+
# I just love how torchvision calls the number of columns `nrow`
|
147 |
+
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
148 |
+
# -> (3, h, w)
|
149 |
+
|
150 |
+
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
151 |
+
grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
|
152 |
+
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
153 |
+
# -> (3, h, w) in range [0, 1]
|
154 |
+
|
155 |
+
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
156 |
+
|
157 |
+
# Create labeled colorbar
|
158 |
+
dpi = 100
|
159 |
+
height = 128 / dpi
|
160 |
+
width = grid_logits.shape[2] / dpi
|
161 |
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
162 |
+
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
163 |
+
plt.colorbar(
|
164 |
+
img,
|
165 |
+
cax=ax,
|
166 |
+
orientation="horizontal",
|
167 |
+
fraction=0.9,
|
168 |
+
aspect=width / height,
|
169 |
+
pad=0.0,
|
170 |
+
)
|
171 |
+
img.set_visible(False)
|
172 |
+
fig.tight_layout()
|
173 |
+
fig.canvas.draw()
|
174 |
+
# manually convert figure to numpy
|
175 |
+
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
176 |
+
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
177 |
+
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
178 |
+
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
179 |
+
|
180 |
+
# Add colorbar to plot
|
181 |
+
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
182 |
+
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
183 |
+
return {
|
184 |
+
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
185 |
+
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
186 |
+
}
|
187 |
+
|
188 |
+
def calculate_adaptive_weight(
|
189 |
+
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
190 |
+
) -> torch.Tensor:
|
191 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
192 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
193 |
+
|
194 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
195 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
196 |
+
d_weight = d_weight * self.discriminator_weight
|
197 |
+
return d_weight
|
198 |
+
|
199 |
+
def forward(
|
200 |
+
self,
|
201 |
+
inputs: torch.Tensor,
|
202 |
+
reconstructions: torch.Tensor,
|
203 |
+
*, # added because I changed the order here
|
204 |
+
regularization_log: Dict[str, torch.Tensor],
|
205 |
+
optimizer_idx: int,
|
206 |
+
global_step: int,
|
207 |
+
last_layer: torch.Tensor,
|
208 |
+
split: str = "train",
|
209 |
+
weights: Union[None, float, torch.Tensor] = None,
|
210 |
+
) -> Tuple[torch.Tensor, dict]:
|
211 |
+
if self.scale_input_to_tgt_size:
|
212 |
+
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
|
213 |
+
|
214 |
+
if self.dims > 2:
|
215 |
+
inputs, reconstructions = map(
|
216 |
+
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
217 |
+
(inputs, reconstructions),
|
218 |
+
)
|
219 |
+
|
220 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
221 |
+
if self.perceptual_weight > 0:
|
222 |
+
frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
|
223 |
+
|
224 |
+
from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
|
225 |
+
|
226 |
+
input_frames = pick_video_frame(inputs, frame_indices)
|
227 |
+
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
228 |
+
|
229 |
+
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
230 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
231 |
+
|
232 |
+
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
233 |
+
|
234 |
+
# now the GAN part
|
235 |
+
if optimizer_idx == 0:
|
236 |
+
# generator update
|
237 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
238 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
239 |
+
g_loss = -torch.mean(logits_fake)
|
240 |
+
if self.training:
|
241 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
242 |
+
else:
|
243 |
+
d_weight = torch.tensor(1.0)
|
244 |
+
else:
|
245 |
+
d_weight = torch.tensor(0.0)
|
246 |
+
g_loss = torch.tensor(0.0, requires_grad=True)
|
247 |
+
|
248 |
+
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
249 |
+
log = dict()
|
250 |
+
for k in regularization_log:
|
251 |
+
if k in self.regularization_weights:
|
252 |
+
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
253 |
+
if k in self.additional_log_keys:
|
254 |
+
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
255 |
+
|
256 |
+
log.update(
|
257 |
+
{
|
258 |
+
f"{split}/loss/total": loss.clone().detach().mean(),
|
259 |
+
f"{split}/loss/nll": nll_loss.detach().mean(),
|
260 |
+
f"{split}/loss/rec": rec_loss.detach().mean(),
|
261 |
+
f"{split}/loss/percep": p_loss.detach().mean(),
|
262 |
+
f"{split}/loss/rec": rec_loss.detach().mean(),
|
263 |
+
f"{split}/loss/g": g_loss.detach().mean(),
|
264 |
+
f"{split}/scalars/logvar": self.logvar.detach(),
|
265 |
+
f"{split}/scalars/d_weight": d_weight.detach(),
|
266 |
+
}
|
267 |
+
)
|
268 |
+
|
269 |
+
return loss, log
|
270 |
+
elif optimizer_idx == 1:
|
271 |
+
# second pass for discriminator update
|
272 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
273 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
274 |
+
|
275 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
276 |
+
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
277 |
+
else:
|
278 |
+
d_loss = torch.tensor(0.0, requires_grad=True)
|
279 |
+
|
280 |
+
log = {
|
281 |
+
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
282 |
+
f"{split}/logits/real": logits_real.detach().mean(),
|
283 |
+
f"{split}/logits/fake": logits_fake.detach().mean(),
|
284 |
+
}
|
285 |
+
return d_loss, log
|
286 |
+
else:
|
287 |
+
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
288 |
+
|
289 |
+
def get_nll_loss(
|
290 |
+
self,
|
291 |
+
rec_loss: torch.Tensor,
|
292 |
+
weights: Optional[Union[float, torch.Tensor]] = None,
|
293 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
294 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
295 |
+
weighted_nll_loss = nll_loss
|
296 |
+
if weights is not None:
|
297 |
+
weighted_nll_loss = weights * nll_loss
|
298 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
299 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
300 |
+
|
301 |
+
return nll_loss, weighted_nll_loss
|
sat/sgm/modules/autoencoding/losses/lpips.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ....util import default, instantiate_from_config
|
5 |
+
from ..lpips.loss.lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
class LatentLPIPS(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
decoder_config,
|
12 |
+
perceptual_weight=1.0,
|
13 |
+
latent_weight=1.0,
|
14 |
+
scale_input_to_tgt_size=False,
|
15 |
+
scale_tgt_to_input_size=False,
|
16 |
+
perceptual_weight_on_inputs=0.0,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
20 |
+
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
21 |
+
self.init_decoder(decoder_config)
|
22 |
+
self.perceptual_loss = LPIPS().eval()
|
23 |
+
self.perceptual_weight = perceptual_weight
|
24 |
+
self.latent_weight = latent_weight
|
25 |
+
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
26 |
+
|
27 |
+
def init_decoder(self, config):
|
28 |
+
self.decoder = instantiate_from_config(config)
|
29 |
+
if hasattr(self.decoder, "encoder"):
|
30 |
+
del self.decoder.encoder
|
31 |
+
|
32 |
+
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
33 |
+
log = dict()
|
34 |
+
loss = (latent_inputs - latent_predictions) ** 2
|
35 |
+
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
36 |
+
image_reconstructions = None
|
37 |
+
if self.perceptual_weight > 0.0:
|
38 |
+
image_reconstructions = self.decoder.decode(latent_predictions)
|
39 |
+
image_targets = self.decoder.decode(latent_inputs)
|
40 |
+
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
|
41 |
+
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
42 |
+
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
43 |
+
|
44 |
+
if self.perceptual_weight_on_inputs > 0.0:
|
45 |
+
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
|
46 |
+
if self.scale_input_to_tgt_size:
|
47 |
+
image_inputs = torch.nn.functional.interpolate(
|
48 |
+
image_inputs,
|
49 |
+
image_reconstructions.shape[2:],
|
50 |
+
mode="bicubic",
|
51 |
+
antialias=True,
|
52 |
+
)
|
53 |
+
elif self.scale_tgt_to_input_size:
|
54 |
+
image_reconstructions = torch.nn.functional.interpolate(
|
55 |
+
image_reconstructions,
|
56 |
+
image_inputs.shape[2:],
|
57 |
+
mode="bicubic",
|
58 |
+
antialias=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
|
62 |
+
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
63 |
+
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
64 |
+
return loss, log
|