eysho commited on
Commit
e276be2
1 Parent(s): 7de1cfa

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .github/ISSUE_TEMPLATE/bug_report.yaml +51 -0
  3. .github/ISSUE_TEMPLATE/feature-request.yaml +34 -0
  4. .github/PULL_REQUEST_TEMPLATE/pr_template.md +34 -0
  5. .gitignore +9 -0
  6. LICENSE +201 -0
  7. Model_License +71 -0
  8. README.md +157 -6
  9. README_zh.md +149 -0
  10. gradio_demo.py +254 -0
  11. inference/cli_demo.py +127 -0
  12. inference/cli_vae_demo.py +103 -0
  13. inference/convert_demo.py +92 -0
  14. inference/web_demo.py +214 -0
  15. pyproject.toml +27 -0
  16. requirements.txt +11 -0
  17. resources/CogVideoX.pdf +3 -0
  18. resources/WECHAT.md +7 -0
  19. resources/contribute.md +50 -0
  20. resources/contribute_zh.md +45 -0
  21. resources/logo.svg +298 -0
  22. resources/videos/1.mp4 +0 -0
  23. resources/videos/2.mp4 +3 -0
  24. resources/videos/3.mp4 +0 -0
  25. resources/videos/4.mp4 +0 -0
  26. resources/web_demo.png +3 -0
  27. resources/wechat.jpg +0 -0
  28. sat/README.md +182 -0
  29. sat/README_zh.md +180 -0
  30. sat/arguments.py +281 -0
  31. sat/configs/cogvideox_2b_infer.yaml +166 -0
  32. sat/configs/cogvideox_2b_sft.yaml +225 -0
  33. sat/configs/test.txt +3 -0
  34. sat/data_video.py +451 -0
  35. sat/diffusion_video.py +318 -0
  36. sat/dit_video_concat.py +858 -0
  37. sat/finetune.sh +12 -0
  38. sat/inference.sh +12 -0
  39. sat/requirements.txt +17 -0
  40. sat/sample_video.py +236 -0
  41. sat/sgm/__init__.py +4 -0
  42. sat/sgm/lr_scheduler.py +110 -0
  43. sat/sgm/models/__init__.py +1 -0
  44. sat/sgm/models/autoencoder.py +630 -0
  45. sat/sgm/modules/__init__.py +6 -0
  46. sat/sgm/modules/attention.py +572 -0
  47. sat/sgm/modules/autoencoding/__init__.py +0 -0
  48. sat/sgm/modules/autoencoding/losses/__init__.py +8 -0
  49. sat/sgm/modules/autoencoding/losses/discriminator_loss.py +301 -0
  50. sat/sgm/modules/autoencoding/losses/lpips.py +64 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resources/CogVideoX.pdf filter=lfs diff=lfs merge=lfs -text
37
+ resources/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ resources/web_demo.png filter=lfs diff=lfs merge=lfs -text
39
+ tools/caption/assests/cogvlm2-video-example.png filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F41B Bug Report"
2
+ description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
3
+ body:
4
+ - type: textarea
5
+ id: system-info
6
+ attributes:
7
+ label: System Info / 系統信息
8
+ description: Your operating environment / 您的运行环境信息
9
+ placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
10
+ validations:
11
+ required: true
12
+
13
+ - type: checkboxes
14
+ id: information-scripts-examples
15
+ attributes:
16
+ label: Information / 问题信息
17
+ description: 'The problem arises when using: / 问题出现在'
18
+ options:
19
+ - label: "The official example scripts / 官方的示例脚本"
20
+ - label: "My own modified scripts / 我自己修改的脚本和任务"
21
+
22
+ - type: textarea
23
+ id: reproduction
24
+ validations:
25
+ required: true
26
+ attributes:
27
+ label: Reproduction / 复现过程
28
+ description: |
29
+ Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
30
+ If you have code snippets, error messages, stack traces, please provide them here as well.
31
+ Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
32
+ Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
33
+
34
+ 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
35
+ 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
36
+ 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
37
+ 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
38
+ placeholder: |
39
+ Steps to reproduce the behavior/复现Bug的步骤:
40
+
41
+ 1.
42
+ 2.
43
+ 3.
44
+
45
+ - type: textarea
46
+ id: expected-behavior
47
+ validations:
48
+ required: true
49
+ attributes:
50
+ label: Expected behavior / 期待表现
51
+ description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
.github/ISSUE_TEMPLATE/feature-request.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F680 Feature request"
2
+ description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
3
+ labels: [ "feature" ]
4
+ body:
5
+ - type: textarea
6
+ id: feature-request
7
+ validations:
8
+ required: true
9
+ attributes:
10
+ label: Feature request / 功能建议
11
+ description: |
12
+ A brief description of the functional proposal. Links to corresponding papers and code are desirable.
13
+ 对功能建议的简述。最好提供对应的论文和代码链接。
14
+
15
+ - type: textarea
16
+ id: motivation
17
+ validations:
18
+ required: true
19
+ attributes:
20
+ label: Motivation / 动机
21
+ description: |
22
+ Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
23
+ 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
24
+
25
+ - type: textarea
26
+ id: contribution
27
+ validations:
28
+ required: true
29
+ attributes:
30
+ label: Your contribution / 您的贡献
31
+ description: |
32
+
33
+ Your PR link or any other link you can help with.
34
+ 您的PR链接或者其他您能提供帮助的链接。
.github/PULL_REQUEST_TEMPLATE/pr_template.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Raise valuable PR / 提出有价值的PR
2
+
3
+ ## Caution / 注意事项:
4
+ Users should keep the following points in mind when submitting PRs:
5
+
6
+ 1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
7
+ 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
8
+
9
+ 用户在提交PR时候应该注意以下几点:
10
+
11
+ 1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
12
+ 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
13
+
14
+ ## 不应该提出的PR / PRs that should not be proposed
15
+
16
+ If a developer proposes a PR about any of the following, it may be closed or Rejected.
17
+
18
+ 1. those that don't describe improvement options.
19
+ 2. multiple issues of different types combined in one PR.
20
+ 3. The proposed PR is highly duplicative of already existing PRs.
21
+
22
+ 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
23
+
24
+ 1. 没有说明改进方案的。
25
+ 2. 多个不同类型的问题合并在一个PR中的。
26
+ 3. 提出的PR与已经存在的PR高度重复的。
27
+
28
+
29
+ # 检查您的PR
30
+ - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
31
+ - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
32
+ - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
33
+ - [ ] Did you write new required tests? / 您是否编写了新的必要测试?
34
+ - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ output/
2
+ *__pycache__/
3
+ samples*/
4
+ runs/
5
+ checkpoints/
6
+ master_ip
7
+ logs/
8
+ *.DS_Store
9
+ .idea
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 CogVideo Model Team @ Zhipu AI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Model_License ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The CogVideoX License
2
+
3
+ 1. Definitions
4
+
5
+ “Licensor” means the CogVideoX Model Team that distributes its Software.
6
+
7
+ “Software” means the CogVideoX model parameters made available under this license.
8
+
9
+ 2. License Grant
10
+
11
+ Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
12
+ This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
13
+ Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
14
+ If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
15
+ The above copyright statement and this license statement should be included in all copies or significant portions of this software.
16
+
17
+ 3. Restriction
18
+
19
+ You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
20
+
21
+ You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
22
+
23
+ 4. Disclaimer
24
+
25
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26
+
27
+ 5. Limitation of Liability
28
+
29
+ EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
30
+
31
+ 6. Dispute Resolution
32
+
33
+ This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
34
+
35
+ Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
36
+
37
+ 1. 定义
38
+
39
+ “许可方”是指分发其软件的 CogVideoX 模型团队。
40
+
41
+ “软件”是指根据本许可提供的 CogVideoX 模型参数。
42
+
43
+ 2. 许可授予
44
+
45
+ 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
46
+ 本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
47
+
48
+ 经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
49
+ 在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
50
+ 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
51
+
52
+ 3.限制
53
+
54
+ 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
55
+
56
+ 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
57
+
58
+ 4.免责声明
59
+
60
+ 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
61
+ 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
62
+
63
+ 5. 责任限制
64
+
65
+ 除适用��律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
66
+
67
+ 6.争议解决
68
+
69
+ 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
70
+
71
+ 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
README.md CHANGED
@@ -1,12 +1,163 @@
1
  ---
2
  title: CogVideo
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: CogVideo
3
+ app_file: gradio_demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.41.0
 
 
6
  ---
7
+ # CogVideo && CogVideoX
8
 
9
+ [中文阅读](./README_zh.md)
10
+
11
+ <div align="center">
12
+ <img src=resources/logo.svg width="50%"/>
13
+ </div>
14
+ <p align="center">
15
+ 🤗 Experience on <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a>
16
+ </p>
17
+ <p align="center">
18
+ 📚 Check here to view <a href="resources/CogVideoX.pdf" target="_blank">Paper</a>
19
+ </p>
20
+ <p align="center">
21
+ 👋 Join our <a href="resources/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
22
+ </p>
23
+ <p align="center">
24
+ 📍 Visit <a href="https://chatglm.cn/video?fr=osm_cogvideox">清影</a> and <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">API Platform</a> to experience larger-scale commercial video generation models.
25
+ </p>
26
+
27
+ ## Update and News
28
+
29
+ - 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct
30
+ the video almost losslessly.
31
+ - 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**,the first model in the CogVideoX series of video
32
+ generation models.
33
+ - 🌱 **Source**: ```2022/5/19```: We have open-sourced CogVideo (now you can see in `CogVideo` branch),the **first** open-sourced pretrained text-to-video model, and you can check [ICLR'23 CogVideo Paper](https://arxiv.org/abs/2205.15868) for technical details.
34
+
35
+ **More powerful models with larger parameter sizes are on the way~ Stay tuned!**
36
+
37
+ ## CogVideoX-2B Gallery
38
+
39
+ <div align="center">
40
+ <video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
41
+ <p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
42
+ </div>
43
+
44
+ <div align="center">
45
+ <video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
46
+ <p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
47
+ </div>
48
+
49
+ <div align="center">
50
+ <video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
51
+ <p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
52
+ </div>
53
+
54
+ <div align="center">
55
+ <video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
56
+ <p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
57
+ </div>
58
+
59
+ ## Model Introduction
60
+
61
+ CogVideoX is an open-source version of the video generation model, which is homologous
62
+ to [清影](https://chatglm.cn/video?fr=osm_cogvideox).
63
+
64
+ The table below shows the list of video generation models we currently provide,
65
+ along with related basic information:
66
+
67
+ | Model Name | CogVideoX-2B |
68
+ |-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
69
+ | Prompt Language | English |
70
+ | GPU Memory Required for Inference (FP16) | 18GB if using [SAT](https://github.com/THUDM/SwissArmyTransformer); 36GB if using diffusers (will be optimized before the PR is merged) |
71
+ | GPU Memory Required for Fine-tuning(bs=1) | 40GB |
72
+ | Prompt Max Length | 226 Tokens |
73
+ | Video Length | 6 seconds |
74
+ | Frames Per Second | 8 frames |
75
+ | Resolution | 720 * 480 |
76
+ | Quantized Inference | Not Supported |
77
+ | Multi-card Inference | Not Supported |
78
+ | Download Link (HF diffusers Model) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) [💫 WiseModel](https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b) |
79
+ | Download Link (SAT Model) | [SAT](./sat/README.md) |
80
+
81
+ ## Project Structure
82
+
83
+ This open-source repository will guide developers to quickly get started with the basic usage and fine-tuning examples
84
+ of the **CogVideoX** open-source model.
85
+
86
+ ### Inference
87
+
88
+ + [cli_demo](inference/cli_demo.py): A more detailed explanation of the inference code, mentioning the significance of common parameters.
89
+ + [cli_vae_demo](inference/cli_vae_demo.py): Executing the VAE inference code alone currently requires 71GB of memory, but it will be optimized in the future.
90
+ + [convert_demo](inference/convert_demo.py): How to convert user input into a format suitable for CogVideoX. Because CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as GPT, Gemini, etc.
91
+ + [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model to generate videos.
92
+
93
+ <div style="text-align: center;">
94
+ <img src="resources/web_demo.png" style="width: 100%; height: auto;" />
95
+ </div>
96
+
97
+ ### sat
98
+
99
+ + [sat_demo](sat/README.md): Contains the inference code and fine-tuning code of SAT weights. It is
100
+ recommended to improve based on the CogVideoX model structure. Innovative researchers use this code to better perform
101
+ rapid stacking and development.
102
+
103
+ ### Tools
104
+
105
+ This folder contains some tools for model conversion / caption generation, etc.
106
+
107
+ + [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): Convert SAT model weights to Huggingface model weights.
108
+ + [caption_demo](tools/caption): Caption tool, a model that understands videos and outputs them in text.
109
+
110
+ ## Project Plan
111
+
112
+ - [x] Open source CogVideoX model
113
+ - [x] Open source 3D Causal VAE used in CogVideoX.
114
+ - [x] CogVideoX model inference example (CLI / Web Demo)
115
+ - [x] CogVideoX online experience demo (Huggingface Space)
116
+ - [x] CogVideoX open source model API interface example (Huggingface)
117
+ - [x] CogVideoX model fine-tuning example (SAT)
118
+ - [ ] CogVideoX model fine-tuning example (Huggingface / SAT)
119
+ - [ ] Open source CogVideoX-Pro (adapted for CogVideoX-2B suite)
120
+ - [x] Release CogVideoX technical report
121
+
122
+ We welcome your contributions. You can click [here](resources/contribute.md) for more information.
123
+
124
+ ## Model License
125
+
126
+ The code in this repository is released under the [Apache 2.0 License](LICENSE).
127
+
128
+ The model weights and implementation code are released under the [CogVideoX LICENSE](MODEL_LICENSE).
129
+
130
+ ## CogVideo(ICLR'23)
131
+ The official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) is on the [CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)
132
+
133
+ **CogVideo is able to generate relatively high-frame-rate videos.**
134
+ A 4-second clip of 32 frames is shown below.
135
+
136
+ ![High-frame-rate sample](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/appendix-sample-highframerate.png)
137
+
138
+ ![Intro images](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/intro-image.png)
139
+ <div align="center">
140
+ <video src="https://github.com/user-attachments/assets/2fa19651-e925-4a2a-b8d6-b3f216d490ba" width="80%" controls autoplay></video>
141
+ </div>
142
+
143
+
144
+ The demo for CogVideo is at [https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.*
145
+
146
+
147
+ ## Citation
148
+
149
+ 🌟 If you find our work helpful, please leave us a star and cite our paper.
150
+
151
+ ```
152
+ @article{yang2024cogvideox,
153
+ title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
154
+ author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
155
+ year={2024},
156
+ }
157
+ @article{hong2022cogvideo,
158
+ title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
159
+ author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
160
+ journal={arXiv preprint arXiv:2205.15868},
161
+ year={2022}
162
+ }
163
+ ```
README_zh.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideo && CogVideoX
2
+
3
+ [Read this in English.](./README_zh)
4
+
5
+
6
+ <div align="center">
7
+ <img src=resources/logo.svg width="50%"/>
8
+ </div>
9
+ <p align="center">
10
+ 🤗 在 <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a> 体验视频生成模型
11
+ </p>
12
+ <p align="center">
13
+ 📚 查看 <a href="resources/CogVideoX.pdf" target="_blank">论文</a>
14
+ </p>
15
+ <p align="center">
16
+ 👋 加入我们的 <a href="resources/WECHAT.md" target="_blank">微信</a> 和 <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
17
+ </p>
18
+ <p align="center">
19
+ 📍 前往<a href="https://chatglm.cn/video?fr=osm_cogvideox"> 清影</a> 和 <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9"> API平台</a> 体验更大规模的商业版视频生成模型。
20
+ </p>
21
+
22
+ ## 项目更新
23
+
24
+ - 🔥 **News**: ``2024/8/6``: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
25
+ - 🔥 **News**: ``2024/8/6``: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**。
26
+ - 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于 Transformer 的大型文本生成视频模型,您可以访问 [ICLR'23 论文](https://arxiv.org/abs/2205.15868) 查看技术细节。
27
+ **性能更强,参数量更大的模型正在到来的路上~,欢迎关注**
28
+
29
+ ## CogVideoX-2B 视频作品
30
+
31
+ <div align="center">
32
+ <video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
33
+ <p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
34
+ </div>
35
+
36
+ <div align="center">
37
+ <video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
38
+ <p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
39
+ </div>
40
+
41
+ <div align="center">
42
+ <video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
43
+ <p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
44
+ </div>
45
+
46
+ <div align="center">
47
+ <video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
48
+ <p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
49
+ </div>
50
+
51
+ ## 模型介绍
52
+
53
+ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源版本视频生成模型。
54
+
55
+ 下表战展示目前我们提供的视频生成模型列表,以及相关基础信息:
56
+
57
+ | 模型名字 | CogVideoX-2B |
58
+ |---------------------|--------------------------------------------------------------------------------------------------------------------------------------|
59
+ | 提示词语言 | English |
60
+ | 推理显存消耗 (FP-16) | 36GB using diffusers (will be optimized before the PR is merged) and 18GB using [SAT](https://github.com/THUDM/SwissArmyTransformer) |
61
+ | 微调显存消耗 (bs=1) | 42GB |
62
+ | 提示词长度上限 | 226 Tokens |
63
+ | 视频长度 | 6 seconds |
64
+ | 帧率(每秒) | 8 frames |
65
+ | 视频分辨率 | 720 * 480 |
66
+ | 量化推理 | 不支持 |
67
+ | 多卡推理 | 不支持 |
68
+ | 下载地址 (Diffusers 模型) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) |
69
+ | 下载地址 (SAT 模型) | [SAT](./sat/README_zh.md) |
70
+
71
+ ## 项目结构
72
+
73
+ 本开源仓库将带领开发者快速上手 **CogVideoX** 开源模型的基础调用方式、微调示例。
74
+
75
+ ### inference
76
+
77
+ + [cli_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
78
+ + [cli_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码,目前需要71GB显存,将来会优化。
79
+ + [convert_demo](inference/convert_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。因为CogVideoX是在长文本上训练的,所以我们需要把输入文本的分布通过LLM转换为和训练一致的长文本。脚本中默认使用GLM4,也可以替换为GPT、Gemini等任意大语言模型。
80
+ + [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
81
+
82
+ <div style="text-align: center;">
83
+ <img src="resources/web_demo.png" style="width: 100%; height: auto;" />
84
+ </div>
85
+
86
+ ### sat
87
+
88
+ + [sat_demo](sat/README_zh.md): 包含了 SAT 权重的推理代码和微调代码,推荐基于 CogVideoX
89
+ 模型结构进行改进,创新的研究者使用改代码以更好的进行快速的堆叠和开发。
90
+
91
+ ### tools
92
+
93
+ 本文件夹包含了一些工具,用于模型的转换 / Caption 等工作。
94
+
95
+ + [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): 将 SAT 模型权重转换为 Huggingface 模型权重。
96
+ + [caption_demo](tools/caption/README_zh.md): Caption 工具,对视频理解并用文字输出的模型。
97
+
98
+ ## 项目规划
99
+
100
+ - [x] CogVideoX 模型开源
101
+ - [x] CogVideoX 模型推理示例 (CLI / Web Demo)
102
+ - [x] CogVideoX 在线体验示例 (Huggingface Space)
103
+ - [x] CogVideoX 开源模型API接口示例 (Huggingface)
104
+ - [x] CogVideoX 模型微调示例 (SAT)
105
+ - [ ] CogVideoX 模型微调示例 (Huggingface / SAT)
106
+ - [ ] CogVideoX-Pro 开源(适配 CogVideoX-2B 套件)
107
+ - [ ] CogVideoX 技术报告公开
108
+
109
+ 我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
110
+
111
+ ## 模型协议
112
+
113
+ 本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
114
+
115
+ 本模型权重和模型实现代码根据 [CogVideoX LICENSE](MODEL_LICENSE) 许可证发布。
116
+
117
+ ## CogVideo(ICLR'23)
118
+ [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) 的官方repo位于[CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)。
119
+
120
+ **CogVideo可以生成高帧率视频,下面展示了一个32帧的4秒视频。**
121
+
122
+ ![High-frame-rate sample](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/appendix-sample-highframerate.png)
123
+
124
+ ![Intro images](https://raw.githubusercontent.com/THUDM/CogVideo/CogVideo/assets/intro-image.png)
125
+
126
+
127
+ <div align="center">
128
+ <video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
129
+ </div>
130
+
131
+ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。*原始输入为中文。*
132
+
133
+ ## 引用
134
+
135
+ 🌟 如果您发现我们的工作有所帮助,欢迎引用我们的文章,留下宝贵的stars
136
+
137
+ ```
138
+ @article{yang2024cogvideox,
139
+ title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
140
+ author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
141
+ year={2024},
142
+ }
143
+ @article{hong2022cogvideo,
144
+ title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
145
+ author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
146
+ journal={arXiv preprint arXiv:2205.15868},
147
+ year={2022}
148
+ }
149
+ ```
gradio_demo.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import threading
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import CogVideoXPipeline
10
+ from datetime import datetime, timedelta
11
+ from openai import OpenAI
12
+ import spaces
13
+ import imageio
14
+ import moviepy.editor as mp
15
+ from typing import List, Union
16
+ import PIL
17
+
18
+ dtype = torch.bfloat16
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
21
+
22
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
23
+
24
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
25
+ There are a few rules to follow:
26
+
27
+ You will only ever output a single video description per user request.
28
+
29
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
30
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
31
+
32
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
33
+ """
34
+
35
+
36
+ def export_to_video_imageio(
37
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
38
+ ) -> str:
39
+ """
40
+ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
41
+ """
42
+ if output_video_path is None:
43
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
44
+
45
+ if isinstance(video_frames[0], PIL.Image.Image):
46
+ video_frames = [np.array(frame) for frame in video_frames]
47
+
48
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
49
+ for frame in video_frames:
50
+ writer.append_data(frame)
51
+
52
+ return output_video_path
53
+
54
+
55
+ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
56
+ if not os.environ.get("OPENAI_API_KEY"):
57
+ return prompt
58
+ client = OpenAI()
59
+ text = prompt.strip()
60
+
61
+ for i in range(retry_times):
62
+ response = client.chat.completions.create(
63
+ messages=[
64
+ {"role": "system", "content": sys_prompt},
65
+ {"role": "user",
66
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
67
+ {"role": "assistant",
68
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
69
+ {"role": "user",
70
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
71
+ {"role": "assistant",
72
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
73
+ {"role": "user",
74
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
75
+ {"role": "assistant",
76
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
77
+ {"role": "user",
78
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
79
+ ],
80
+ model="glm-4-0520",
81
+ temperature=0.01,
82
+ top_p=0.7,
83
+ stream=False,
84
+ max_tokens=250,
85
+ )
86
+ if response.choices:
87
+ return response.choices[0].message.content
88
+ return prompt
89
+
90
+
91
+ @spaces.GPU(duration=240)
92
+ def infer(
93
+ prompt: str,
94
+ num_inference_steps: int,
95
+ guidance_scale: float,
96
+ progress=gr.Progress(track_tqdm=True)
97
+ ):
98
+ torch.cuda.empty_cache()
99
+
100
+ prompt_embeds, _ = pipe.encode_prompt(
101
+ prompt=prompt,
102
+ negative_prompt=None,
103
+ do_classifier_free_guidance=True,
104
+ num_videos_per_prompt=1,
105
+ max_sequence_length=226,
106
+ device=device,
107
+ dtype=dtype,
108
+ )
109
+
110
+ video = pipe(
111
+ num_inference_steps=num_inference_steps,
112
+ guidance_scale=guidance_scale,
113
+ prompt_embeds=prompt_embeds,
114
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds),
115
+ ).frames[0]
116
+
117
+
118
+ return video
119
+
120
+
121
+ def save_video(tensor):
122
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
123
+ video_path = f"./output/{timestamp}.mp4"
124
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
125
+ export_to_video_imageio(tensor[1:], video_path)
126
+ return video_path
127
+
128
+ def convert_to_gif(video_path):
129
+ clip = mp.VideoFileClip(video_path)
130
+ clip = clip.set_fps(8)
131
+ clip = clip.resize(height=240)
132
+ gif_path = video_path.replace('.mp4', '.gif')
133
+ clip.write_gif(gif_path, fps=8)
134
+ return gif_path
135
+
136
+
137
+ def delete_old_files():
138
+ while True:
139
+ now = datetime.now()
140
+ cutoff = now - timedelta(minutes=10)
141
+ output_dir = './output'
142
+ for filename in os.listdir(output_dir):
143
+ file_path = os.path.join(output_dir, filename)
144
+ if os.path.isfile(file_path):
145
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
146
+ if file_mtime < cutoff:
147
+ os.remove(file_path)
148
+ time.sleep(600) # Sleep for 10 minutes
149
+
150
+
151
+ threading.Thread(target=delete_old_files, daemon=True).start()
152
+
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("""
155
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
156
+ CogVideoX-2B Huggingface Space🤗
157
+ </div>
158
+ <div style="text-align: center;">
159
+ <a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 Model Hub</a> |
160
+ <a href="https://github.com/THUDM/CogVideo">🌐 Github</a>
161
+ </div>
162
+
163
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
164
+ ⚠️ This demo is for academic research and experiential use only.
165
+ Users should strictly adhere to local laws and ethics.
166
+ </div>
167
+ """)
168
+ with gr.Row():
169
+ with gr.Column():
170
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
171
+ with gr.Row():
172
+ gr.Markdown(
173
+ "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
174
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
175
+
176
+ with gr.Column():
177
+ gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
178
+ "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
179
+ "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
180
+ with gr.Row():
181
+ num_inference_steps = gr.Number(label="Inference Steps", value=50)
182
+ guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
183
+ generate_button = gr.Button("🎬 Generate Video")
184
+
185
+ with gr.Column():
186
+ video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
187
+ with gr.Row():
188
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
189
+ download_gif_button = gr.File(label="📥 Download GIF", visible=False)
190
+
191
+ gr.Markdown("""
192
+ <table border="1" style="width: 100%; text-align: left; margin-top: 20px;">
193
+ <tr>
194
+ <th>Prompt</th>
195
+ <th>Video URL</th>
196
+ <th>Inference Steps</th>
197
+ <th>Guidance Scale</th>
198
+ </tr>
199
+ <tr>
200
+ <td>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</td>
201
+ <td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/1.mp4">Video 1</a></td>
202
+ <td>50</td>
203
+ <td>6</td>
204
+ </tr>
205
+ <tr>
206
+ <td>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</td>
207
+ <td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/2.mp4">Video 2</a></td>
208
+ <td>50</td>
209
+ <td>6</td>
210
+ </tr>
211
+ <tr>
212
+ <td>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</td>
213
+ <td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/3.mp4">Video 3</a></td>
214
+ <td>50</td>
215
+ <td>6</td>
216
+ </tr>
217
+ <tr>
218
+ <td>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</td>
219
+ <td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/4.mp4">Video 4</a></td>
220
+ <td>50</td>
221
+ <td>6</td>
222
+ </tr>
223
+ </table>
224
+ """)
225
+
226
+
227
+ def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
228
+ tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
229
+ video_path = save_video(tensor)
230
+ video_update = gr.update(visible=True, value=video_path)
231
+ gif_path = convert_to_gif(video_path)
232
+ gif_update = gr.update(visible=True, value=gif_path)
233
+
234
+ return video_path, video_update, gif_update
235
+
236
+
237
+ def enhance_prompt_func(prompt):
238
+ return convert_prompt(prompt, retry_times=1)
239
+
240
+
241
+ generate_button.click(
242
+ generate,
243
+ inputs=[prompt, num_inference_steps, guidance_scale],
244
+ outputs=[video_output, download_video_button, download_gif_button]
245
+ )
246
+
247
+ enhance_button.click(
248
+ enhance_prompt_func,
249
+ inputs=[prompt],
250
+ outputs=[prompt]
251
+ )
252
+
253
+ if __name__ == "__main__":
254
+ demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
inference/cli_demo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
3
+
4
+ Note:
5
+ This script requires the `diffusers>=0.30.0` library to be installed.
6
+ If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
7
+
8
+ Run the script:
9
+ $ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
10
+
11
+ """
12
+
13
+ import argparse
14
+ import tempfile
15
+ from typing import Union, List
16
+
17
+ import PIL
18
+ import imageio
19
+ import numpy as np
20
+ import torch
21
+ from diffusers import CogVideoXPipeline
22
+
23
+
24
+ def export_to_video_imageio(
25
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
26
+ ) -> str:
27
+ """
28
+ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
29
+ """
30
+ if output_video_path is None:
31
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
32
+ if isinstance(video_frames[0], PIL.Image.Image):
33
+ video_frames = [np.array(frame) for frame in video_frames]
34
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
35
+ for frame in video_frames:
36
+ writer.append_data(frame)
37
+ return output_video_path
38
+
39
+
40
+ def generate_video(
41
+ prompt: str,
42
+ model_path: str,
43
+ output_path: str = "./output.mp4",
44
+ num_inference_steps: int = 50,
45
+ guidance_scale: float = 6.0,
46
+ num_videos_per_prompt: int = 1,
47
+ device: str = "cuda",
48
+ dtype: torch.dtype = torch.float16,
49
+ ):
50
+ """
51
+ Generates a video based on the given prompt and saves it to the specified path.
52
+
53
+ Parameters:
54
+ - prompt (str): The description of the video to be generated.
55
+ - model_path (str): The path of the pre-trained model to be used.
56
+ - output_path (str): The path where the generated video will be saved.
57
+ - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
58
+ - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
59
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
60
+ - device (str): The device to use for computation (e.g., "cuda" or "cpu").
61
+ - dtype (torch.dtype): The data type for computation (default is torch.float16).
62
+ """
63
+
64
+ # Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
65
+ pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
66
+
67
+ # Encode the prompt to get the prompt embeddings
68
+ prompt_embeds, _ = pipe.encode_prompt(
69
+ prompt=prompt, # The textual description for video generation
70
+ negative_prompt=None, # The negative prompt to guide the video generation
71
+ do_classifier_free_guidance=True, # Whether to use classifier-free guidance
72
+ num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
73
+ max_sequence_length=226, # Maximum length of the sequence, must be 226
74
+ device=device, # Device to use for computation
75
+ dtype=dtype, # Data type for computation
76
+ )
77
+
78
+ # Generate the video frames using the pipeline
79
+ video = pipe(
80
+ num_inference_steps=num_inference_steps, # Number of inference steps
81
+ guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
82
+ prompt_embeds=prompt_embeds, # Encoded prompt embeddings
83
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
84
+ ).frames[0]
85
+
86
+ # Export the generated frames to a video file. fps must be 8
87
+ export_to_video_imageio(video, output_path, fps=8)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
92
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
93
+ parser.add_argument(
94
+ "--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
95
+ )
96
+ parser.add_argument(
97
+ "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
98
+ )
99
+ parser.add_argument(
100
+ "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
101
+ )
102
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
103
+ parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
104
+ parser.add_argument(
105
+ "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
110
+ )
111
+
112
+ args = parser.parse_args()
113
+
114
+ # Convert dtype argument to torch.dtype, NOT suggest BF16.
115
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
116
+
117
+ # main function to generate video.
118
+ generate_video(
119
+ prompt=args.prompt,
120
+ model_path=args.model_path,
121
+ output_path=args.output_path,
122
+ num_inference_steps=args.num_inference_steps,
123
+ guidance_scale=args.guidance_scale,
124
+ num_videos_per_prompt=args.num_videos_per_prompt,
125
+ device=args.device,
126
+ dtype=dtype,
127
+ )
inference/cli_vae_demo.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
3
+
4
+ Note:
5
+ This script requires the `diffusers>=0.30.0` library to be installed.
6
+ If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
7
+ Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
8
+
9
+ Run the script:
10
+ $ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
11
+
12
+ """
13
+
14
+ import argparse
15
+ import torch
16
+ import imageio
17
+ import numpy as np
18
+ from diffusers import AutoencoderKLCogVideoX
19
+ from torchvision import transforms
20
+
21
+
22
+ def vae_demo(model_path, video_path, dtype, device):
23
+ """
24
+ Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
25
+
26
+ Parameters:
27
+ - model_path (str): The path to the pre-trained model.
28
+ - video_path (str): The path to the video file.
29
+ - dtype (torch.dtype): The data type for computation.
30
+ - device (str): The device to use for computation (e.g., "cuda" or "cpu").
31
+
32
+ Returns:
33
+ - torch.Tensor: The encoded video frames.
34
+ """
35
+ # Load the pre-trained model
36
+ model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
37
+
38
+ # Load video frames
39
+ video_reader = imageio.get_reader(video_path, "ffmpeg")
40
+ frames = []
41
+ for frame in video_reader:
42
+ frames.append(frame)
43
+ video_reader.close()
44
+
45
+ # Transform frames to Tensor
46
+ transform = transforms.Compose(
47
+ [
48
+ transforms.ToTensor(),
49
+ ]
50
+ )
51
+ frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
52
+
53
+ # Add batch dimension and reshape to [1, 3, 49, 480, 720]
54
+ frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
55
+
56
+ # Run the model with Encoder and Decoder
57
+ with torch.no_grad():
58
+ output = model(frames_tensor)
59
+
60
+ return output
61
+
62
+
63
+ def save_video(tensor, output_path):
64
+ """
65
+ Saves the encoded video frames to a video file.
66
+
67
+ Parameters:
68
+ - tensor (torch.Tensor): The encoded video frames.
69
+ - output_path (str): The path to save the output video.
70
+ """
71
+ # Remove batch dimension and permute back to [49, 480, 720, 3]
72
+ frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
73
+
74
+ # Clip values to [0, 1] and convert to uint8
75
+ frames = np.clip(frames, 0, 1)
76
+ frames = (frames * 255).astype(np.uint8)
77
+
78
+ # Save frames to video
79
+ writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
80
+ for frame in frames:
81
+ writer.append_data(frame)
82
+ writer.close()
83
+
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
87
+ parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
88
+ parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
89
+ parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
90
+ parser.add_argument(
91
+ "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
92
+ )
93
+ parser.add_argument(
94
+ "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
95
+ )
96
+ args = parser.parse_args()
97
+
98
+ # Set device and dtype
99
+ device = torch.device(args.device)
100
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
101
+
102
+ output = vae_demo(args.model_path, args.video_path, dtype, device)
103
+ save_video(output, args.output_path)
inference/convert_demo.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ The CogVideoX model is pre-trained and fine-tuned using longer and more detailed prompts.Therefore, it requires highly granular and detailed prompts as input.This script aims to transform user inputs into executable inputs for CogVideoX, enabling superior video generation.
4
+
5
+ This step is not mandatory; the model will still function correctly and without errors even if the prompts are not refined using this script. However, we strongly recommend using it to ensure the generation of high-quality videos.
6
+
7
+ Note:
8
+ Please set the OPENAI_API_KEY and OPENAI_BASE_URL(if needed) environment variable to your OpenAI API key before running this script.
9
+
10
+ Run the script:
11
+ $ python convert_demo.py --prompt "A girl ridding a bike." # Using with OpenAI's API
12
+ """
13
+
14
+ import argparse
15
+
16
+ from openai import OpenAI
17
+
18
+
19
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
20
+
21
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
22
+ There are a few rules to follow:
23
+
24
+ You will only ever output a single video description per user request.
25
+
26
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
27
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
28
+
29
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
30
+ """
31
+
32
+
33
+ def convert_prompt(prompt: str, retry_times: int = 3):
34
+ """
35
+ Convert a prompt to a format that can be used by the model for inference
36
+ """
37
+
38
+ client = OpenAI()
39
+ text = prompt.strip()
40
+
41
+ for i in range(retry_times):
42
+ response = client.chat.completions.create(
43
+ messages=[
44
+ {"role": "system", "content": f"{sys_prompt}"},
45
+ {
46
+ "role": "user",
47
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " a girl is on the beach"',
48
+ },
49
+ {
50
+ "role": "assistant",
51
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
52
+ },
53
+ {
54
+ "role": "user",
55
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A man jogging on a football field"',
56
+ },
57
+ {
58
+ "role": "assistant",
59
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
60
+ },
61
+ {
62
+ "role": "user",
63
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
64
+ },
65
+ {
66
+ "role": "assistant",
67
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
68
+ },
69
+ {
70
+ "role": "user",
71
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
72
+ },
73
+ ],
74
+ model="glm-4-0520", # glm-4-0520 and gpt-4o have be tested
75
+ temperature=0.01,
76
+ top_p=0.7,
77
+ stream=False,
78
+ max_tokens=250,
79
+ )
80
+ if response.choices:
81
+ return response.choices[0].message.content
82
+ return prompt
83
+
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
88
+ parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
89
+ args = parser.parse_args()
90
+
91
+ converted_prompt = convert_prompt(args.prompt, args.retry_times)
92
+ print(converted_prompt)
inference/web_demo.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is used to create a Streamlit web application for generating videos using the CogVideoX model.
3
+
4
+ Run the script using Streamlit:
5
+ $ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key
6
+ $ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI
7
+ $ streamlit run web_demo.py
8
+ """
9
+
10
+ import base64
11
+ import json
12
+ import os
13
+ import time
14
+ from datetime import datetime
15
+ from typing import List
16
+
17
+ import imageio
18
+ import numpy as np
19
+ import streamlit as st
20
+ import torch
21
+ from convert_demo import convert_prompt
22
+ from diffusers import CogVideoXPipeline
23
+
24
+
25
+ model_path: str = "THUDM/CogVideoX-2b"
26
+
27
+
28
+ # Load the model at the start
29
+ @st.cache_resource
30
+ def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline:
31
+ """
32
+ Load the CogVideoX model.
33
+
34
+ Args:
35
+ - model_path (str): Path to the model.
36
+ - dtype (torch.dtype): Data type for model.
37
+ - device (str): Device to load the model on.
38
+
39
+ Returns:
40
+ - CogVideoXPipeline: Loaded model pipeline.
41
+ """
42
+ return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
43
+
44
+
45
+ # Define a function to generate video based on the provided prompt and model path
46
+ def generate_video(
47
+ pipe: CogVideoXPipeline,
48
+ prompt: str,
49
+ num_inference_steps: int = 50,
50
+ guidance_scale: float = 6.0,
51
+ num_videos_per_prompt: int = 1,
52
+ device: str = "cuda",
53
+ dtype: torch.dtype = torch.float16,
54
+ ) -> List[np.ndarray]:
55
+ """
56
+ Generate a video based on the provided prompt and model path.
57
+
58
+ Args:
59
+ - pipe (CogVideoXPipeline): The pipeline for generating videos.
60
+ - prompt (str): Text prompt for video generation.
61
+ - num_inference_steps (int): Number of inference steps.
62
+ - guidance_scale (float): Guidance scale for generation.
63
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
64
+ - device (str): Device to run the generation on.
65
+ - dtype (torch.dtype): Data type for the model.
66
+
67
+ Returns:
68
+ - List[np.ndarray]: Generated video frames.
69
+ """
70
+ prompt_embeds, _ = pipe.encode_prompt(
71
+ prompt=prompt,
72
+ negative_prompt=None,
73
+ do_classifier_free_guidance=True,
74
+ num_videos_per_prompt=num_videos_per_prompt,
75
+ max_sequence_length=226,
76
+ device=device,
77
+ dtype=dtype,
78
+ )
79
+
80
+ # Generate video
81
+ video = pipe(
82
+ num_inference_steps=num_inference_steps,
83
+ guidance_scale=guidance_scale,
84
+ prompt_embeds=prompt_embeds,
85
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds),
86
+ ).frames[0]
87
+ return video
88
+
89
+
90
+ def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None:
91
+ """
92
+ Save the generated video to a file.
93
+
94
+ Args:
95
+ - video (List[np.ndarray]): Video frames.
96
+ - path (str): Path to save the video.
97
+ - fps (int): Frames per second for the video.
98
+ """
99
+ # Remove the first frame
100
+ video = video[1:]
101
+
102
+ writer = imageio.get_writer(path, fps=fps, codec="libx264")
103
+ for frame in video:
104
+ np_frame = np.array(frame)
105
+ writer.append_data(np_frame)
106
+
107
+ writer.close()
108
+
109
+
110
+ def save_metadata(
111
+ prompt: str,
112
+ converted_prompt: str,
113
+ num_inference_steps: int,
114
+ guidance_scale: float,
115
+ num_videos_per_prompt: int,
116
+ path: str,
117
+ ) -> None:
118
+ """
119
+ Save metadata to a JSON file.
120
+
121
+ Args:
122
+ - prompt (str): Original prompt.
123
+ - converted_prompt (str): Converted prompt.
124
+ - num_inference_steps (int): Number of inference steps.
125
+ - guidance_scale (float): Guidance scale.
126
+ - num_videos_per_prompt (int): Number of videos per prompt.
127
+ - path (str): Path to save the metadata.
128
+ """
129
+ metadata = {
130
+ "prompt": prompt,
131
+ "converted_prompt": converted_prompt,
132
+ "num_inference_steps": num_inference_steps,
133
+ "guidance_scale": guidance_scale,
134
+ "num_videos_per_prompt": num_videos_per_prompt,
135
+ }
136
+ with open(path, "w") as f:
137
+ json.dump(metadata, f, indent=4)
138
+
139
+
140
+ def main() -> None:
141
+ """
142
+ Main function to run the Streamlit web application.
143
+ """
144
+ st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide")
145
+ st.write("# CogVideoX 🎥")
146
+ dtype: torch.dtype = torch.float16
147
+ device: str = "cuda"
148
+
149
+ global pipe
150
+ pipe = load_model(model_path, dtype, device)
151
+
152
+ with st.sidebar:
153
+ st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️")
154
+ num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50)
155
+ guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0)
156
+ num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1)
157
+
158
+ share_links_container = st.empty()
159
+
160
+ prompt: str = st.chat_input("Prompt")
161
+
162
+ if prompt:
163
+ # Not Necessary, Suggestions
164
+ with st.spinner("Refining prompts..."):
165
+ converted_prompt = convert_prompt(prompt=prompt, retry_times=1)
166
+ if converted_prompt is None:
167
+ st.error("Failed to Refining the prompt, Using origin one.")
168
+
169
+ st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}")
170
+ torch.cuda.empty_cache()
171
+
172
+ with st.spinner("Generating Video..."):
173
+ start_time = time.time()
174
+ video_paths = []
175
+
176
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
177
+ output_dir = f"./output/{timestamp}"
178
+ os.makedirs(output_dir, exist_ok=True)
179
+
180
+ metadata_path = os.path.join(output_dir, "config.json")
181
+ save_metadata(
182
+ prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path
183
+ )
184
+
185
+ for i in range(num_videos_per_prompt):
186
+ video_path = os.path.join(output_dir, f"output_{i + 1}.mp4")
187
+
188
+ video = generate_video(
189
+ pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype
190
+ )
191
+ save_video(video, video_path, fps=8)
192
+ video_paths.append(video_path)
193
+ with open(video_path, "rb") as video_file:
194
+ video_bytes: bytes = video_file.read()
195
+ st.video(video_bytes, autoplay=True, loop=True, format="video/mp4")
196
+ torch.cuda.empty_cache()
197
+
198
+ used_time: float = time.time() - start_time
199
+ st.success(f"Videos generated in {used_time:.2f} seconds.")
200
+
201
+ # Create download links in the sidebar
202
+ with share_links_container:
203
+ st.sidebar.write("### Download Links:")
204
+ for video_path in video_paths:
205
+ video_name = os.path.basename(video_path)
206
+ with open(video_path, "rb") as f:
207
+ video_bytes: bytes = f.read()
208
+ b64_video = base64.b64encode(video_bytes).decode()
209
+ href = f'<a href="data:video/mp4;base64,{b64_video}" download="{video_name}">Download {video_name}</a>'
210
+ st.sidebar.markdown(href, unsafe_allow_html=True)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ main()
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ line-length = 119
3
+
4
+ [tool.ruff.lint]
5
+ # Never enforce `E501` (line length violations).
6
+ ignore = ["C901", "E501", "E741", "F402", "F823"]
7
+ select = ["C", "E", "F", "I", "W"]
8
+
9
+ # Ignore import violations in all `__init__.py` files.
10
+ [tool.ruff.lint.per-file-ignores]
11
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
12
+
13
+ [tool.ruff.lint.isort]
14
+ lines-after-imports = 2
15
+
16
+ [tool.ruff.format]
17
+ # Like Black, use double quotes for strings.
18
+ quote-style = "double"
19
+
20
+ # Like Black, indent with spaces, rather than tabs.
21
+ indent-style = "space"
22
+
23
+ # Like Black, respect magic trailing commas.
24
+ skip-magic-trailing-comma = false
25
+
26
+ # Like Black, automatically detect the appropriate line ending.
27
+ line-ending = "auto"
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@d1c575ad7ee0390c2735f50cc59a79aae666567a#egg=diffusers
2
+ torch==2.4.0
3
+ torchvision==0.19.0
4
+ streamlit==1.37.0
5
+ opencv-python
6
+ imageio-ffmpeg==0.5.1
7
+ openai==1.38.0
8
+ transformers==4.43.4
9
+ accelerate==0.33.0
10
+ sentencepiece==0.2.0
11
+ pillow==9.5.0
resources/CogVideoX.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25ba30aafcd9604178c6d7adbd17f2bf1b251f3d29d1d29498e576075cb67c4e
3
+ size 31028426
resources/WECHAT.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src=wechat.jpg width="60%"/>
3
+
4
+ <p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
5
+ <p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
6
+ </div>
7
+
resources/contribute.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contribution Guide
2
+
3
+ There may still be many incomplete aspects in this project.
4
+
5
+ We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
6
+ and are willing to submit a PR and share it with the community, upon review, we
7
+ will acknowledge your contribution on the project homepage.
8
+
9
+ ## Model Algorithms
10
+
11
+ - Support for model quantization inference (Int4, Int8, etc. quantization engineering)
12
+ - Support for multi-card inference / model inference concurrency engineering
13
+ - Support for non-CUDA architecture inference devices
14
+
15
+ ## Model Engineering / Secondary Development
16
+
17
+ - Model fine-tuning examples / best prompt practices
18
+ - Video super-resolution/frame interpolation for enhancing video generation quality.
19
+ - Any peripheral tools for the model
20
+ - Any minimal complete open-source projects using the CogVideoX open-source model
21
+
22
+ ## Code Standards
23
+
24
+ Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
25
+ style. You can organize the code according to the following specifications:
26
+
27
+ 1. Install the `ruff` tool
28
+
29
+ ```shell
30
+ pip install ruff
31
+ ```
32
+
33
+ Then, run the `ruff` tool
34
+
35
+ ```shell
36
+ ruff check tools sat inference
37
+ ```
38
+
39
+ Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
40
+
41
+ ```shell
42
+ ruff format tools sat inference
43
+ ```
44
+
45
+ Once your code meets the standard, there should be no errors.
46
+
47
+ ## Naming Conventions
48
+ 1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
49
+ 2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
50
+
resources/contribute_zh.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 贡献指南
2
+
3
+ 本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
4
+
5
+ ## 模型算法
6
+
7
+ - 模型量化推理支持 (Int4,Int8等量化工程)
8
+ - 模型多卡推理支持 / 模型推理并发工程
9
+ - 非 CUDA 架构 推理设备支持
10
+
11
+ ## 模型工程 / 模型二次开发
12
+
13
+ - 模型微调示例 / 最佳提示词实践
14
+ - 视频超分/插帧,用于美化视频生成效果。
15
+ - 任何模型周边工具
16
+ - 任何使用CogVideoX开源模型制作的最小完整开源项目
17
+
18
+ ## 代码规范
19
+
20
+ 良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
21
+
22
+ 1. 安装`ruff`工具
23
+
24
+ ```shell
25
+ pip install ruff
26
+ ```
27
+
28
+ 接着,运行`ruff`工具
29
+
30
+ ```shell
31
+ ruff check tools sat inference
32
+ ```
33
+
34
+ 检查代码风格,如果有问题,您可以通过`ruff formate`命令自动修复。
35
+
36
+ ```shell
37
+ ruff formate tools sat inference
38
+ ```
39
+
40
+ 如果您的代码符合规范,应该不会出现任何的错误。
41
+
42
+ ## 命名规范
43
+
44
+ - 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
45
+ - 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
resources/logo.svg ADDED
resources/videos/1.mp4 ADDED
Binary file (636 kB). View file
 
resources/videos/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e738926f262a28b3b9af6573987905457ea82cdcadb0ec04ad9ab134324f5cc
3
+ size 1683616
resources/videos/3.mp4 ADDED
Binary file (746 kB). View file
 
resources/videos/4.mp4 ADDED
Binary file (233 kB). View file
 
resources/web_demo.png ADDED

Git LFS Details

  • SHA256: 0ac281bdbebe756ea9840cd5c13f04aafa3f05c2a16de1f75a45a6f31079e340
  • Pointer size: 132 Bytes
  • Size of remote file: 4.81 MB
resources/wechat.jpg ADDED
sat/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAT CogVideoX-2B
2
+
3
+ This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
4
+ fine-tuning code for SAT weights.
5
+
6
+ This code is the framework used by the team to train the model. It has few comments and requires careful study.
7
+
8
+ ## Inference Model
9
+
10
+ 1. Ensure that you have correctly installed the dependencies required by this folder.
11
+
12
+ ```shell
13
+ pip install -r requirements.txt
14
+ ```
15
+
16
+ 2. Download the model weights
17
+
18
+ First, go to the SAT mirror to download the dependencies.
19
+
20
+ ```shell
21
+ mkdir CogVideoX-2b-sat
22
+ cd CogVideoX-2b-sat
23
+ wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
24
+ mv 'index.html?dl=1' vae.zip
25
+ unzip vae.zip
26
+ wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
27
+ mv 'index.html?dl=1' transformer.zip
28
+ unzip transformer.zip
29
+ ```
30
+
31
+ Then unzip, the model structure should look like this:
32
+
33
+ ```
34
+ .
35
+ ├── transformer
36
+ │ ├── 1000
37
+ │ │ └── mp_rank_00_model_states.pt
38
+ │ └── latest
39
+ └── vae
40
+ └── 3d-vae.pt
41
+ ```
42
+
43
+ Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
44
+
45
+ ```shell
46
+ git lfs install
47
+ git clone https://huggingface.co/google/t5-v1_1-xxl.git
48
+ ```
49
+
50
+ **We don't need the tf_model.h5** file. This file can be deleted.
51
+
52
+ 3. Modify the file `configs/cogvideox_2b_infer.yaml`.
53
+
54
+ ```yaml
55
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
56
+
57
+ conditioner_config:
58
+ target: sgm.modules.GeneralConditioner
59
+ params:
60
+ emb_models:
61
+ - is_trainable: false
62
+ input_key: txt
63
+ ucg_rate: 0.1
64
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
65
+ params:
66
+ model_dir: "google/t5-v1_1-xxl" ## T5 model path
67
+ max_length: 226
68
+
69
+ first_stage_config:
70
+ target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
71
+ params:
72
+ cp_size: 1
73
+ ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE model path
74
+ ```
75
+
76
+ + If using txt to save multiple prompts, please refer to `configs/test.txt` for modification. One prompt per line. If
77
+ you don't know how to write prompts, you can first use [this code](../inference/convert_demo.py) to call LLM for
78
+ refinement.
79
+ + If using the command line as input, modify
80
+
81
+ ```yaml
82
+ input_type: cli
83
+ ```
84
+
85
+ so that prompts can be entered from the command line.
86
+
87
+ If you want to change the output video directory, you can modify:
88
+
89
+ ```yaml
90
+ output_dir: outputs/
91
+ ```
92
+
93
+ The default is saved in the `.outputs/` folder.
94
+
95
+ 4. Run the inference code to start inference
96
+
97
+ ```shell
98
+ bash inference.sh
99
+ ```
100
+
101
+ ## Fine-Tuning the Model
102
+
103
+ ### Preparing the Dataset
104
+
105
+ The dataset format should be as follows:
106
+
107
+ ```
108
+ .
109
+ ├── labels
110
+ │   ├── 1.txt
111
+ │   ├── 2.txt
112
+ │   ├── ...
113
+ └── videos
114
+ ├── 1.mp4
115
+ ├── 2.mp4
116
+ ├── ...
117
+ ```
118
+
119
+ Each txt file should have the same name as its corresponding video file and contain the labels for that video. Each
120
+ video should have a one-to-one correspondence with a label. Typically, a video should not have multiple labels.
121
+
122
+ For style fine-tuning, please prepare at least 50 videos and labels with similar styles to facilitate fitting.
123
+
124
+ ### Modifying the Configuration File
125
+
126
+ We support both `Lora` and `full-parameter fine-tuning` methods. Please note that both fine-tuning methods only apply to the `transformer` part. The `VAE part` is not modified. `T5` is only used as an Encoder.
127
+
128
+ the `configs/cogvideox_2b_sft.yaml` (for full fine-tuning) as follows.
129
+
130
+ ```yaml
131
+ # checkpoint_activations: True ## using gradient checkpointing (both checkpoint_activations in the configuration file need to be set to True)
132
+ model_parallel_size: 1 # Model parallel size
133
+ experiment_name: lora-disney # Experiment name (do not change)
134
+ mode: finetune # Mode (do not change)
135
+ load: "{your_CogVideoX-2b-sat_path}/transformer" # Transformer model path
136
+ no_load_rng: True # Whether to load the random seed
137
+ train_iters: 1000 # Number of training iterations
138
+ eval_iters: 1 # Number of evaluation iterations
139
+ eval_interval: 100 # Evaluation interval
140
+ eval_batch_size: 1 # Batch size for evaluation
141
+ save: ckpts # Model save path
142
+ save_interval: 100 # Model save interval
143
+ log_interval: 20 # Log output interval
144
+ train_data: [ "your train data path" ]
145
+ valid_data: [ "your val data path" ] # Training and validation sets can be the same
146
+ split: 1,0,0 # Ratio of training, validation, and test sets
147
+ num_workers: 8 # Number of worker threads for data loading
148
+ ```
149
+
150
+ If you wish to use Lora fine-tuning, you also need to modify:
151
+
152
+ ```yaml
153
+ model:
154
+ scale_factor: 1.15258426
155
+ disable_first_stage_autocast: true
156
+ not_trainable_prefixes: [ 'all' ] ## Uncomment
157
+ log_keys:
158
+ - txt'
159
+
160
+ lora_config: ## Uncomment
161
+ target: sat.model.finetune.lora2.LoraMixin
162
+ params:
163
+ r: 256
164
+ ```
165
+
166
+ ### Fine-Tuning and Validation
167
+
168
+ 1. Run the inference code to start fine-tuning.
169
+
170
+ ```shell
171
+ bash finetune.sh
172
+ ```
173
+
174
+ ### Converting to Huggingface Diffusers Supported Weights
175
+
176
+ The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
177
+
178
+ ```shell
179
+ python ../tools/convert_weight_sat2hf.py
180
+ ```
181
+
182
+ **Note**: This content has not yet been tested with LORA fine-tuning models.
sat/README_zh.md ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAT CogVideoX-2B
2
+
3
+ 本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
4
+
5
+ 该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
6
+
7
+ ## 推理模型
8
+
9
+ 1. 确保你已经正确安装本文件夹中的要求的依赖
10
+
11
+ ```shell
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ 2. 下载模型权重
16
+
17
+ 首先,前往 SAT 镜像下载依赖。
18
+
19
+ ```shell
20
+ mkdir CogVideoX-2b-sat
21
+ cd CogVideoX-2b-sat
22
+ wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
23
+ mv 'index.html?dl=1' vae.zip
24
+ unzip vae.zip
25
+ wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
26
+ mv 'index.html?dl=1' transformer.zip
27
+ unzip transformer.zip
28
+ ```
29
+
30
+ 然后,解压文件,模型结构应该如下
31
+
32
+ ```
33
+ .
34
+ ├── transformer
35
+ │   ├── 1000
36
+ │   │   └── mp_rank_00_model_states.pt
37
+ │   └── latest
38
+ └── vae
39
+ └── 3d-vae.pt
40
+ ```
41
+
42
+ 接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。
43
+
44
+ ```shell
45
+ git lfs install
46
+ git clone https://huggingface.co/google/t5-v1_1-xxl.git
47
+ ```
48
+
49
+ **我们不需要使用tf_model.h5**文件。该文件可以删除。
50
+
51
+ 3. 修改`configs/cogvideox_2b_infer.yaml`中的文件。
52
+
53
+ ```yaml
54
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
55
+
56
+ conditioner_config:
57
+ target: sgm.modules.GeneralConditioner
58
+ params:
59
+ emb_models:
60
+ - is_trainable: false
61
+ input_key: txt
62
+ ucg_rate: 0.1
63
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
64
+ params:
65
+ model_dir: "google/t5-v1_1-xxl" ## T5 模型路径
66
+ max_length: 226
67
+
68
+ first_stage_config:
69
+ target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
70
+ params:
71
+ cp_size: 1
72
+ ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE 模型路径
73
+
74
+ ```
75
+
76
+ + 如果使用 txt 保存多个提示词,请参考`configs/test.txt`
77
+ 进行修改。每一行一个提示词。如果您不知道如何书写提示词,可以先使用[此代码](../inference/convert_demo.py)调用 LLM进行润色。
78
+ + 如果使用命令行作为输入,请修改
79
+
80
+ ```yaml
81
+ input_type: cli
82
+ ```
83
+
84
+ 这样就可以从命令行输入提示词。
85
+
86
+ 如果你希望修改输出视频的地址,你可以修改:
87
+
88
+ ```yaml
89
+ output_dir: outputs/
90
+ ```
91
+
92
+ 默认保存在`.outputs/`文件夹下。
93
+
94
+ 4. 运行推理代码,即可推理
95
+
96
+ ```shell
97
+ bash inference.sh
98
+ ```
99
+
100
+ ## 微调模型
101
+
102
+ ### 准备数据集
103
+
104
+ 数据集格式应该如下:
105
+
106
+ ```
107
+ .
108
+ ├── labels
109
+ │   ├── 1.txt
110
+ │   ├── 2.txt
111
+ │   ├── ...
112
+ └── videos
113
+ ├── 1.mp4
114
+ ├── 2.mp4
115
+ ├── ...
116
+ ```
117
+
118
+ 每个 txt 与视频同名,为视频的标签。视频与标签应该一一对应。通常情况下,不使用一个视频对应多个标签。
119
+
120
+ 如果为风格微调,清准备至少50条风格相似的视频和标签,以利于拟合。
121
+
122
+ ### 修改配置文件
123
+
124
+ 我们支持 `Lora` 和 全参数微调两种方式。请注意,两种微调方式都仅仅对 `transformer` 部分进行微调。不改动 `VAE` 部分。`T5`仅作为
125
+ Encoder 使用。
126
+ 部分。 请按照以下方式修改`configs/cogvideox_2b_sft.yaml`(全量微调) 中的文件。
127
+
128
+ ```yaml
129
+ # checkpoint_activations: True ## using gradient checkpointing (配置文件中的两个checkpoint_activations都需要设置为True)
130
+ model_parallel_size: 1 # 模型并行大小
131
+ experiment_name: lora-disney # 实验名称(不要改动)
132
+ mode: finetune # 模式(不要改动)
133
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
134
+ no_load_rng: True # 是否加载随机数种子
135
+ train_iters: 1000 # 训练迭代次数
136
+ eval_iters: 1 # 验证迭代次数
137
+ eval_interval: 100 # 验证间隔
138
+ eval_batch_size: 1 # 验证集 batch size
139
+ save: ckpts # 模型保存路径
140
+ save_interval: 100 # 模型保存间隔
141
+ log_interval: 20 # 日志输出间隔
142
+ train_data: [ "your train data path" ]
143
+ valid_data: [ "your val data path" ] # 训练集和验证集可以相同
144
+ split: 1,0,0 # 训练集,验证集,测试集比例
145
+ num_workers: 8 # 数据加载器的工作线程数
146
+ ```
147
+
148
+ 如果你希望使用 Lora 微调,你还需要修改:
149
+
150
+ ```yaml
151
+ model:
152
+ scale_factor: 1.15258426
153
+ disable_first_stage_autocast: true
154
+ not_trainable_prefixes: [ 'all' ] ## 解除注释
155
+ log_keys:
156
+ - txt'
157
+
158
+ lora_config: ## 解除注释
159
+ target: sat.model.finetune.lora2.LoraMixin
160
+ params:
161
+ r: 256
162
+ ```
163
+
164
+ ### 微调和验证
165
+
166
+ 1. 运行推理代码,即可开始微调。
167
+
168
+ ```shell
169
+ bash finetune.sh
170
+ ```
171
+
172
+ ### 转换到 Huggingface Diffusers 库支持的权重
173
+
174
+ SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
175
+
176
+ ```shell
177
+ python ../tools/convert_weight_sat2hf.py
178
+ ```
179
+
180
+ **注意** 本内容暂未测试 LORA 微调模型。
sat/arguments.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import json
5
+ import warnings
6
+ import omegaconf
7
+ from omegaconf import OmegaConf
8
+ from sat.helpers import print_rank0
9
+ from sat import mpu
10
+ from sat.arguments import set_random_seed
11
+ from sat.arguments import add_training_args, add_evaluation_args, add_data_args
12
+ import torch.distributed
13
+
14
+
15
+ def add_model_config_args(parser):
16
+ """Model arguments"""
17
+
18
+ group = parser.add_argument_group("model", "model configuration")
19
+ group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
20
+ group.add_argument(
21
+ "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
22
+ )
23
+ group.add_argument("--force-pretrain", action="store_true")
24
+ group.add_argument("--device", type=int, default=-1)
25
+ group.add_argument("--debug", action="store_true")
26
+ group.add_argument("--log-image", type=bool, default=True)
27
+
28
+ return parser
29
+
30
+
31
+ def add_sampling_config_args(parser):
32
+ """Sampling configurations"""
33
+
34
+ group = parser.add_argument_group("sampling", "Sampling Configurations")
35
+ group.add_argument("--output-dir", type=str, default="samples")
36
+ group.add_argument("--input-dir", type=str, default=None)
37
+ group.add_argument("--input-type", type=str, default="cli")
38
+ group.add_argument("--input-file", type=str, default="input.txt")
39
+ group.add_argument("--final-size", type=int, default=2048)
40
+ group.add_argument("--sdedit", action="store_true")
41
+ group.add_argument("--grid-num-rows", type=int, default=1)
42
+ group.add_argument("--force-inference", action="store_true")
43
+ group.add_argument("--lcm_steps", type=int, default=None)
44
+ group.add_argument("--sampling-num-frames", type=int, default=32)
45
+ group.add_argument("--sampling-fps", type=int, default=8)
46
+ group.add_argument("--only-save-latents", type=bool, default=False)
47
+ group.add_argument("--only-log-video-latents", type=bool, default=False)
48
+ group.add_argument("--latent-channels", type=int, default=32)
49
+ group.add_argument("--image2video", action="store_true")
50
+
51
+ return parser
52
+
53
+
54
+ def get_args(args_list=None, parser=None):
55
+ """Parse all the args."""
56
+ if parser is None:
57
+ parser = argparse.ArgumentParser(description="sat")
58
+ else:
59
+ assert isinstance(parser, argparse.ArgumentParser)
60
+ parser = add_model_config_args(parser)
61
+ parser = add_sampling_config_args(parser)
62
+ parser = add_training_args(parser)
63
+ parser = add_evaluation_args(parser)
64
+ parser = add_data_args(parser)
65
+
66
+ import deepspeed
67
+
68
+ parser = deepspeed.add_config_arguments(parser)
69
+
70
+ args = parser.parse_args(args_list)
71
+ args = process_config_to_args(args)
72
+
73
+ if not args.train_data:
74
+ print_rank0("No training data specified", level="WARNING")
75
+
76
+ assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
77
+ if args.train_iters is None and args.epochs is None:
78
+ args.train_iters = 10000 # default 10k iters
79
+ print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
80
+
81
+ args.cuda = torch.cuda.is_available()
82
+
83
+ args.rank = int(os.getenv("RANK", "0"))
84
+ args.world_size = int(os.getenv("WORLD_SIZE", "1"))
85
+ if args.local_rank is None:
86
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
87
+
88
+ if args.device == -1:
89
+ if torch.cuda.device_count() == 0:
90
+ args.device = "cpu"
91
+ elif args.local_rank is not None:
92
+ args.device = args.local_rank
93
+ else:
94
+ args.device = args.rank % torch.cuda.device_count()
95
+
96
+ if args.local_rank != args.device and args.mode != "inference":
97
+ raise ValueError(
98
+ "LOCAL_RANK (default 0) and args.device inconsistent. "
99
+ "This can only happens in inference mode. "
100
+ "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
101
+ )
102
+
103
+ if args.rank == 0:
104
+ print_rank0("using world size: {}".format(args.world_size))
105
+
106
+ if args.train_data_weights is not None:
107
+ assert len(args.train_data_weights) == len(args.train_data)
108
+
109
+ if args.mode != "inference": # training with deepspeed
110
+ args.deepspeed = True
111
+ if args.deepspeed_config is None: # not specified
112
+ deepspeed_config_path = os.path.join(
113
+ os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
114
+ )
115
+ with open(deepspeed_config_path) as file:
116
+ args.deepspeed_config = json.load(file)
117
+ override_deepspeed_config = True
118
+ else:
119
+ override_deepspeed_config = False
120
+
121
+ assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
122
+
123
+ if args.zero_stage > 0 and not args.fp16 and not args.bf16:
124
+ print_rank0("Automatically set fp16=True to use ZeRO.")
125
+ args.fp16 = True
126
+ args.bf16 = False
127
+
128
+ if args.deepspeed:
129
+ if args.checkpoint_activations:
130
+ args.deepspeed_activation_checkpointing = True
131
+ else:
132
+ args.deepspeed_activation_checkpointing = False
133
+ if args.deepspeed_config is not None:
134
+ deepspeed_config = args.deepspeed_config
135
+
136
+ if override_deepspeed_config: # not specify deepspeed_config, use args
137
+ if args.fp16:
138
+ deepspeed_config["fp16"]["enabled"] = True
139
+ elif args.bf16:
140
+ deepspeed_config["bf16"]["enabled"] = True
141
+ deepspeed_config["fp16"]["enabled"] = False
142
+ else:
143
+ deepspeed_config["fp16"]["enabled"] = False
144
+ deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
145
+ deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
146
+ optimizer_params_config = deepspeed_config["optimizer"]["params"]
147
+ optimizer_params_config["lr"] = args.lr
148
+ optimizer_params_config["weight_decay"] = args.weight_decay
149
+ else: # override args with values in deepspeed_config
150
+ if args.rank == 0:
151
+ print_rank0("Will override arguments with manually specified deepspeed_config!")
152
+ if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
153
+ args.fp16 = True
154
+ else:
155
+ args.fp16 = False
156
+ if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
157
+ args.bf16 = True
158
+ else:
159
+ args.bf16 = False
160
+ if "train_micro_batch_size_per_gpu" in deepspeed_config:
161
+ args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
162
+ if "gradient_accumulation_steps" in deepspeed_config:
163
+ args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
164
+ else:
165
+ args.gradient_accumulation_steps = None
166
+ if "optimizer" in deepspeed_config:
167
+ optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
168
+ args.lr = optimizer_params_config.get("lr", args.lr)
169
+ args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
170
+ args.deepspeed_config = deepspeed_config
171
+
172
+ # initialize distributed and random seed because it always seems to be necessary.
173
+ initialize_distributed(args)
174
+ args.seed = args.seed + mpu.get_data_parallel_rank()
175
+ set_random_seed(args.seed)
176
+ return args
177
+
178
+
179
+ def initialize_distributed(args):
180
+ """Initialize torch.distributed."""
181
+ if torch.distributed.is_initialized():
182
+ if mpu.model_parallel_is_initialized():
183
+ if args.model_parallel_size != mpu.get_model_parallel_world_size():
184
+ raise ValueError(
185
+ "model_parallel_size is inconsistent with prior configuration."
186
+ "We currently do not support changing model_parallel_size."
187
+ )
188
+ return False
189
+ else:
190
+ if args.model_parallel_size > 1:
191
+ warnings.warn(
192
+ "model_parallel_size > 1 but torch.distributed is not initialized via SAT."
193
+ "Please carefully make sure the correctness on your own."
194
+ )
195
+ mpu.initialize_model_parallel(args.model_parallel_size)
196
+ return True
197
+ # the automatic assignment of devices has been moved to arguments.py
198
+ if args.device == "cpu":
199
+ pass
200
+ else:
201
+ torch.cuda.set_device(args.device)
202
+ # Call the init process
203
+ init_method = "tcp://"
204
+ args.master_ip = os.getenv("MASTER_ADDR", "localhost")
205
+
206
+ if args.world_size == 1:
207
+ from sat.helpers import get_free_port
208
+
209
+ default_master_port = str(get_free_port())
210
+ else:
211
+ default_master_port = "6000"
212
+ args.master_port = os.getenv("MASTER_PORT", default_master_port)
213
+ init_method += args.master_ip + ":" + args.master_port
214
+ torch.distributed.init_process_group(
215
+ backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
216
+ )
217
+
218
+ # Set the model-parallel / data-parallel communicators.
219
+ mpu.initialize_model_parallel(args.model_parallel_size)
220
+
221
+ # Set vae context parallel group equal to model parallel group
222
+ from sgm.util import set_context_parallel_group, initialize_context_parallel
223
+
224
+ if args.model_parallel_size <= 2:
225
+ set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
226
+ else:
227
+ initialize_context_parallel(2)
228
+ # mpu.initialize_model_parallel(1)
229
+ # Optional DeepSpeed Activation Checkpointing Features
230
+ if args.deepspeed:
231
+ import deepspeed
232
+
233
+ deepspeed.init_distributed(
234
+ dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
235
+ )
236
+ # # It seems that it has no negative influence to configure it even without using checkpointing.
237
+ # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
238
+ else:
239
+ # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
240
+ try:
241
+ import deepspeed
242
+ from deepspeed.runtime.activation_checkpointing.checkpointing import (
243
+ _CUDA_RNG_STATE_TRACKER,
244
+ _MODEL_PARALLEL_RNG_TRACKER_NAME,
245
+ )
246
+
247
+ _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
248
+ except Exception as e:
249
+ from sat.helpers import print_rank0
250
+
251
+ print_rank0(str(e), level="DEBUG")
252
+
253
+ return True
254
+
255
+
256
+ def process_config_to_args(args):
257
+ """Fetch args from only --base"""
258
+
259
+ configs = [OmegaConf.load(cfg) for cfg in args.base]
260
+ config = OmegaConf.merge(*configs)
261
+
262
+ args_config = config.pop("args", OmegaConf.create())
263
+ for key in args_config:
264
+ if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
265
+ arg = OmegaConf.to_object(args_config[key])
266
+ else:
267
+ arg = args_config[key]
268
+ if hasattr(args, key):
269
+ setattr(args, key, arg)
270
+
271
+ if "model" in config:
272
+ model_config = config.pop("model", OmegaConf.create())
273
+ args.model_config = model_config
274
+ if "deepspeed" in config:
275
+ deepspeed_config = config.pop("deepspeed", OmegaConf.create())
276
+ args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
277
+ if "data" in config:
278
+ data_config = config.pop("data", OmegaConf.create())
279
+ args.data_config = data_config
280
+
281
+ return args
sat/configs/cogvideox_2b_infer.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ latent_channels: 16
3
+ mode: inference
4
+ load: "CogVideoX-2b-sat/transformer"
5
+ batch_size: 1
6
+ input_type: txt
7
+ input_file: test.txt
8
+ sampling_num_frames: 13 # Must be 13, 11 or 9
9
+ sampling_fps: 8
10
+ fp16: True
11
+ output_dir: outputs/
12
+ force_inference: True
13
+
14
+ model:
15
+ scale_factor: 1.15258426
16
+ disable_first_stage_autocast: true
17
+ log_keys:
18
+ - txt
19
+
20
+ denoiser_config:
21
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
22
+ params:
23
+ num_idx: 1000
24
+ quantize_c_noise: False
25
+
26
+ weighting_config:
27
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
28
+ scaling_config:
29
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
30
+ discretization_config:
31
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
32
+ params:
33
+ shift_scale: 3.0
34
+
35
+ network_config:
36
+ target: dit_video_concat.DiffusionTransformer
37
+ params:
38
+ time_embed_dim: 512
39
+ elementwise_affine: True
40
+ num_frames: 49
41
+ time_compressed_rate: 4
42
+ latent_width: 90
43
+ latent_height: 60
44
+ num_layers: 30
45
+ patch_size: 2
46
+ in_channels: 16
47
+ out_channels: 16
48
+ hidden_size: 1920
49
+ adm_in_channels: 256
50
+ num_attention_heads: 30
51
+
52
+ transformer_args:
53
+ vocab_size: 1
54
+ max_sequence_length: 64
55
+ layernorm_order: pre
56
+ skip_init: false
57
+ model_parallel_size: 1
58
+ is_decoder: false
59
+
60
+ modules:
61
+ pos_embed_config:
62
+ target: dit_video_concat.Basic3DPositionEmbeddingMixin
63
+ params:
64
+ text_length: 226
65
+ height_interpolation: 1.875
66
+ width_interpolation: 1.875
67
+
68
+ patch_embed_config:
69
+ target: dit_video_concat.ImagePatchEmbeddingMixin
70
+ params:
71
+ text_hidden_size: 4096
72
+
73
+ adaln_layer_config:
74
+ target: dit_video_concat.AdaLNMixin
75
+ params:
76
+ qk_ln: True
77
+
78
+ final_layer_config:
79
+ target: dit_video_concat.FinalLayerMixin
80
+
81
+ conditioner_config:
82
+ target: sgm.modules.GeneralConditioner
83
+ params:
84
+ emb_models:
85
+ - is_trainable: false
86
+ input_key: txt
87
+ ucg_rate: 0.1
88
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
89
+ params:
90
+ model_dir: "google/t5-v1_1-xxl"
91
+ max_length: 226
92
+
93
+ first_stage_config:
94
+ target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
95
+ params:
96
+ cp_size: 1
97
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
98
+ ignore_keys: [ 'loss' ]
99
+
100
+ loss_config:
101
+ target: torch.nn.Identity
102
+
103
+ regularizer_config:
104
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
105
+
106
+ encoder_config:
107
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
108
+ params:
109
+ double_z: true
110
+ z_channels: 16
111
+ resolution: 256
112
+ in_channels: 3
113
+ out_ch: 3
114
+ ch: 128
115
+ ch_mult: [ 1, 2, 2, 4 ]
116
+ attn_resolutions: [ ]
117
+ num_res_blocks: 3
118
+ dropout: 0.0
119
+ gather_norm: True
120
+
121
+ decoder_config:
122
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
123
+ params:
124
+ double_z: True
125
+ z_channels: 16
126
+ resolution: 256
127
+ in_channels: 3
128
+ out_ch: 3
129
+ ch: 128
130
+ ch_mult: [ 1, 2, 2, 4 ]
131
+ attn_resolutions: [ ]
132
+ num_res_blocks: 3
133
+ dropout: 0.0
134
+ gather_norm: false
135
+
136
+ loss_fn_config:
137
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
138
+ params:
139
+ offset_noise_level: 0
140
+ sigma_sampler_config:
141
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
142
+ params:
143
+ uniform_sampling: True
144
+ num_idx: 1000
145
+ discretization_config:
146
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
147
+ params:
148
+ shift_scale: 3.0
149
+
150
+ sampler_config:
151
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
152
+ params:
153
+ num_steps: 50
154
+ verbose: True
155
+
156
+ discretization_config:
157
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
158
+ params:
159
+ shift_scale: 3.0
160
+
161
+ guider_config:
162
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
163
+ params:
164
+ scale: 6
165
+ exp: 5
166
+ num_steps: 50
sat/configs/cogvideox_2b_sft.yaml ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ checkpoint_activations: True ## using gradient checkpointing
3
+ model_parallel_size: 1
4
+ experiment_name: lora-disney
5
+ mode: finetune
6
+ load: "CogVideoX-2b-sat/transformer"
7
+ no_load_rng: True
8
+ train_iters: 1000
9
+ eval_iters: 1
10
+ eval_interval: 100
11
+ eval_batch_size: 1
12
+ save: ckpts
13
+ save_interval: 100
14
+ log_interval: 20
15
+ train_data: ["disney"]
16
+ valid_data: ["disney"]
17
+ split: 1,0,0
18
+ num_workers: 8
19
+ force_train: True
20
+ only_log_video_latents: True
21
+
22
+ data:
23
+ target: data_video.SFTDataset
24
+ params:
25
+ video_size: [480, 720]
26
+ fps: 8
27
+ max_num_frames: 49
28
+ skip_frms_num: 3.
29
+
30
+ deepspeed:
31
+ train_micro_batch_size_per_gpu: 1
32
+ gradient_accumulation_steps: 1
33
+ steps_per_print: 50
34
+ gradient_clipping: 0.1
35
+ zero_optimization:
36
+ stage: 2
37
+ cpu_offload: false
38
+ contiguous_gradients: false
39
+ overlap_comm: true
40
+ reduce_scatter: true
41
+ reduce_bucket_size: 1000000000
42
+ allgather_bucket_size: 1000000000
43
+ load_from_fp32_weights: false
44
+ zero_allow_untested_optimizer: true
45
+ bf16:
46
+ enabled: False
47
+ fp16:
48
+ enabled: True
49
+ loss_scale: 0
50
+ loss_scale_window: 400
51
+ hysteresis: 2
52
+ min_loss_scale: 1
53
+ optimizer:
54
+ type: sat.ops.FusedEmaAdam
55
+ params:
56
+ lr: 0.0002
57
+ betas: [0.9, 0.95]
58
+ eps: 1e-8
59
+ weight_decay: 1e-4
60
+ activation_checkpointing:
61
+ partition_activations: false
62
+ contiguous_memory_optimization: false
63
+ wall_clock_breakdown: false
64
+
65
+
66
+ model:
67
+ scale_factor: 1.15258426
68
+ disable_first_stage_autocast: true
69
+ not_trainable_prefixes: ['all'] ## Using Lora
70
+ log_keys:
71
+ - txt
72
+
73
+ denoiser_config:
74
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
75
+ params:
76
+ num_idx: 1000
77
+ quantize_c_noise: False
78
+
79
+ weighting_config:
80
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
81
+ scaling_config:
82
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
83
+ discretization_config:
84
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
85
+ params:
86
+ shift_scale: 3.0
87
+
88
+ network_config:
89
+ target: dit_video_concat.DiffusionTransformer
90
+ params:
91
+ time_embed_dim: 512
92
+ elementwise_affine: True
93
+ num_frames: 49
94
+ time_compressed_rate: 4
95
+ latent_width: 90
96
+ latent_height: 60
97
+ num_layers: 30
98
+ patch_size: 2
99
+ in_channels: 16
100
+ out_channels: 16
101
+ hidden_size: 1920
102
+ adm_in_channels: 256
103
+ num_attention_heads: 30
104
+
105
+ transformer_args:
106
+ checkpoint_activations: True ## using gradient checkpointing
107
+ vocab_size: 1
108
+ max_sequence_length: 64
109
+ layernorm_order: pre
110
+ skip_init: false
111
+ model_parallel_size: 1
112
+ is_decoder: false
113
+
114
+ modules:
115
+ pos_embed_config:
116
+ target: dit_video_concat.Basic3DPositionEmbeddingMixin
117
+ params:
118
+ text_length: 226
119
+ height_interpolation: 1.875
120
+ width_interpolation: 1.875
121
+
122
+ lora_config: ## Using Lora
123
+ target: sat.model.finetune.lora2.LoraMixin
124
+ params:
125
+ r: 128
126
+
127
+ patch_embed_config:
128
+ target: dit_video_concat.ImagePatchEmbeddingMixin
129
+ params:
130
+ text_hidden_size: 4096
131
+
132
+ adaln_layer_config:
133
+ target: dit_video_concat.AdaLNMixin
134
+ params:
135
+ qk_ln: True
136
+
137
+ final_layer_config:
138
+ target: dit_video_concat.FinalLayerMixin
139
+
140
+ conditioner_config:
141
+ target: sgm.modules.GeneralConditioner
142
+ params:
143
+ emb_models:
144
+ - is_trainable: false
145
+ input_key: txt
146
+ ucg_rate: 0.1
147
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
148
+ params:
149
+ model_dir: "google/t5-v1_1-xxl"
150
+ max_length: 226
151
+
152
+ first_stage_config:
153
+ target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
154
+ params:
155
+ cp_size: 1
156
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
157
+ ignore_keys: [ 'loss' ]
158
+
159
+ loss_config:
160
+ target: torch.nn.Identity
161
+
162
+ regularizer_config:
163
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
164
+
165
+ encoder_config:
166
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
167
+ params:
168
+ double_z: true
169
+ z_channels: 16
170
+ resolution: 256
171
+ in_channels: 3
172
+ out_ch: 3
173
+ ch: 128
174
+ ch_mult: [ 1, 2, 2, 4 ]
175
+ attn_resolutions: [ ]
176
+ num_res_blocks: 3
177
+ dropout: 0.0
178
+ gather_norm: True
179
+
180
+ decoder_config:
181
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
182
+ params:
183
+ double_z: True
184
+ z_channels: 16
185
+ resolution: 256
186
+ in_channels: 3
187
+ out_ch: 3
188
+ ch: 128
189
+ ch_mult: [ 1, 2, 2, 4 ]
190
+ attn_resolutions: [ ]
191
+ num_res_blocks: 3
192
+ dropout: 0.0
193
+ gather_norm: false
194
+
195
+ loss_fn_config:
196
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
197
+ params:
198
+ offset_noise_level: 0
199
+ sigma_sampler_config:
200
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
201
+ params:
202
+ uniform_sampling: True
203
+ num_idx: 1000
204
+ discretization_config:
205
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
206
+ params:
207
+ shift_scale: 3.0
208
+
209
+ sampler_config:
210
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
211
+ params:
212
+ num_steps: 50
213
+ verbose: True
214
+
215
+ discretization_config:
216
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
217
+ params:
218
+ shift_scale: 3.0
219
+
220
+ guider_config:
221
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
222
+ params:
223
+ scale: 6
224
+ exp: 5
225
+ num_steps: 50
sat/configs/test.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
2
+ The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
3
+ A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
sat/data_video.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+ from functools import partial
5
+ import math
6
+ import torchvision.transforms as TT
7
+ from sgm.webds import MetaDistributedWebDataset
8
+ import random
9
+ from fractions import Fraction
10
+ from typing import Union, Optional, Dict, Any, Tuple
11
+ from torchvision.io.video import av
12
+ import numpy as np
13
+ import torch
14
+ from torchvision.io import _video_opt
15
+ from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
16
+ from torchvision.transforms.functional import center_crop, resize
17
+ from torchvision.transforms import InterpolationMode
18
+ import decord
19
+ from decord import VideoReader
20
+ from torch.utils.data import Dataset
21
+
22
+
23
+ def read_video(
24
+ filename: str,
25
+ start_pts: Union[float, Fraction] = 0,
26
+ end_pts: Optional[Union[float, Fraction]] = None,
27
+ pts_unit: str = "pts",
28
+ output_format: str = "THWC",
29
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
30
+ """
31
+ Reads a video from a file, returning both the video frames and the audio frames
32
+
33
+ Args:
34
+ filename (str): path to the video file
35
+ start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
36
+ The start presentation time of the video
37
+ end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
38
+ The end presentation time
39
+ pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
40
+ either 'pts' or 'sec'. Defaults to 'pts'.
41
+ output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
42
+
43
+ Returns:
44
+ vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
45
+ aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
46
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
47
+ """
48
+
49
+ output_format = output_format.upper()
50
+ if output_format not in ("THWC", "TCHW"):
51
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
52
+
53
+ _check_av_available()
54
+
55
+ if end_pts is None:
56
+ end_pts = float("inf")
57
+
58
+ if end_pts < start_pts:
59
+ raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
60
+
61
+ info = {}
62
+ audio_frames = []
63
+ audio_timebase = _video_opt.default_timebase
64
+
65
+ with av.open(filename, metadata_errors="ignore") as container:
66
+ if container.streams.audio:
67
+ audio_timebase = container.streams.audio[0].time_base
68
+ if container.streams.video:
69
+ video_frames = _read_from_stream(
70
+ container,
71
+ start_pts,
72
+ end_pts,
73
+ pts_unit,
74
+ container.streams.video[0],
75
+ {"video": 0},
76
+ )
77
+ video_fps = container.streams.video[0].average_rate
78
+ # guard against potentially corrupted files
79
+ if video_fps is not None:
80
+ info["video_fps"] = float(video_fps)
81
+
82
+ if container.streams.audio:
83
+ audio_frames = _read_from_stream(
84
+ container,
85
+ start_pts,
86
+ end_pts,
87
+ pts_unit,
88
+ container.streams.audio[0],
89
+ {"audio": 0},
90
+ )
91
+ info["audio_fps"] = container.streams.audio[0].rate
92
+
93
+ aframes_list = [frame.to_ndarray() for frame in audio_frames]
94
+
95
+ vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
96
+
97
+ if aframes_list:
98
+ aframes = np.concatenate(aframes_list, 1)
99
+ aframes = torch.as_tensor(aframes)
100
+ if pts_unit == "sec":
101
+ start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
102
+ if end_pts != float("inf"):
103
+ end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
104
+ aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
105
+ else:
106
+ aframes = torch.empty((1, 0), dtype=torch.float32)
107
+
108
+ if output_format == "TCHW":
109
+ # [T,H,W,C] --> [T,C,H,W]
110
+ vframes = vframes.permute(0, 3, 1, 2)
111
+
112
+ return vframes, aframes, info
113
+
114
+
115
+ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
116
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
117
+ arr = resize(
118
+ arr,
119
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
120
+ interpolation=InterpolationMode.BICUBIC,
121
+ )
122
+ else:
123
+ arr = resize(
124
+ arr,
125
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
126
+ interpolation=InterpolationMode.BICUBIC,
127
+ )
128
+
129
+ h, w = arr.shape[2], arr.shape[3]
130
+ arr = arr.squeeze(0)
131
+
132
+ delta_h = h - image_size[0]
133
+ delta_w = w - image_size[1]
134
+
135
+ if reshape_mode == "random" or reshape_mode == "none":
136
+ top = np.random.randint(0, delta_h + 1)
137
+ left = np.random.randint(0, delta_w + 1)
138
+ elif reshape_mode == "center":
139
+ top, left = delta_h // 2, delta_w // 2
140
+ else:
141
+ raise NotImplementedError
142
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
143
+ return arr
144
+
145
+
146
+ def pad_last_frame(tensor, num_frames):
147
+ # T, H, W, C
148
+ if tensor.shape[0] < num_frames:
149
+ last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
150
+ padded_tensor = torch.cat([tensor, last_frame], dim=0)
151
+ return padded_tensor
152
+ else:
153
+ return tensor[:num_frames]
154
+
155
+
156
+ def load_video(
157
+ video_data,
158
+ sampling="uniform",
159
+ duration=None,
160
+ num_frames=4,
161
+ wanted_fps=None,
162
+ actual_fps=None,
163
+ skip_frms_num=0.0,
164
+ nb_read_frames=None,
165
+ ):
166
+ decord.bridge.set_bridge("torch")
167
+ vr = VideoReader(uri=video_data, height=-1, width=-1)
168
+ if nb_read_frames is not None:
169
+ ori_vlen = nb_read_frames
170
+ else:
171
+ ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
172
+
173
+ max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
174
+ start = random.randint(skip_frms_num, max_seek + 1)
175
+ end = int(start + num_frames / wanted_fps * actual_fps)
176
+ n_frms = num_frames
177
+
178
+ if sampling == "uniform":
179
+ indices = np.arange(start, end, (end - start) / n_frms).astype(int)
180
+ else:
181
+ raise NotImplementedError
182
+
183
+ # get_batch -> T, H, W, C
184
+ temp_frms = vr.get_batch(np.arange(start, end))
185
+ assert temp_frms is not None
186
+ tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
187
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
188
+
189
+ return pad_last_frame(tensor_frms, num_frames)
190
+
191
+
192
+ import threading
193
+
194
+
195
+ def load_video_with_timeout(*args, **kwargs):
196
+ video_container = {}
197
+
198
+ def target_function():
199
+ video = load_video(*args, **kwargs)
200
+ video_container["video"] = video
201
+
202
+ thread = threading.Thread(target=target_function)
203
+ thread.start()
204
+ timeout = 20
205
+ thread.join(timeout)
206
+
207
+ if thread.is_alive():
208
+ print("Loading video timed out")
209
+ raise TimeoutError
210
+ return video_container.get("video", None).contiguous()
211
+
212
+
213
+ def process_video(
214
+ video_path,
215
+ image_size=None,
216
+ duration=None,
217
+ num_frames=4,
218
+ wanted_fps=None,
219
+ actual_fps=None,
220
+ skip_frms_num=0.0,
221
+ nb_read_frames=None,
222
+ ):
223
+ """
224
+ video_path: str or io.BytesIO
225
+ image_size: .
226
+ duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
227
+ num_frames: wanted num_frames.
228
+ wanted_fps: .
229
+ skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
230
+ """
231
+
232
+ video = load_video_with_timeout(
233
+ video_path,
234
+ duration=duration,
235
+ num_frames=num_frames,
236
+ wanted_fps=wanted_fps,
237
+ actual_fps=actual_fps,
238
+ skip_frms_num=skip_frms_num,
239
+ nb_read_frames=nb_read_frames,
240
+ )
241
+
242
+ # --- copy and modify the image process ---
243
+ video = video.permute(0, 3, 1, 2) # [T, C, H, W]
244
+
245
+ # resize
246
+ if image_size is not None:
247
+ video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
248
+
249
+ return video
250
+
251
+
252
+ def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
253
+ while True:
254
+ r = next(src)
255
+ if "mp4" in r:
256
+ video_data = r["mp4"]
257
+ elif "avi" in r:
258
+ video_data = r["avi"]
259
+ else:
260
+ print("No video data found")
261
+ continue
262
+
263
+ if txt_key not in r:
264
+ txt = ""
265
+ else:
266
+ txt = r[txt_key]
267
+
268
+ if isinstance(txt, bytes):
269
+ txt = txt.decode("utf-8")
270
+ else:
271
+ txt = str(txt)
272
+
273
+ duration = r.get("duration", None)
274
+ if duration is not None:
275
+ duration = float(duration)
276
+ else:
277
+ continue
278
+
279
+ actual_fps = r.get("fps", None)
280
+ if actual_fps is not None:
281
+ actual_fps = float(actual_fps)
282
+ else:
283
+ continue
284
+
285
+ required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
286
+ required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
287
+
288
+ if duration is not None and duration < required_duration:
289
+ continue
290
+
291
+ try:
292
+ frames = process_video(
293
+ io.BytesIO(video_data),
294
+ num_frames=num_frames,
295
+ wanted_fps=fps,
296
+ image_size=image_size,
297
+ duration=duration,
298
+ actual_fps=actual_fps,
299
+ skip_frms_num=skip_frms_num,
300
+ )
301
+ frames = (frames - 127.5) / 127.5
302
+ except Exception as e:
303
+ print(e)
304
+ continue
305
+
306
+ item = {
307
+ "mp4": frames,
308
+ "txt": txt,
309
+ "num_frames": num_frames,
310
+ "fps": fps,
311
+ }
312
+
313
+ yield item
314
+
315
+
316
+ class VideoDataset(MetaDistributedWebDataset):
317
+ def __init__(
318
+ self,
319
+ path,
320
+ image_size,
321
+ num_frames,
322
+ fps,
323
+ skip_frms_num=0.0,
324
+ nshards=sys.maxsize,
325
+ seed=1,
326
+ meta_names=None,
327
+ shuffle_buffer=1000,
328
+ include_dirs=None,
329
+ txt_key="caption",
330
+ **kwargs,
331
+ ):
332
+ if seed == -1:
333
+ seed = random.randint(0, 1000000)
334
+ if meta_names is None:
335
+ meta_names = []
336
+
337
+ if path.startswith(";"):
338
+ path, include_dirs = path.split(";", 1)
339
+ super().__init__(
340
+ path,
341
+ partial(
342
+ process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
343
+ ),
344
+ seed,
345
+ meta_names=meta_names,
346
+ shuffle_buffer=shuffle_buffer,
347
+ nshards=nshards,
348
+ include_dirs=include_dirs,
349
+ )
350
+
351
+ @classmethod
352
+ def create_dataset_function(cls, path, args, **kwargs):
353
+ return cls(path, **kwargs)
354
+
355
+
356
+ class SFTDataset(Dataset):
357
+ def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
358
+ """
359
+ skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
360
+ """
361
+ super(SFTDataset, self).__init__()
362
+
363
+ self.videos_list = []
364
+ self.captions_list = []
365
+ self.num_frames_list = []
366
+ self.fps_list = []
367
+
368
+ decord.bridge.set_bridge("torch")
369
+ for root, dirnames, filenames in os.walk(data_dir):
370
+ for filename in filenames:
371
+ if filename.endswith(".mp4"):
372
+ video_path = os.path.join(root, filename)
373
+ vr = VideoReader(uri=video_path, height=-1, width=-1)
374
+ actual_fps = vr.get_avg_fps()
375
+ ori_vlen = len(vr)
376
+
377
+ if ori_vlen / actual_fps * fps > max_num_frames:
378
+ num_frames = max_num_frames
379
+ start = int(skip_frms_num)
380
+ end = int(start + num_frames / fps * actual_fps)
381
+ indices = np.arange(start, end, (end - start) / num_frames).astype(int)
382
+ temp_frms = vr.get_batch(np.arange(start, end))
383
+ assert temp_frms is not None
384
+ tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
385
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
386
+ else:
387
+ if ori_vlen > max_num_frames:
388
+ num_frames = max_num_frames
389
+ start = int(skip_frms_num)
390
+ end = int(ori_vlen - skip_frms_num)
391
+ indices = np.arange(start, end, (end - start) / num_frames).astype(int)
392
+ temp_frms = vr.get_batch(np.arange(start, end))
393
+ assert temp_frms is not None
394
+ tensor_frms = (
395
+ torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
396
+ )
397
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
398
+ else:
399
+
400
+ def nearest_smaller_4k_plus_1(n):
401
+ remainder = n % 4
402
+ if remainder == 0:
403
+ return n - 3
404
+ else:
405
+ return n - remainder + 1
406
+
407
+ start = int(skip_frms_num)
408
+ end = int(ori_vlen - skip_frms_num)
409
+ num_frames = nearest_smaller_4k_plus_1(
410
+ end - start
411
+ ) # 3D VAE requires the number of frames to be 4k+1
412
+ end = int(start + num_frames)
413
+ temp_frms = vr.get_batch(np.arange(start, end))
414
+ assert temp_frms is not None
415
+ tensor_frms = (
416
+ torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
417
+ )
418
+
419
+ tensor_frms = pad_last_frame(
420
+ tensor_frms, num_frames
421
+ ) # the len of indices may be less than num_frames, due to round error
422
+ tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
423
+ tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
424
+ tensor_frms = (tensor_frms - 127.5) / 127.5
425
+ self.videos_list.append(tensor_frms)
426
+
427
+ # caption
428
+ caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
429
+ if os.path.exists(caption_path):
430
+ caption = open(caption_path, "r").read().splitlines()[0]
431
+ else:
432
+ caption = ""
433
+ self.captions_list.append(caption)
434
+ self.num_frames_list.append(num_frames)
435
+ self.fps_list.append(fps)
436
+
437
+ def __getitem__(self, index):
438
+ item = {
439
+ "mp4": self.videos_list[index],
440
+ "txt": self.captions_list[index],
441
+ "num_frames": self.num_frames_list[index],
442
+ "fps": self.fps_list[index],
443
+ }
444
+ return item
445
+
446
+ def __len__(self):
447
+ return len(self.fps_list)
448
+
449
+ @classmethod
450
+ def create_dataset_function(cls, path, args, **kwargs):
451
+ return cls(data_dir=path, **kwargs)
sat/diffusion_video.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Tuple, Union, Optional
4
+ from omegaconf import ListConfig, OmegaConf
5
+ from copy import deepcopy
6
+ import torch.nn.functional as F
7
+
8
+ from sat.helpers import print_rank0
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sgm.modules import UNCONDITIONAL_CONFIG
13
+ from sgm.modules.autoencoding.temporal_ae import VideoDecoder
14
+ from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
15
+ from sgm.util import (
16
+ default,
17
+ disabled_train,
18
+ get_obj_from_str,
19
+ instantiate_from_config,
20
+ log_txt_as_img,
21
+ )
22
+ import gc
23
+ from sat import mpu
24
+ import random
25
+
26
+
27
+ class SATVideoDiffusionEngine(nn.Module):
28
+ def __init__(self, args, **kwargs):
29
+ super().__init__()
30
+
31
+ model_config = args.model_config
32
+ # model args preprocess
33
+ log_keys = model_config.get("log_keys", None)
34
+ input_key = model_config.get("input_key", "mp4")
35
+ network_config = model_config.get("network_config", None)
36
+ network_wrapper = model_config.get("network_wrapper", None)
37
+ denoiser_config = model_config.get("denoiser_config", None)
38
+ sampler_config = model_config.get("sampler_config", None)
39
+ conditioner_config = model_config.get("conditioner_config", None)
40
+ first_stage_config = model_config.get("first_stage_config", None)
41
+ loss_fn_config = model_config.get("loss_fn_config", None)
42
+ scale_factor = model_config.get("scale_factor", 1.0)
43
+ latent_input = model_config.get("latent_input", False)
44
+ disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
45
+ no_cond_log = model_config.get("disable_first_stage_autocast", False)
46
+ not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
47
+ compile_model = model_config.get("compile_model", False)
48
+ en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
49
+ lr_scale = model_config.get("lr_scale", None)
50
+ lora_train = model_config.get("lora_train", False)
51
+ self.use_pd = model_config.get("use_pd", False) # progressive distillation
52
+
53
+ self.log_keys = log_keys
54
+ self.input_key = input_key
55
+ self.not_trainable_prefixes = not_trainable_prefixes
56
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
57
+ self.lr_scale = lr_scale
58
+ self.lora_train = lora_train
59
+ self.noised_image_input = model_config.get("noised_image_input", False)
60
+ self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
61
+ self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
62
+ if args.fp16:
63
+ dtype = torch.float16
64
+ dtype_str = "fp16"
65
+ elif args.bf16:
66
+ dtype = torch.bfloat16
67
+ dtype_str = "bf16"
68
+ else:
69
+ dtype = torch.float32
70
+ dtype_str = "fp32"
71
+ self.dtype = dtype
72
+ self.dtype_str = dtype_str
73
+
74
+ network_config["params"]["dtype"] = dtype_str
75
+ model = instantiate_from_config(network_config)
76
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
77
+ model, compile_model=compile_model, dtype=dtype
78
+ )
79
+
80
+ self.denoiser = instantiate_from_config(denoiser_config)
81
+ self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
82
+ self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
83
+
84
+ self._init_first_stage(first_stage_config)
85
+
86
+ self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
87
+
88
+ self.latent_input = latent_input
89
+ self.scale_factor = scale_factor
90
+ self.disable_first_stage_autocast = disable_first_stage_autocast
91
+ self.no_cond_log = no_cond_log
92
+ self.device = args.device
93
+
94
+ def disable_untrainable_params(self):
95
+ total_trainable = 0
96
+ for n, p in self.named_parameters():
97
+ if p.requires_grad == False:
98
+ continue
99
+ flag = False
100
+ for prefix in self.not_trainable_prefixes:
101
+ if n.startswith(prefix) or prefix == "all":
102
+ flag = True
103
+ break
104
+
105
+ lora_prefix = ["matrix_A", "matrix_B"]
106
+ for prefix in lora_prefix:
107
+ if prefix in n:
108
+ flag = False
109
+ break
110
+
111
+ if flag:
112
+ p.requires_grad_(False)
113
+ else:
114
+ total_trainable += p.numel()
115
+
116
+ print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
117
+
118
+ def reinit(self, parent_model=None):
119
+ # reload the initial params from previous trained modules
120
+ # you can also get access to other mixins through parent_model.get_mixin().
121
+ pass
122
+
123
+ def _init_first_stage(self, config):
124
+ model = instantiate_from_config(config).eval()
125
+ model.train = disabled_train
126
+ for param in model.parameters():
127
+ param.requires_grad = False
128
+ self.first_stage_model = model
129
+
130
+ def forward(self, x, batch):
131
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
132
+ loss_mean = loss.mean()
133
+ loss_dict = {"loss": loss_mean}
134
+ return loss_mean, loss_dict
135
+
136
+ def shared_step(self, batch: Dict) -> Any:
137
+ x = self.get_input(batch)
138
+ if self.lr_scale is not None:
139
+ lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
140
+ lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
141
+ lr_z = self.encode_first_stage(lr_x, batch)
142
+ batch["lr_input"] = lr_z
143
+
144
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
145
+ x = self.encode_first_stage(x, batch)
146
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
147
+
148
+ gc.collect()
149
+ torch.cuda.empty_cache()
150
+ loss, loss_dict = self(x, batch)
151
+ return loss, loss_dict
152
+
153
+ def get_input(self, batch):
154
+ return batch[self.input_key].to(self.dtype)
155
+
156
+ @torch.no_grad()
157
+ def decode_first_stage(self, z):
158
+ z = 1.0 / self.scale_factor * z
159
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
160
+ n_rounds = math.ceil(z.shape[0] / n_samples)
161
+ all_out = []
162
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
163
+ for n in range(n_rounds):
164
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
165
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
166
+ else:
167
+ kwargs = {}
168
+ use_cp = False
169
+ out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
170
+ all_out.append(out)
171
+ out = torch.cat(all_out, dim=0)
172
+ return out
173
+
174
+ @torch.no_grad()
175
+ def encode_first_stage(self, x, batch):
176
+ frame = x.shape[2]
177
+
178
+ if frame > 1 and self.latent_input:
179
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
180
+ return x * self.scale_factor # already encoded
181
+
182
+ use_cp = False
183
+
184
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
185
+ n_rounds = math.ceil(x.shape[0] / n_samples)
186
+ all_out = []
187
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
188
+ for n in range(n_rounds):
189
+ out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
190
+ all_out.append(out)
191
+ z = torch.cat(all_out, dim=0)
192
+ z = self.scale_factor * z
193
+ return z
194
+
195
+ @torch.no_grad()
196
+ def sample(
197
+ self,
198
+ cond: Dict,
199
+ uc: Union[Dict, None] = None,
200
+ batch_size: int = 16,
201
+ shape: Union[None, Tuple, List] = None,
202
+ prefix=None,
203
+ concat_images=None,
204
+ **kwargs,
205
+ ):
206
+ randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
207
+ if hasattr(self, "seeded_noise"):
208
+ randn = self.seeded_noise(randn)
209
+
210
+ if prefix is not None:
211
+ randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
212
+
213
+ # broadcast noise
214
+ mp_size = mpu.get_model_parallel_world_size()
215
+ if mp_size > 1:
216
+ global_rank = torch.distributed.get_rank() // mp_size
217
+ src = global_rank * mp_size
218
+ torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
219
+
220
+ scale = None
221
+ scale_emb = None
222
+
223
+ denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
224
+ self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
225
+ )
226
+
227
+ samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
228
+ samples = samples.to(self.dtype)
229
+ return samples
230
+
231
+ @torch.no_grad()
232
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
233
+ """
234
+ Defines heuristics to log different conditionings.
235
+ These can be lists of strings (text-to-image), tensors, ints, ...
236
+ """
237
+ image_h, image_w = batch[self.input_key].shape[3:]
238
+ log = dict()
239
+
240
+ for embedder in self.conditioner.embedders:
241
+ if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
242
+ x = batch[embedder.input_key][:n]
243
+ if isinstance(x, torch.Tensor):
244
+ if x.dim() == 1:
245
+ # class-conditional, convert integer to string
246
+ x = [str(x[i].item()) for i in range(x.shape[0])]
247
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
248
+ elif x.dim() == 2:
249
+ # size and crop cond and the like
250
+ x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
251
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
252
+ else:
253
+ raise NotImplementedError()
254
+ elif isinstance(x, (List, ListConfig)):
255
+ if isinstance(x[0], str):
256
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
257
+ else:
258
+ raise NotImplementedError()
259
+ else:
260
+ raise NotImplementedError()
261
+ log[embedder.input_key] = xc
262
+ return log
263
+
264
+ @torch.no_grad()
265
+ def log_video(
266
+ self,
267
+ batch: Dict,
268
+ N: int = 8,
269
+ ucg_keys: List[str] = None,
270
+ only_log_video_latents=False,
271
+ **kwargs,
272
+ ) -> Dict:
273
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
274
+ if ucg_keys:
275
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
276
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
277
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
278
+ )
279
+ else:
280
+ ucg_keys = conditioner_input_keys
281
+ log = dict()
282
+
283
+ x = self.get_input(batch)
284
+
285
+ c, uc = self.conditioner.get_unconditional_conditioning(
286
+ batch,
287
+ force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
288
+ )
289
+
290
+ sampling_kwargs = {}
291
+
292
+ N = min(x.shape[0], N)
293
+ x = x.to(self.device)[:N]
294
+ if not self.latent_input:
295
+ log["inputs"] = x.to(torch.float32)
296
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
297
+ z = self.encode_first_stage(x, batch)
298
+ if not only_log_video_latents:
299
+ log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
300
+ log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
301
+ z = z.permute(0, 2, 1, 3, 4).contiguous()
302
+
303
+ log.update(self.log_conditionings(batch, N))
304
+
305
+ for k in c:
306
+ if isinstance(c[k], torch.Tensor):
307
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
308
+
309
+ samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
310
+ samples = samples.permute(0, 2, 1, 3, 4).contiguous()
311
+ if only_log_video_latents:
312
+ latents = 1.0 / self.scale_factor * samples
313
+ log["latents"] = latents
314
+ else:
315
+ samples = self.decode_first_stage(samples).to(torch.float32)
316
+ samples = samples.permute(0, 2, 1, 3, 4).contiguous()
317
+ log["samples"] = samples
318
+ return log
sat/dit_video_concat.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from einops import rearrange, repeat
3
+ import numpy as np
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from sat.model.base_model import BaseModel, non_conflict
10
+ from sat.model.mixins import BaseMixin
11
+ from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
12
+ from sat.mpu.layers import ColumnParallelLinear
13
+ from sgm.util import instantiate_from_config
14
+
15
+ from sgm.modules.diffusionmodules.openaimodel import Timestep
16
+ from sgm.modules.diffusionmodules.util import (
17
+ linear,
18
+ timestep_embedding,
19
+ )
20
+ from sat.ops.layernorm import LayerNorm, RMSNorm
21
+
22
+
23
+ class ImagePatchEmbeddingMixin(BaseMixin):
24
+ def __init__(
25
+ self,
26
+ in_channels,
27
+ hidden_size,
28
+ patch_size,
29
+ bias=True,
30
+ text_hidden_size=None,
31
+ ):
32
+ super().__init__()
33
+ self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
34
+ if text_hidden_size is not None:
35
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
36
+ else:
37
+ self.text_proj = None
38
+
39
+ def word_embedding_forward(self, input_ids, **kwargs):
40
+ # now is 3d patch
41
+ images = kwargs["images"] # (b,t,c,h,w)
42
+ B, T = images.shape[:2]
43
+ emb = images.view(-1, *images.shape[2:])
44
+ emb = self.proj(emb) # ((b t),d,h/2,w/2)
45
+ emb = emb.view(B, T, *emb.shape[1:])
46
+ emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
47
+ emb = rearrange(emb, "b t n d -> b (t n) d")
48
+
49
+ if self.text_proj is not None:
50
+ text_emb = self.text_proj(kwargs["encoder_outputs"])
51
+ emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
52
+
53
+ emb = emb.contiguous()
54
+ return emb # (b,n_t+t*n_i,d)
55
+
56
+ def reinit(self, parent_model=None):
57
+ w = self.proj.weight.data
58
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
59
+ nn.init.constant_(self.proj.bias, 0)
60
+ del self.transformer.word_embeddings
61
+
62
+
63
+ def get_3d_sincos_pos_embed(
64
+ embed_dim,
65
+ grid_height,
66
+ grid_width,
67
+ t_size,
68
+ cls_token=False,
69
+ height_interpolation=1.0,
70
+ width_interpolation=1.0,
71
+ time_interpolation=1.0,
72
+ ):
73
+ """
74
+ grid_size: int of the grid height and width
75
+ t_size: int of the temporal size
76
+ return:
77
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
78
+ """
79
+ assert embed_dim % 4 == 0
80
+ embed_dim_spatial = embed_dim // 4 * 3
81
+ embed_dim_temporal = embed_dim // 4
82
+
83
+ # spatial
84
+ grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
85
+ grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
86
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
87
+ grid = np.stack(grid, axis=0)
88
+
89
+ grid = grid.reshape([2, 1, grid_height, grid_width])
90
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
91
+
92
+ # temporal
93
+ grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
94
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
95
+
96
+ # concate: [T, H, W] order
97
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
98
+ pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
99
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
100
+ pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
101
+
102
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
103
+ # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
104
+
105
+ return pos_embed # [T, H*W, D]
106
+
107
+
108
+ def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
109
+ """
110
+ grid_size: int of the grid height and width
111
+ return:
112
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
113
+ """
114
+ grid_h = np.arange(grid_height, dtype=np.float32)
115
+ grid_w = np.arange(grid_width, dtype=np.float32)
116
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
117
+ grid = np.stack(grid, axis=0)
118
+
119
+ grid = grid.reshape([2, 1, grid_height, grid_width])
120
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
121
+ if cls_token and extra_tokens > 0:
122
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
123
+ return pos_embed
124
+
125
+
126
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
127
+ assert embed_dim % 2 == 0
128
+
129
+ # use half of dimensions to encode grid_h
130
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
131
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
132
+
133
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
134
+ return emb
135
+
136
+
137
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
138
+ """
139
+ embed_dim: output dimension for each position
140
+ pos: a list of positions to be encoded: size (M,)
141
+ out: (M, D)
142
+ """
143
+ assert embed_dim % 2 == 0
144
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
145
+ omega /= embed_dim / 2.0
146
+ omega = 1.0 / 10000**omega # (D/2,)
147
+
148
+ pos = pos.reshape(-1) # (M,)
149
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
150
+
151
+ emb_sin = np.sin(out) # (M, D/2)
152
+ emb_cos = np.cos(out) # (M, D/2)
153
+
154
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
155
+ return emb
156
+
157
+
158
+ class Basic3DPositionEmbeddingMixin(BaseMixin):
159
+ def __init__(
160
+ self,
161
+ height,
162
+ width,
163
+ compressed_num_frames,
164
+ hidden_size,
165
+ text_length=0,
166
+ height_interpolation=1.0,
167
+ width_interpolation=1.0,
168
+ time_interpolation=1.0,
169
+ ):
170
+ super().__init__()
171
+ self.height = height
172
+ self.width = width
173
+ self.text_length = text_length
174
+ self.compressed_num_frames = compressed_num_frames
175
+ self.spatial_length = height * width
176
+ self.num_patches = height * width * compressed_num_frames
177
+ self.pos_embedding = nn.Parameter(
178
+ torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
179
+ )
180
+ self.height_interpolation = height_interpolation
181
+ self.width_interpolation = width_interpolation
182
+ self.time_interpolation = time_interpolation
183
+
184
+ def position_embedding_forward(self, position_ids, **kwargs):
185
+ if kwargs["images"].shape[1] == 1:
186
+ return self.pos_embedding[:, : self.text_length + self.spatial_length]
187
+
188
+ return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
189
+
190
+ def reinit(self, parent_model=None):
191
+ del self.transformer.position_embeddings
192
+ pos_embed = get_3d_sincos_pos_embed(
193
+ self.pos_embedding.shape[-1],
194
+ self.height,
195
+ self.width,
196
+ self.compressed_num_frames,
197
+ height_interpolation=self.height_interpolation,
198
+ width_interpolation=self.width_interpolation,
199
+ time_interpolation=self.time_interpolation,
200
+ )
201
+ pos_embed = torch.from_numpy(pos_embed).float()
202
+ pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
203
+ self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
204
+
205
+
206
+ def broadcat(tensors, dim=-1):
207
+ num_tensors = len(tensors)
208
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
209
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
210
+ shape_len = list(shape_lens)[0]
211
+ dim = (dim + shape_len) if dim < 0 else dim
212
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
213
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
214
+ assert all(
215
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
216
+ ), "invalid dimensions for broadcastable concatentation"
217
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
218
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
219
+ expanded_dims.insert(dim, (dim, dims[dim]))
220
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
221
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
222
+ return torch.cat(tensors, dim=dim)
223
+
224
+
225
+ def rotate_half(x):
226
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
227
+ x1, x2 = x.unbind(dim=-1)
228
+ x = torch.stack((-x2, x1), dim=-1)
229
+ return rearrange(x, "... d r -> ... (d r)")
230
+
231
+
232
+ class Rotary3DPositionEmbeddingMixin(BaseMixin):
233
+ def __init__(
234
+ self,
235
+ height,
236
+ width,
237
+ compressed_num_frames,
238
+ hidden_size,
239
+ hidden_size_head,
240
+ text_length,
241
+ theta=10000,
242
+ rot_v=False,
243
+ pnp=False,
244
+ learnable_pos_embed=False,
245
+ ):
246
+ super().__init__()
247
+ self.rot_v = rot_v
248
+
249
+ dim_t = hidden_size_head // 4
250
+ dim_h = hidden_size_head // 8 * 3
251
+ dim_w = hidden_size_head // 8 * 3
252
+
253
+ # 'lang':
254
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
255
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
256
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
257
+
258
+ grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
259
+ grid_h = torch.arange(height, dtype=torch.float32)
260
+ grid_w = torch.arange(width, dtype=torch.float32)
261
+
262
+ freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
263
+ freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
264
+ freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
265
+
266
+ freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
267
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
268
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
269
+
270
+ freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
271
+ # (T H W D)
272
+
273
+ self.pnp = pnp
274
+
275
+ if not self.pnp:
276
+ freqs = rearrange(freqs, "t h w d -> (t h w) d")
277
+
278
+ freqs = freqs.contiguous()
279
+ freqs_sin = freqs.sin()
280
+ freqs_cos = freqs.cos()
281
+ self.register_buffer("freqs_sin", freqs_sin)
282
+ self.register_buffer("freqs_cos", freqs_cos)
283
+
284
+ self.text_length = text_length
285
+ if learnable_pos_embed:
286
+ num_patches = height * width * compressed_num_frames + text_length
287
+ self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
288
+ else:
289
+ self.pos_embedding = None
290
+
291
+ def rotary(self, t, **kwargs):
292
+ if self.pnp:
293
+ t_coords = kwargs["rope_position_ids"][:, :, 0]
294
+ x_coords = kwargs["rope_position_ids"][:, :, 1]
295
+ y_coords = kwargs["rope_position_ids"][:, :, 2]
296
+ mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
297
+ freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
298
+ freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
299
+
300
+ else:
301
+
302
+ def reshape_freq(freqs):
303
+ frame = t.shape[2]
304
+ freqs = freqs[:frame].contiguous()
305
+ freqs = freqs.unsqueeze(0).unsqueeze(0)
306
+ return freqs
307
+
308
+ freqs_cos = reshape_freq(self.freqs_cos)
309
+ freqs_sin = reshape_freq(self.freqs_sin)
310
+
311
+ return t * freqs_cos + rotate_half(t) * freqs_sin
312
+
313
+ def position_embedding_forward(self, position_ids, **kwargs):
314
+ if self.pos_embedding is not None:
315
+ return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
316
+ else:
317
+ return None
318
+
319
+ def attention_fn(
320
+ self,
321
+ query_layer,
322
+ key_layer,
323
+ value_layer,
324
+ attention_mask,
325
+ attention_dropout=None,
326
+ log_attention_weights=None,
327
+ scaling_attention_score=True,
328
+ **kwargs,
329
+ ):
330
+ attention_fn_default = HOOKS_DEFAULT["attention_fn"]
331
+
332
+ if self.pnp:
333
+ query_layer = self.rotary(query_layer, **kwargs)
334
+ key_layer = self.rotary(key_layer, **kwargs)
335
+ if self.rot_v:
336
+ value_layer = self.rotary(value_layer)
337
+ else:
338
+ query_layer = torch.cat(
339
+ (
340
+ query_layer[
341
+ :,
342
+ :,
343
+ : kwargs["text_length"],
344
+ ],
345
+ self.rotary(
346
+ query_layer[
347
+ :,
348
+ :,
349
+ kwargs["text_length"] :,
350
+ ]
351
+ ),
352
+ ),
353
+ dim=2,
354
+ )
355
+ key_layer = torch.cat(
356
+ (
357
+ key_layer[
358
+ :,
359
+ :,
360
+ : kwargs["text_length"],
361
+ ],
362
+ self.rotary(
363
+ key_layer[
364
+ :,
365
+ :,
366
+ kwargs["text_length"] :,
367
+ ]
368
+ ),
369
+ ),
370
+ dim=2,
371
+ )
372
+ if self.rot_v:
373
+ value_layer = torch.cat(
374
+ (
375
+ value_layer[
376
+ :,
377
+ :,
378
+ : kwargs["text_length"],
379
+ ],
380
+ self.rotary(
381
+ value_layer[
382
+ :,
383
+ :,
384
+ kwargs["text_length"] :,
385
+ ]
386
+ ),
387
+ ),
388
+ dim=2,
389
+ )
390
+
391
+ return attention_fn_default(
392
+ query_layer,
393
+ key_layer,
394
+ value_layer,
395
+ attention_mask,
396
+ attention_dropout=attention_dropout,
397
+ log_attention_weights=log_attention_weights,
398
+ scaling_attention_score=scaling_attention_score,
399
+ **kwargs,
400
+ )
401
+
402
+
403
+ def modulate(x, shift, scale):
404
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
405
+
406
+
407
+ def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
408
+ """
409
+ x: (N, T/2 * S, patch_size**3 * C)
410
+ imgs: (N, T, H, W, C)
411
+ """
412
+ if rope_position_ids is not None:
413
+ assert NotImplementedError
414
+ # do pix2struct unpatchify
415
+ L = x.shape[1]
416
+ x = x.reshape(shape=(x.shape[0], L, p, p, c))
417
+ x = torch.einsum("nlpqc->ncplq", x)
418
+ imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
419
+ else:
420
+ b = x.shape[0]
421
+ imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
422
+
423
+ return imgs
424
+
425
+
426
+ class FinalLayerMixin(BaseMixin):
427
+ def __init__(
428
+ self,
429
+ hidden_size,
430
+ time_embed_dim,
431
+ patch_size,
432
+ out_channels,
433
+ latent_width,
434
+ latent_height,
435
+ elementwise_affine,
436
+ ):
437
+ super().__init__()
438
+ self.hidden_size = hidden_size
439
+ self.patch_size = patch_size
440
+ self.out_channels = out_channels
441
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
442
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
443
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
444
+
445
+ self.spatial_length = latent_width * latent_height // patch_size**2
446
+ self.latent_width = latent_width
447
+ self.latent_height = latent_height
448
+
449
+ def final_forward(self, logits, **kwargs):
450
+ x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
451
+
452
+ shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
453
+ x = modulate(self.norm_final(x), shift, scale)
454
+ x = self.linear(x)
455
+
456
+ return unpatchify(
457
+ x,
458
+ c=self.out_channels,
459
+ p=self.patch_size,
460
+ w=self.latent_width // self.patch_size,
461
+ h=self.latent_height // self.patch_size,
462
+ rope_position_ids=kwargs.get("rope_position_ids", None),
463
+ **kwargs,
464
+ )
465
+
466
+ def reinit(self, parent_model=None):
467
+ nn.init.xavier_uniform_(self.linear.weight)
468
+ nn.init.constant_(self.linear.bias, 0)
469
+
470
+
471
+ class SwiGLUMixin(BaseMixin):
472
+ def __init__(self, num_layers, in_features, hidden_features, bias=False):
473
+ super().__init__()
474
+ self.w2 = nn.ModuleList(
475
+ [
476
+ ColumnParallelLinear(
477
+ in_features,
478
+ hidden_features,
479
+ gather_output=False,
480
+ bias=bias,
481
+ module=self,
482
+ name="dense_h_to_4h_gate",
483
+ )
484
+ for i in range(num_layers)
485
+ ]
486
+ )
487
+
488
+ def mlp_forward(self, hidden_states, **kw_args):
489
+ x = hidden_states
490
+ origin = self.transformer.layers[kw_args["layer_id"]].mlp
491
+ x1 = origin.dense_h_to_4h(x)
492
+ x2 = self.w2[kw_args["layer_id"]](x)
493
+ hidden = origin.activation_func(x2) * x1
494
+ x = origin.dense_4h_to_h(hidden)
495
+ return x
496
+
497
+
498
+ class AdaLNMixin(BaseMixin):
499
+ def __init__(
500
+ self,
501
+ width,
502
+ height,
503
+ hidden_size,
504
+ num_layers,
505
+ time_embed_dim,
506
+ compressed_num_frames,
507
+ qk_ln=True,
508
+ hidden_size_head=None,
509
+ elementwise_affine=True,
510
+ ):
511
+ super().__init__()
512
+ self.num_layers = num_layers
513
+ self.width = width
514
+ self.height = height
515
+ self.compressed_num_frames = compressed_num_frames
516
+
517
+ self.adaLN_modulations = nn.ModuleList(
518
+ [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
519
+ )
520
+
521
+ self.qk_ln = qk_ln
522
+ if qk_ln:
523
+ self.query_layernorm_list = nn.ModuleList(
524
+ [
525
+ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
526
+ for _ in range(num_layers)
527
+ ]
528
+ )
529
+ self.key_layernorm_list = nn.ModuleList(
530
+ [
531
+ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
532
+ for _ in range(num_layers)
533
+ ]
534
+ )
535
+
536
+ def layer_forward(
537
+ self,
538
+ hidden_states,
539
+ mask,
540
+ *args,
541
+ **kwargs,
542
+ ):
543
+ text_length = kwargs["text_length"]
544
+ # hidden_states (b,(n_t+t*n_i),d)
545
+ text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
546
+ img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
547
+ layer = self.transformer.layers[kwargs["layer_id"]]
548
+ adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
549
+
550
+ (
551
+ shift_msa,
552
+ scale_msa,
553
+ gate_msa,
554
+ shift_mlp,
555
+ scale_mlp,
556
+ gate_mlp,
557
+ text_shift_msa,
558
+ text_scale_msa,
559
+ text_gate_msa,
560
+ text_shift_mlp,
561
+ text_scale_mlp,
562
+ text_gate_mlp,
563
+ ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
564
+ gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
565
+ gate_msa.unsqueeze(1),
566
+ gate_mlp.unsqueeze(1),
567
+ text_gate_msa.unsqueeze(1),
568
+ text_gate_mlp.unsqueeze(1),
569
+ )
570
+
571
+ # self full attention (b,(t n),d)
572
+ img_attention_input = layer.input_layernorm(img_hidden_states)
573
+ text_attention_input = layer.input_layernorm(text_hidden_states)
574
+ img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
575
+ text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
576
+
577
+ attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
578
+ attention_output = layer.attention(attention_input, mask, **kwargs)
579
+ text_attention_output = attention_output[:, :text_length] # (b,n,d)
580
+ img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
581
+
582
+ if self.transformer.layernorm_order == "sandwich":
583
+ text_attention_output = layer.third_layernorm(text_attention_output)
584
+ img_attention_output = layer.third_layernorm(img_attention_output)
585
+ img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
586
+ text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
587
+
588
+ # mlp (b,(t n),d)
589
+ img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
590
+ text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
591
+ img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
592
+ text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
593
+ mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
594
+ mlp_output = layer.mlp(mlp_input, **kwargs)
595
+ img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
596
+ text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
597
+ if self.transformer.layernorm_order == "sandwich":
598
+ text_mlp_output = layer.fourth_layernorm(text_mlp_output)
599
+ img_mlp_output = layer.fourth_layernorm(img_mlp_output)
600
+
601
+ img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
602
+ text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
603
+
604
+ hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
605
+ return hidden_states
606
+
607
+ def reinit(self, parent_model=None):
608
+ for layer in self.adaLN_modulations:
609
+ nn.init.constant_(layer[-1].weight, 0)
610
+ nn.init.constant_(layer[-1].bias, 0)
611
+
612
+ @non_conflict
613
+ def attention_fn(
614
+ self,
615
+ query_layer,
616
+ key_layer,
617
+ value_layer,
618
+ attention_mask,
619
+ attention_dropout=None,
620
+ log_attention_weights=None,
621
+ scaling_attention_score=True,
622
+ old_impl=attention_fn_default,
623
+ **kwargs,
624
+ ):
625
+ if self.qk_ln:
626
+ query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
627
+ key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
628
+ query_layer = query_layernorm(query_layer)
629
+ key_layer = key_layernorm(key_layer)
630
+
631
+ return old_impl(
632
+ query_layer,
633
+ key_layer,
634
+ value_layer,
635
+ attention_mask,
636
+ attention_dropout=attention_dropout,
637
+ log_attention_weights=log_attention_weights,
638
+ scaling_attention_score=scaling_attention_score,
639
+ **kwargs,
640
+ )
641
+
642
+
643
+ str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
644
+
645
+
646
+ class DiffusionTransformer(BaseModel):
647
+ def __init__(
648
+ self,
649
+ transformer_args,
650
+ num_frames,
651
+ time_compressed_rate,
652
+ latent_width,
653
+ latent_height,
654
+ patch_size,
655
+ in_channels,
656
+ out_channels,
657
+ hidden_size,
658
+ num_layers,
659
+ num_attention_heads,
660
+ elementwise_affine,
661
+ time_embed_dim=None,
662
+ num_classes=None,
663
+ modules={},
664
+ input_time="adaln",
665
+ adm_in_channels=None,
666
+ parallel_output=True,
667
+ height_interpolation=1.0,
668
+ width_interpolation=1.0,
669
+ time_interpolation=1.0,
670
+ use_SwiGLU=False,
671
+ use_RMSNorm=False,
672
+ zero_init_y_embed=False,
673
+ **kwargs,
674
+ ):
675
+ self.latent_width = latent_width
676
+ self.latent_height = latent_height
677
+ self.patch_size = patch_size
678
+ self.num_frames = num_frames
679
+ self.time_compressed_rate = time_compressed_rate
680
+ self.spatial_length = latent_width * latent_height // patch_size**2
681
+ self.in_channels = in_channels
682
+ self.out_channels = out_channels
683
+ self.hidden_size = hidden_size
684
+ self.model_channels = hidden_size
685
+ self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
686
+ self.num_classes = num_classes
687
+ self.adm_in_channels = adm_in_channels
688
+ self.input_time = input_time
689
+ self.num_layers = num_layers
690
+ self.num_attention_heads = num_attention_heads
691
+ self.is_decoder = transformer_args.is_decoder
692
+ self.elementwise_affine = elementwise_affine
693
+ self.height_interpolation = height_interpolation
694
+ self.width_interpolation = width_interpolation
695
+ self.time_interpolation = time_interpolation
696
+ self.inner_hidden_size = hidden_size * 4
697
+ self.zero_init_y_embed = zero_init_y_embed
698
+ try:
699
+ self.dtype = str_to_dtype[kwargs.pop("dtype")]
700
+ except:
701
+ self.dtype = torch.float32
702
+
703
+ if use_SwiGLU:
704
+ kwargs["activation_func"] = F.silu
705
+ elif "activation_func" not in kwargs:
706
+ approx_gelu = nn.GELU(approximate="tanh")
707
+ kwargs["activation_func"] = approx_gelu
708
+
709
+ if use_RMSNorm:
710
+ kwargs["layernorm"] = RMSNorm
711
+ else:
712
+ kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
713
+
714
+ transformer_args.num_layers = num_layers
715
+ transformer_args.hidden_size = hidden_size
716
+ transformer_args.num_attention_heads = num_attention_heads
717
+ transformer_args.parallel_output = parallel_output
718
+ super().__init__(args=transformer_args, transformer=None, **kwargs)
719
+
720
+ module_configs = modules
721
+ self._build_modules(module_configs)
722
+
723
+ if use_SwiGLU:
724
+ self.add_mixin(
725
+ "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
726
+ )
727
+
728
+ def _build_modules(self, module_configs):
729
+ model_channels = self.hidden_size
730
+ # time_embed_dim = model_channels * 4
731
+ time_embed_dim = self.time_embed_dim
732
+ self.time_embed = nn.Sequential(
733
+ linear(model_channels, time_embed_dim),
734
+ nn.SiLU(),
735
+ linear(time_embed_dim, time_embed_dim),
736
+ )
737
+
738
+ if self.num_classes is not None:
739
+ if isinstance(self.num_classes, int):
740
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
741
+ elif self.num_classes == "continuous":
742
+ print("setting up linear c_adm embedding layer")
743
+ self.label_emb = nn.Linear(1, time_embed_dim)
744
+ elif self.num_classes == "timestep":
745
+ self.label_emb = nn.Sequential(
746
+ Timestep(model_channels),
747
+ nn.Sequential(
748
+ linear(model_channels, time_embed_dim),
749
+ nn.SiLU(),
750
+ linear(time_embed_dim, time_embed_dim),
751
+ ),
752
+ )
753
+ elif self.num_classes == "sequential":
754
+ assert self.adm_in_channels is not None
755
+ self.label_emb = nn.Sequential(
756
+ nn.Sequential(
757
+ linear(self.adm_in_channels, time_embed_dim),
758
+ nn.SiLU(),
759
+ linear(time_embed_dim, time_embed_dim),
760
+ )
761
+ )
762
+ if self.zero_init_y_embed:
763
+ nn.init.constant_(self.label_emb[0][2].weight, 0)
764
+ nn.init.constant_(self.label_emb[0][2].bias, 0)
765
+ else:
766
+ raise ValueError()
767
+
768
+ pos_embed_config = module_configs["pos_embed_config"]
769
+ self.add_mixin(
770
+ "pos_embed",
771
+ instantiate_from_config(
772
+ pos_embed_config,
773
+ height=self.latent_height // self.patch_size,
774
+ width=self.latent_width // self.patch_size,
775
+ compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
776
+ hidden_size=self.hidden_size,
777
+ ),
778
+ reinit=True,
779
+ )
780
+
781
+ patch_embed_config = module_configs["patch_embed_config"]
782
+ self.add_mixin(
783
+ "patch_embed",
784
+ instantiate_from_config(
785
+ patch_embed_config,
786
+ patch_size=self.patch_size,
787
+ hidden_size=self.hidden_size,
788
+ in_channels=self.in_channels,
789
+ ),
790
+ reinit=True,
791
+ )
792
+ if self.input_time == "adaln":
793
+ adaln_layer_config = module_configs["adaln_layer_config"]
794
+ self.add_mixin(
795
+ "adaln_layer",
796
+ instantiate_from_config(
797
+ adaln_layer_config,
798
+ height=self.latent_height // self.patch_size,
799
+ width=self.latent_width // self.patch_size,
800
+ hidden_size=self.hidden_size,
801
+ num_layers=self.num_layers,
802
+ compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
803
+ hidden_size_head=self.hidden_size // self.num_attention_heads,
804
+ time_embed_dim=self.time_embed_dim,
805
+ elementwise_affine=self.elementwise_affine,
806
+ ),
807
+ )
808
+ else:
809
+ raise NotImplementedError
810
+
811
+ final_layer_config = module_configs["final_layer_config"]
812
+ self.add_mixin(
813
+ "final_layer",
814
+ instantiate_from_config(
815
+ final_layer_config,
816
+ hidden_size=self.hidden_size,
817
+ patch_size=self.patch_size,
818
+ out_channels=self.out_channels,
819
+ time_embed_dim=self.time_embed_dim,
820
+ latent_width=self.latent_width,
821
+ latent_height=self.latent_height,
822
+ elementwise_affine=self.elementwise_affine,
823
+ ),
824
+ reinit=True,
825
+ )
826
+
827
+ if "lora_config" in module_configs:
828
+ lora_config = module_configs["lora_config"]
829
+ self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
830
+
831
+ return
832
+
833
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
834
+ b, t, d, h, w = x.shape
835
+ if x.dtype != self.dtype:
836
+ x = x.to(self.dtype)
837
+ assert (y is not None) == (
838
+ self.num_classes is not None
839
+ ), "must specify y if and only if the model is class-conditional"
840
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
841
+ emb = self.time_embed(t_emb)
842
+
843
+ if self.num_classes is not None:
844
+ # assert y.shape[0] == x.shape[0]
845
+ assert x.shape[0] % y.shape[0] == 0
846
+ y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
847
+ emb = emb + self.label_emb(y)
848
+
849
+ kwargs["seq_length"] = t * h * w // (self.patch_size**2)
850
+ kwargs["images"] = x
851
+ kwargs["emb"] = emb
852
+ kwargs["encoder_outputs"] = context
853
+ kwargs["text_length"] = context.shape[1]
854
+
855
+ kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
856
+ output = super().forward(**kwargs)[0]
857
+
858
+ return output
sat/finetune.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
4
+
5
+ environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
6
+
7
+ run_cmd="$environs python train_video.py --base configs/cogvideox_2b_sft.yaml --seed $RANDOM"
8
+
9
+ echo ${run_cmd}
10
+ eval ${run_cmd}
11
+
12
+ echo "DONE on `hostname`"
sat/inference.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
4
+
5
+ environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
6
+
7
+ run_cmd="$environs python sample_video.py --base configs/cogvideox_2b_infer.yaml"
8
+
9
+ echo ${run_cmd}
10
+ eval ${run_cmd}
11
+
12
+ echo "DONE on `hostname`"
sat/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/spacegoing/SwissArmyTransformer.git
2
+ diffusers>=0.29.2
3
+ omegaconf>=2.3.0
4
+ torch>=2.3.1
5
+ torchvision>=0.19.0
6
+ pytorch_lightning>=2.3.3
7
+ kornia>=0.7.3
8
+ beartype>=0.18.5
9
+ numpy>=2.0.1
10
+ fsspec>=2024.5.0
11
+ safetensors>=0.4.3
12
+ imageio-ffmpeg>=0.5.1
13
+ imageio>=2.34.2
14
+ scipy>=1.14.0
15
+ decord>=0.6.0
16
+ wandb>=0.17.5
17
+ deepspeed>=0.14.4
sat/sample_video.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ from typing import List, Union
5
+ from tqdm import tqdm
6
+ from omegaconf import ListConfig
7
+ import imageio
8
+
9
+ import torch
10
+ import numpy as np
11
+ from einops import rearrange
12
+ import torchvision.transforms as TT
13
+
14
+ from sat.model.base_model import get_model
15
+ from sat.training.model_io import load_checkpoint
16
+ from sat import mpu
17
+
18
+ from diffusion_video import SATVideoDiffusionEngine
19
+ from arguments import get_args
20
+ from torchvision.transforms.functional import center_crop, resize
21
+ from torchvision.transforms import InterpolationMode
22
+
23
+
24
+ def read_from_cli():
25
+ cnt = 0
26
+ try:
27
+ while True:
28
+ x = input("Please input English text (Ctrl-D quit): ")
29
+ yield x.strip(), cnt
30
+ cnt += 1
31
+ except EOFError as e:
32
+ pass
33
+
34
+
35
+ def read_from_file(p, rank=0, world_size=1):
36
+ with open(p, "r") as fin:
37
+ cnt = -1
38
+ for l in fin:
39
+ cnt += 1
40
+ if cnt % world_size != rank:
41
+ continue
42
+ yield l.strip(), cnt
43
+
44
+
45
+ def get_unique_embedder_keys_from_conditioner(conditioner):
46
+ return list(set([x.input_key for x in conditioner.embedders]))
47
+
48
+
49
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
50
+ batch = {}
51
+ batch_uc = {}
52
+
53
+ for key in keys:
54
+ if key == "txt":
55
+ batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
56
+ batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
57
+ else:
58
+ batch[key] = value_dict[key]
59
+
60
+ if T is not None:
61
+ batch["num_video_frames"] = T
62
+
63
+ for key in batch.keys():
64
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
65
+ batch_uc[key] = torch.clone(batch[key])
66
+ return batch, batch_uc
67
+
68
+
69
+ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
70
+ os.makedirs(save_path, exist_ok=True)
71
+
72
+ for i, vid in enumerate(video_batch):
73
+ gif_frames = []
74
+ for frame in vid:
75
+ frame = rearrange(frame, "c h w -> h w c")
76
+ frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
77
+ gif_frames.append(frame)
78
+ now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
79
+ with imageio.get_writer(now_save_path, fps=fps) as writer:
80
+ for frame in gif_frames:
81
+ writer.append_data(frame)
82
+
83
+
84
+ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
85
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
86
+ arr = resize(
87
+ arr,
88
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
89
+ interpolation=InterpolationMode.BICUBIC,
90
+ )
91
+ else:
92
+ arr = resize(
93
+ arr,
94
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
95
+ interpolation=InterpolationMode.BICUBIC,
96
+ )
97
+
98
+ h, w = arr.shape[2], arr.shape[3]
99
+ arr = arr.squeeze(0)
100
+
101
+ delta_h = h - image_size[0]
102
+ delta_w = w - image_size[1]
103
+
104
+ if reshape_mode == "random" or reshape_mode == "none":
105
+ top = np.random.randint(0, delta_h + 1)
106
+ left = np.random.randint(0, delta_w + 1)
107
+ elif reshape_mode == "center":
108
+ top, left = delta_h // 2, delta_w // 2
109
+ else:
110
+ raise NotImplementedError
111
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
112
+ return arr
113
+
114
+
115
+ def sampling_main(args, model_cls):
116
+ if isinstance(model_cls, type):
117
+ model = get_model(args, model_cls)
118
+ else:
119
+ model = model_cls
120
+
121
+ load_checkpoint(model, args)
122
+ model.eval()
123
+
124
+ if args.input_type == "cli":
125
+ data_iter = read_from_cli()
126
+ elif args.input_type == "txt":
127
+ rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
128
+ print("rank and world_size", rank, world_size)
129
+ data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
130
+ else:
131
+ raise NotImplementedError
132
+
133
+ image_size = [480, 720]
134
+
135
+ sample_func = model.sample
136
+ T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
137
+ num_samples = [1]
138
+ force_uc_zero_embeddings = ["txt"]
139
+ device = model.device
140
+ with torch.no_grad():
141
+ for text, cnt in tqdm(data_iter):
142
+ # reload model on GPU
143
+ model.to(device)
144
+ print("rank:", rank, "start to process", text, cnt)
145
+ # TODO: broadcast image2video
146
+ value_dict = {
147
+ "prompt": text,
148
+ "negative_prompt": "",
149
+ "num_frames": torch.tensor(T).unsqueeze(0),
150
+ }
151
+
152
+ batch, batch_uc = get_batch(
153
+ get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
154
+ )
155
+ for key in batch:
156
+ if isinstance(batch[key], torch.Tensor):
157
+ print(key, batch[key].shape)
158
+ elif isinstance(batch[key], list):
159
+ print(key, [len(l) for l in batch[key]])
160
+ else:
161
+ print(key, batch[key])
162
+ c, uc = model.conditioner.get_unconditional_conditioning(
163
+ batch,
164
+ batch_uc=batch_uc,
165
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
166
+ )
167
+
168
+ for k in c:
169
+ if not k == "crossattn":
170
+ c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
171
+ for index in range(args.batch_size):
172
+ # reload model on GPU
173
+ model.to(device)
174
+ samples_z = sample_func(
175
+ c,
176
+ uc=uc,
177
+ batch_size=1,
178
+ shape=(T, C, H // F, W // F),
179
+ )
180
+ samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
181
+
182
+ # Unload the model from GPU to save GPU memory
183
+ model.to("cpu")
184
+ torch.cuda.empty_cache()
185
+ first_stage_model = model.first_stage_model
186
+ first_stage_model = first_stage_model.to(device)
187
+
188
+ latent = 1.0 / model.scale_factor * samples_z
189
+
190
+ # Decode latent serial to save GPU memory
191
+ recons = []
192
+ loop_num = (T - 1) // 2
193
+ for i in range(loop_num):
194
+ if i == 0:
195
+ start_frame, end_frame = 0, 3
196
+ else:
197
+ start_frame, end_frame = i * 2 + 1, i * 2 + 3
198
+ if i == loop_num - 1:
199
+ clear_fake_cp_cache = True
200
+ else:
201
+ clear_fake_cp_cache = False
202
+ with torch.no_grad():
203
+ recon = first_stage_model.decode(
204
+ latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
205
+ )
206
+
207
+ recons.append(recon)
208
+
209
+ recon = torch.cat(recons, dim=2).to(torch.float32)
210
+ samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
211
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
212
+
213
+ save_path = os.path.join(
214
+ args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
215
+ )
216
+ if mpu.get_model_parallel_rank() == 0:
217
+ save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
218
+
219
+
220
+ if __name__ == "__main__":
221
+ if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
222
+ os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
223
+ os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
224
+ os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
225
+ py_parser = argparse.ArgumentParser(add_help=False)
226
+ known, args_list = py_parser.parse_known_args()
227
+
228
+ args = get_args(args_list)
229
+ args = argparse.Namespace(**vars(args), **vars(known))
230
+ del args.deepspeed_config
231
+ args.model_config.first_stage_config.params.cp_size = 1
232
+ args.model_config.network_config.params.transformer_args.model_parallel_size = 1
233
+ args.model_config.network_config.params.transformer_args.checkpoint_activations = False
234
+ args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
235
+
236
+ sampling_main(args, model_cls=SATVideoDiffusionEngine)
sat/sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sat/sgm/lr_scheduler.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
32
+ self.last_lr = lr
33
+ return lr
34
+ else:
35
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
36
+ t = min(t, 1.0)
37
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
38
+ self.last_lr = lr
39
+ return lr
40
+
41
+ def __call__(self, n, **kwargs):
42
+ return self.schedule(n, **kwargs)
43
+
44
+
45
+ class LambdaWarmUpCosineScheduler2:
46
+ """
47
+ supports repeated iterations, configurable via lists
48
+ note: use with a base_lr of 1.0.
49
+ """
50
+
51
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
52
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
53
+ self.lr_warm_up_steps = warm_up_steps
54
+ self.f_start = f_start
55
+ self.f_min = f_min
56
+ self.f_max = f_max
57
+ self.cycle_lengths = cycle_lengths
58
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
59
+ self.last_f = 0.0
60
+ self.verbosity_interval = verbosity_interval
61
+
62
+ def find_in_interval(self, n):
63
+ interval = 0
64
+ for cl in self.cum_cycles[1:]:
65
+ if n <= cl:
66
+ return interval
67
+ interval += 1
68
+
69
+ def schedule(self, n, **kwargs):
70
+ cycle = self.find_in_interval(n)
71
+ n = n - self.cum_cycles[cycle]
72
+ if self.verbosity_interval > 0:
73
+ if n % self.verbosity_interval == 0:
74
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
75
+ if n < self.lr_warm_up_steps[cycle]:
76
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
77
+ self.last_f = f
78
+ return f
79
+ else:
80
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
81
+ t = min(t, 1.0)
82
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
83
+ self.last_f = f
84
+ return f
85
+
86
+ def __call__(self, n, **kwargs):
87
+ return self.schedule(n, **kwargs)
88
+
89
+
90
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
91
+ def schedule(self, n, **kwargs):
92
+ cycle = self.find_in_interval(n)
93
+ n = n - self.cum_cycles[cycle]
94
+ if self.verbosity_interval > 0:
95
+ if n % self.verbosity_interval == 0:
96
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
97
+
98
+ if n < self.lr_warm_up_steps[cycle]:
99
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
100
+ self.last_f = f
101
+ return f
102
+ else:
103
+ f = (
104
+ self.f_min[cycle]
105
+ + (self.f_max[cycle] - self.f_min[cycle])
106
+ * (self.cycle_lengths[cycle] - n)
107
+ / (self.cycle_lengths[cycle])
108
+ )
109
+ self.last_f = f
110
+ return f
sat/sgm/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoencodingEngine
sat/sgm/models/autoencoder.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ import random
5
+ from abc import abstractmethod
6
+ from contextlib import contextmanager
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ import torch
12
+ import torch.distributed
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+ from packaging import version
16
+
17
+ from ..modules.autoencoding.regularizers import AbstractRegularizer
18
+ from ..modules.ema import LitEma
19
+ from ..util import (
20
+ default,
21
+ get_nested_attribute,
22
+ get_obj_from_str,
23
+ instantiate_from_config,
24
+ initialize_context_parallel,
25
+ get_context_parallel_group,
26
+ get_context_parallel_group_rank,
27
+ is_context_parallel_initialized,
28
+ )
29
+ from ..modules.cp_enc_dec import _conv_split, _conv_gather
30
+
31
+ logpy = logging.getLogger(__name__)
32
+
33
+
34
+ class AbstractAutoencoder(pl.LightningModule):
35
+ """
36
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
37
+ unCLIP models, etc. Hence, it is fairly general, and specific features
38
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ ema_decay: Union[None, float] = None,
44
+ monitor: Union[None, str] = None,
45
+ input_key: str = "jpg",
46
+ ):
47
+ super().__init__()
48
+
49
+ self.input_key = input_key
50
+ self.use_ema = ema_decay is not None
51
+ if monitor is not None:
52
+ self.monitor = monitor
53
+
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self, decay=ema_decay)
56
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
59
+ self.automatic_optimization = False
60
+
61
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
62
+ if ckpt is None:
63
+ return
64
+ if isinstance(ckpt, str):
65
+ ckpt = {
66
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
67
+ "params": {"ckpt_path": ckpt},
68
+ }
69
+ engine = instantiate_from_config(ckpt)
70
+ engine(self)
71
+
72
+ @abstractmethod
73
+ def get_input(self, batch) -> Any:
74
+ raise NotImplementedError()
75
+
76
+ def on_train_batch_end(self, *args, **kwargs):
77
+ # for EMA computation
78
+ if self.use_ema:
79
+ self.model_ema(self)
80
+
81
+ @contextmanager
82
+ def ema_scope(self, context=None):
83
+ if self.use_ema:
84
+ self.model_ema.store(self.parameters())
85
+ self.model_ema.copy_to(self)
86
+ if context is not None:
87
+ logpy.info(f"{context}: Switched to EMA weights")
88
+ try:
89
+ yield None
90
+ finally:
91
+ if self.use_ema:
92
+ self.model_ema.restore(self.parameters())
93
+ if context is not None:
94
+ logpy.info(f"{context}: Restored training weights")
95
+
96
+ @abstractmethod
97
+ def encode(self, *args, **kwargs) -> torch.Tensor:
98
+ raise NotImplementedError("encode()-method of abstract base class called")
99
+
100
+ @abstractmethod
101
+ def decode(self, *args, **kwargs) -> torch.Tensor:
102
+ raise NotImplementedError("decode()-method of abstract base class called")
103
+
104
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
105
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
106
+ return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
107
+
108
+ def configure_optimizers(self) -> Any:
109
+ raise NotImplementedError()
110
+
111
+
112
+ class AutoencodingEngine(AbstractAutoencoder):
113
+ """
114
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
115
+ (we also restore them explicitly as special cases for legacy reasons).
116
+ Regularizations such as KL or VQ are moved to the regularizer class.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ *args,
122
+ encoder_config: Dict,
123
+ decoder_config: Dict,
124
+ loss_config: Dict,
125
+ regularizer_config: Dict,
126
+ optimizer_config: Union[Dict, None] = None,
127
+ lr_g_factor: float = 1.0,
128
+ trainable_ae_params: Optional[List[List[str]]] = None,
129
+ ae_optimizer_args: Optional[List[dict]] = None,
130
+ trainable_disc_params: Optional[List[List[str]]] = None,
131
+ disc_optimizer_args: Optional[List[dict]] = None,
132
+ disc_start_iter: int = 0,
133
+ diff_boost_factor: float = 3.0,
134
+ ckpt_engine: Union[None, str, dict] = None,
135
+ ckpt_path: Optional[str] = None,
136
+ additional_decode_keys: Optional[List[str]] = None,
137
+ **kwargs,
138
+ ):
139
+ super().__init__(*args, **kwargs)
140
+ self.automatic_optimization = False # pytorch lightning
141
+
142
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
143
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
144
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
145
+ self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config)
146
+ self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
147
+ self.diff_boost_factor = diff_boost_factor
148
+ self.disc_start_iter = disc_start_iter
149
+ self.lr_g_factor = lr_g_factor
150
+ self.trainable_ae_params = trainable_ae_params
151
+ if self.trainable_ae_params is not None:
152
+ self.ae_optimizer_args = default(
153
+ ae_optimizer_args,
154
+ [{} for _ in range(len(self.trainable_ae_params))],
155
+ )
156
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
157
+ else:
158
+ self.ae_optimizer_args = [{}] # makes type consitent
159
+
160
+ self.trainable_disc_params = trainable_disc_params
161
+ if self.trainable_disc_params is not None:
162
+ self.disc_optimizer_args = default(
163
+ disc_optimizer_args,
164
+ [{} for _ in range(len(self.trainable_disc_params))],
165
+ )
166
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
167
+ else:
168
+ self.disc_optimizer_args = [{}] # makes type consitent
169
+
170
+ if ckpt_path is not None:
171
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
172
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
173
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
174
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
175
+
176
+ def get_input(self, batch: Dict) -> torch.Tensor:
177
+ # assuming unified data format, dataloader returns a dict.
178
+ # image tensors should be scaled to -1 ... 1 and in channels-first
179
+ # format (e.g., bchw instead if bhwc)
180
+ return batch[self.input_key]
181
+
182
+ def get_autoencoder_params(self) -> list:
183
+ params = []
184
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
185
+ params += list(self.loss.get_trainable_autoencoder_parameters())
186
+ if hasattr(self.regularization, "get_trainable_parameters"):
187
+ params += list(self.regularization.get_trainable_parameters())
188
+ params = params + list(self.encoder.parameters())
189
+ params = params + list(self.decoder.parameters())
190
+ return params
191
+
192
+ def get_discriminator_params(self) -> list:
193
+ if hasattr(self.loss, "get_trainable_parameters"):
194
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
195
+ else:
196
+ params = []
197
+ return params
198
+
199
+ def get_last_layer(self):
200
+ return self.decoder.get_last_layer()
201
+
202
+ def encode(
203
+ self,
204
+ x: torch.Tensor,
205
+ return_reg_log: bool = False,
206
+ unregularized: bool = False,
207
+ **kwargs,
208
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
209
+ z = self.encoder(x, **kwargs)
210
+ if unregularized:
211
+ return z, dict()
212
+ z, reg_log = self.regularization(z)
213
+ if return_reg_log:
214
+ return z, reg_log
215
+ return z
216
+
217
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
218
+ x = self.decoder(z, **kwargs)
219
+ return x
220
+
221
+ def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
222
+ z, reg_log = self.encode(x, return_reg_log=True)
223
+ dec = self.decode(z, **additional_decode_kwargs)
224
+ return z, dec, reg_log
225
+
226
+ def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
227
+ x = self.get_input(batch)
228
+ additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
229
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
230
+ if hasattr(self.loss, "forward_keys"):
231
+ extra_info = {
232
+ "z": z,
233
+ "optimizer_idx": optimizer_idx,
234
+ "global_step": self.global_step,
235
+ "last_layer": self.get_last_layer(),
236
+ "split": "train",
237
+ "regularization_log": regularization_log,
238
+ "autoencoder": self,
239
+ }
240
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
241
+ else:
242
+ extra_info = dict()
243
+
244
+ if optimizer_idx == 0:
245
+ # autoencode
246
+ out_loss = self.loss(x, xrec, **extra_info)
247
+ if isinstance(out_loss, tuple):
248
+ aeloss, log_dict_ae = out_loss
249
+ else:
250
+ # simple loss function
251
+ aeloss = out_loss
252
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
253
+
254
+ self.log_dict(
255
+ log_dict_ae,
256
+ prog_bar=False,
257
+ logger=True,
258
+ on_step=True,
259
+ on_epoch=True,
260
+ sync_dist=False,
261
+ )
262
+ self.log(
263
+ "loss",
264
+ aeloss.mean().detach(),
265
+ prog_bar=True,
266
+ logger=False,
267
+ on_epoch=False,
268
+ on_step=True,
269
+ )
270
+ return aeloss
271
+ elif optimizer_idx == 1:
272
+ # discriminator
273
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
274
+ # -> discriminator always needs to return a tuple
275
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
276
+ return discloss
277
+ else:
278
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
279
+
280
+ def training_step(self, batch: dict, batch_idx: int):
281
+ opts = self.optimizers()
282
+ if not isinstance(opts, list):
283
+ # Non-adversarial case
284
+ opts = [opts]
285
+ optimizer_idx = batch_idx % len(opts)
286
+ if self.global_step < self.disc_start_iter:
287
+ optimizer_idx = 0
288
+ opt = opts[optimizer_idx]
289
+ opt.zero_grad()
290
+ with opt.toggle_model():
291
+ loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
292
+ self.manual_backward(loss)
293
+ opt.step()
294
+
295
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
296
+ log_dict = self._validation_step(batch, batch_idx)
297
+ with self.ema_scope():
298
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
299
+ log_dict.update(log_dict_ema)
300
+ return log_dict
301
+
302
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
303
+ x = self.get_input(batch)
304
+
305
+ z, xrec, regularization_log = self(x)
306
+ if hasattr(self.loss, "forward_keys"):
307
+ extra_info = {
308
+ "z": z,
309
+ "optimizer_idx": 0,
310
+ "global_step": self.global_step,
311
+ "last_layer": self.get_last_layer(),
312
+ "split": "val" + postfix,
313
+ "regularization_log": regularization_log,
314
+ "autoencoder": self,
315
+ }
316
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
317
+ else:
318
+ extra_info = dict()
319
+ out_loss = self.loss(x, xrec, **extra_info)
320
+ if isinstance(out_loss, tuple):
321
+ aeloss, log_dict_ae = out_loss
322
+ else:
323
+ # simple loss function
324
+ aeloss = out_loss
325
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
326
+ full_log_dict = log_dict_ae
327
+
328
+ if "optimizer_idx" in extra_info:
329
+ extra_info["optimizer_idx"] = 1
330
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
331
+ full_log_dict.update(log_dict_disc)
332
+ self.log(
333
+ f"val{postfix}/loss/rec",
334
+ log_dict_ae[f"val{postfix}/loss/rec"],
335
+ sync_dist=True,
336
+ )
337
+ self.log_dict(full_log_dict, sync_dist=True)
338
+ return full_log_dict
339
+
340
+ def get_param_groups(
341
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
342
+ ) -> Tuple[List[Dict[str, Any]], int]:
343
+ groups = []
344
+ num_params = 0
345
+ for names, args in zip(parameter_names, optimizer_args):
346
+ params = []
347
+ for pattern_ in names:
348
+ pattern_params = []
349
+ pattern = re.compile(pattern_)
350
+ for p_name, param in self.named_parameters():
351
+ if re.match(pattern, p_name):
352
+ pattern_params.append(param)
353
+ num_params += param.numel()
354
+ if len(pattern_params) == 0:
355
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
356
+ params.extend(pattern_params)
357
+ groups.append({"params": params, **args})
358
+ return groups, num_params
359
+
360
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
361
+ if self.trainable_ae_params is None:
362
+ ae_params = self.get_autoencoder_params()
363
+ else:
364
+ ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
365
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
366
+ if self.trainable_disc_params is None:
367
+ disc_params = self.get_discriminator_params()
368
+ else:
369
+ disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
370
+ logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
371
+ opt_ae = self.instantiate_optimizer_from_config(
372
+ ae_params,
373
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
374
+ self.optimizer_config,
375
+ )
376
+ opts = [opt_ae]
377
+ if len(disc_params) > 0:
378
+ opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
379
+ opts.append(opt_disc)
380
+
381
+ return opts
382
+
383
+ @torch.no_grad()
384
+ def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
385
+ log = dict()
386
+ additional_decode_kwargs = {}
387
+ x = self.get_input(batch)
388
+ additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
389
+
390
+ _, xrec, _ = self(x, **additional_decode_kwargs)
391
+ log["inputs"] = x
392
+ log["reconstructions"] = xrec
393
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
394
+ diff.clamp_(0, 1.0)
395
+ log["diff"] = 2.0 * diff - 1.0
396
+ # diff_boost shows location of small errors, by boosting their
397
+ # brightness.
398
+ log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
399
+ if hasattr(self.loss, "log_images"):
400
+ log.update(self.loss.log_images(x, xrec))
401
+ with self.ema_scope():
402
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
403
+ log["reconstructions_ema"] = xrec_ema
404
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
405
+ diff_ema.clamp_(0, 1.0)
406
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
407
+ log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
408
+ if additional_log_kwargs:
409
+ additional_decode_kwargs.update(additional_log_kwargs)
410
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
411
+ log_str = "reconstructions-" + "-".join(
412
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
413
+ )
414
+ log[log_str] = xrec_add
415
+ return log
416
+
417
+
418
+ class AutoencodingEngineLegacy(AutoencodingEngine):
419
+ def __init__(self, embed_dim: int, **kwargs):
420
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
421
+ ddconfig = kwargs.pop("ddconfig")
422
+ ckpt_path = kwargs.pop("ckpt_path", None)
423
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
424
+ super().__init__(
425
+ encoder_config={
426
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
427
+ "params": ddconfig,
428
+ },
429
+ decoder_config={
430
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
431
+ "params": ddconfig,
432
+ },
433
+ **kwargs,
434
+ )
435
+ self.quant_conv = torch.nn.Conv2d(
436
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
437
+ (1 + ddconfig["double_z"]) * embed_dim,
438
+ 1,
439
+ )
440
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
441
+ self.embed_dim = embed_dim
442
+
443
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
444
+
445
+ def get_autoencoder_params(self) -> list:
446
+ params = super().get_autoencoder_params()
447
+ return params
448
+
449
+ def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
450
+ if self.max_batch_size is None:
451
+ z = self.encoder(x)
452
+ z = self.quant_conv(z)
453
+ else:
454
+ N = x.shape[0]
455
+ bs = self.max_batch_size
456
+ n_batches = int(math.ceil(N / bs))
457
+ z = list()
458
+ for i_batch in range(n_batches):
459
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
460
+ z_batch = self.quant_conv(z_batch)
461
+ z.append(z_batch)
462
+ z = torch.cat(z, 0)
463
+
464
+ z, reg_log = self.regularization(z)
465
+ if return_reg_log:
466
+ return z, reg_log
467
+ return z
468
+
469
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
470
+ if self.max_batch_size is None:
471
+ dec = self.post_quant_conv(z)
472
+ dec = self.decoder(dec, **decoder_kwargs)
473
+ else:
474
+ N = z.shape[0]
475
+ bs = self.max_batch_size
476
+ n_batches = int(math.ceil(N / bs))
477
+ dec = list()
478
+ for i_batch in range(n_batches):
479
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
480
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
481
+ dec.append(dec_batch)
482
+ dec = torch.cat(dec, 0)
483
+
484
+ return dec
485
+
486
+
487
+ class IdentityFirstStage(AbstractAutoencoder):
488
+ def __init__(self, *args, **kwargs):
489
+ super().__init__(*args, **kwargs)
490
+
491
+ def get_input(self, x: Any) -> Any:
492
+ return x
493
+
494
+ def encode(self, x: Any, *args, **kwargs) -> Any:
495
+ return x
496
+
497
+ def decode(self, x: Any, *args, **kwargs) -> Any:
498
+ return
499
+
500
+
501
+ class VideoAutoencodingEngine(AutoencodingEngine):
502
+ def __init__(
503
+ self,
504
+ ckpt_path: Union[None, str] = None,
505
+ ignore_keys: Union[Tuple, list] = (),
506
+ image_video_weights=[1, 1],
507
+ only_train_decoder=False,
508
+ context_parallel_size=0,
509
+ **kwargs,
510
+ ):
511
+ super().__init__(**kwargs)
512
+ self.context_parallel_size = context_parallel_size
513
+ if ckpt_path is not None:
514
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
515
+
516
+ def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
517
+ return self.log_images(batch, additional_log_kwargs, **kwargs)
518
+
519
+ def get_input(self, batch: dict) -> torch.Tensor:
520
+ if self.context_parallel_size > 0:
521
+ if not is_context_parallel_initialized():
522
+ initialize_context_parallel(self.context_parallel_size)
523
+
524
+ batch = batch[self.input_key]
525
+
526
+ global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
527
+ torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
528
+
529
+ batch = _conv_split(batch, dim=2, kernel_size=1)
530
+ return batch
531
+
532
+ return batch[self.input_key]
533
+
534
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
535
+ if ckpt is None:
536
+ return
537
+ self.init_from_ckpt(ckpt)
538
+
539
+ def init_from_ckpt(self, path, ignore_keys=list()):
540
+ sd = torch.load(path, map_location="cpu")["state_dict"]
541
+ keys = list(sd.keys())
542
+ for k in keys:
543
+ for ik in ignore_keys:
544
+ if k.startswith(ik):
545
+ del sd[k]
546
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
547
+ print("Missing keys: ", missing_keys)
548
+ print("Unexpected keys: ", unexpected_keys)
549
+ print(f"Restored from {path}")
550
+
551
+
552
+ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
553
+ def __init__(
554
+ self,
555
+ cp_size=0,
556
+ *args,
557
+ **kwargs,
558
+ ):
559
+ self.cp_size = cp_size
560
+ return super().__init__(*args, **kwargs)
561
+
562
+ def encode(
563
+ self,
564
+ x: torch.Tensor,
565
+ return_reg_log: bool = False,
566
+ unregularized: bool = False,
567
+ input_cp: bool = False,
568
+ output_cp: bool = False,
569
+ use_cp: bool = True,
570
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
571
+ if self.cp_size <= 1:
572
+ use_cp = False
573
+ if self.cp_size > 0 and use_cp and not input_cp:
574
+ if not is_context_parallel_initialized:
575
+ initialize_context_parallel(self.cp_size)
576
+
577
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
578
+ torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
579
+
580
+ x = _conv_split(x, dim=2, kernel_size=1)
581
+
582
+ if return_reg_log:
583
+ z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
584
+ else:
585
+ z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
586
+
587
+ if self.cp_size > 0 and use_cp and not output_cp:
588
+ z = _conv_gather(z, dim=2, kernel_size=1)
589
+
590
+ if return_reg_log:
591
+ return z, reg_log
592
+ return z
593
+
594
+ def decode(
595
+ self,
596
+ z: torch.Tensor,
597
+ input_cp: bool = False,
598
+ output_cp: bool = False,
599
+ use_cp: bool = True,
600
+ **kwargs,
601
+ ):
602
+ if self.cp_size <= 1:
603
+ use_cp = False
604
+ if self.cp_size > 0 and use_cp and not input_cp:
605
+ if not is_context_parallel_initialized:
606
+ initialize_context_parallel(self.cp_size)
607
+
608
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
609
+ torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
610
+
611
+ z = _conv_split(z, dim=2, kernel_size=1)
612
+
613
+ x = super().decode(z, use_cp=use_cp, **kwargs)
614
+
615
+ if self.cp_size > 0 and use_cp and not output_cp:
616
+ x = _conv_gather(x, dim=2, kernel_size=1)
617
+
618
+ return x
619
+
620
+ def forward(
621
+ self,
622
+ x: torch.Tensor,
623
+ input_cp: bool = False,
624
+ latent_cp: bool = False,
625
+ output_cp: bool = False,
626
+ **additional_decode_kwargs,
627
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
628
+ z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
629
+ dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
630
+ return z, dec, reg_log
sat/sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sat/sgm/modules/attention.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from packaging import version
9
+ from torch import nn
10
+
11
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
12
+ SDP_IS_AVAILABLE = True
13
+ from torch.backends.cuda import SDPBackend, sdp_kernel
14
+
15
+ BACKEND_MAP = {
16
+ SDPBackend.MATH: {
17
+ "enable_math": True,
18
+ "enable_flash": False,
19
+ "enable_mem_efficient": False,
20
+ },
21
+ SDPBackend.FLASH_ATTENTION: {
22
+ "enable_math": False,
23
+ "enable_flash": True,
24
+ "enable_mem_efficient": False,
25
+ },
26
+ SDPBackend.EFFICIENT_ATTENTION: {
27
+ "enable_math": False,
28
+ "enable_flash": False,
29
+ "enable_mem_efficient": True,
30
+ },
31
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
32
+ }
33
+ else:
34
+ from contextlib import nullcontext
35
+
36
+ SDP_IS_AVAILABLE = False
37
+ sdp_kernel = nullcontext
38
+ BACKEND_MAP = {}
39
+ print(
40
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
41
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
42
+ )
43
+
44
+ try:
45
+ import xformers
46
+ import xformers.ops
47
+
48
+ XFORMERS_IS_AVAILABLE = True
49
+ except:
50
+ XFORMERS_IS_AVAILABLE = False
51
+ print("no module 'xformers'. Processing without...")
52
+
53
+ from .diffusionmodules.util import checkpoint
54
+
55
+
56
+ def exists(val):
57
+ return val is not None
58
+
59
+
60
+ def uniq(arr):
61
+ return {el: True for el in arr}.keys()
62
+
63
+
64
+ def default(val, d):
65
+ if exists(val):
66
+ return val
67
+ return d() if isfunction(d) else d
68
+
69
+
70
+ def max_neg_value(t):
71
+ return -torch.finfo(t.dtype).max
72
+
73
+
74
+ def init_(tensor):
75
+ dim = tensor.shape[-1]
76
+ std = 1 / math.sqrt(dim)
77
+ tensor.uniform_(-std, std)
78
+ return tensor
79
+
80
+
81
+ # feedforward
82
+ class GEGLU(nn.Module):
83
+ def __init__(self, dim_in, dim_out):
84
+ super().__init__()
85
+ self.proj = nn.Linear(dim_in, dim_out * 2)
86
+
87
+ def forward(self, x):
88
+ x, gate = self.proj(x).chunk(2, dim=-1)
89
+ return x * F.gelu(gate)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
94
+ super().__init__()
95
+ inner_dim = int(dim * mult)
96
+ dim_out = default(dim_out, dim)
97
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
98
+
99
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
100
+
101
+ def forward(self, x):
102
+ return self.net(x)
103
+
104
+
105
+ def zero_module(module):
106
+ """
107
+ Zero out the parameters of a module and return it.
108
+ """
109
+ for p in module.parameters():
110
+ p.detach().zero_()
111
+ return module
112
+
113
+
114
+ def Normalize(in_channels):
115
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
116
+
117
+
118
+ class LinearAttention(nn.Module):
119
+ def __init__(self, dim, heads=4, dim_head=32):
120
+ super().__init__()
121
+ self.heads = heads
122
+ hidden_dim = dim_head * heads
123
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
124
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
125
+
126
+ def forward(self, x):
127
+ b, c, h, w = x.shape
128
+ qkv = self.to_qkv(x)
129
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
130
+ k = k.softmax(dim=-1)
131
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
132
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
133
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
134
+ return self.to_out(out)
135
+
136
+
137
+ class SpatialSelfAttention(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = Normalize(in_channels)
143
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
144
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
145
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
146
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
147
+
148
+ def forward(self, x):
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # compute attention
156
+ b, c, h, w = q.shape
157
+ q = rearrange(q, "b c h w -> b (h w) c")
158
+ k = rearrange(k, "b c h w -> b c (h w)")
159
+ w_ = torch.einsum("bij,bjk->bik", q, k)
160
+
161
+ w_ = w_ * (int(c) ** (-0.5))
162
+ w_ = torch.nn.functional.softmax(w_, dim=2)
163
+
164
+ # attend to values
165
+ v = rearrange(v, "b c h w -> b c (h w)")
166
+ w_ = rearrange(w_, "b i j -> b j i")
167
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
168
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
169
+ h_ = self.proj_out(h_)
170
+
171
+ return x + h_
172
+
173
+
174
+ class CrossAttention(nn.Module):
175
+ def __init__(
176
+ self,
177
+ query_dim,
178
+ context_dim=None,
179
+ heads=8,
180
+ dim_head=64,
181
+ dropout=0.0,
182
+ backend=None,
183
+ ):
184
+ super().__init__()
185
+ inner_dim = dim_head * heads
186
+ context_dim = default(context_dim, query_dim)
187
+
188
+ self.scale = dim_head**-0.5
189
+ self.heads = heads
190
+
191
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
192
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
193
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
194
+
195
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
196
+ self.backend = backend
197
+
198
+ def forward(
199
+ self,
200
+ x,
201
+ context=None,
202
+ mask=None,
203
+ additional_tokens=None,
204
+ n_times_crossframe_attn_in_self=0,
205
+ ):
206
+ h = self.heads
207
+
208
+ if additional_tokens is not None:
209
+ # get the number of masked tokens at the beginning of the output sequence
210
+ n_tokens_to_mask = additional_tokens.shape[1]
211
+ # add additional token
212
+ x = torch.cat([additional_tokens, x], dim=1)
213
+
214
+ q = self.to_q(x)
215
+ context = default(context, x)
216
+ k = self.to_k(context)
217
+ v = self.to_v(context)
218
+
219
+ if n_times_crossframe_attn_in_self:
220
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
221
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
222
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
223
+ k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
224
+ v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
225
+
226
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
227
+
228
+ ## old
229
+ """
230
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
231
+ del q, k
232
+
233
+ if exists(mask):
234
+ mask = rearrange(mask, 'b ... -> b (...)')
235
+ max_neg_value = -torch.finfo(sim.dtype).max
236
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
237
+ sim.masked_fill_(~mask, max_neg_value)
238
+
239
+ # attention, what we cannot get enough of
240
+ sim = sim.softmax(dim=-1)
241
+
242
+ out = einsum('b i j, b j d -> b i d', sim, v)
243
+ """
244
+ ## new
245
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
246
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
247
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
248
+
249
+ del q, k, v
250
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
251
+
252
+ if additional_tokens is not None:
253
+ # remove additional token
254
+ out = out[:, n_tokens_to_mask:]
255
+ return self.to_out(out)
256
+
257
+
258
+ class MemoryEfficientCrossAttention(nn.Module):
259
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
260
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
261
+ super().__init__()
262
+ print(
263
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
264
+ f"{heads} heads with a dimension of {dim_head}."
265
+ )
266
+ inner_dim = dim_head * heads
267
+ context_dim = default(context_dim, query_dim)
268
+
269
+ self.heads = heads
270
+ self.dim_head = dim_head
271
+
272
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
273
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
274
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
275
+
276
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
277
+ self.attention_op: Optional[Any] = None
278
+
279
+ def forward(
280
+ self,
281
+ x,
282
+ context=None,
283
+ mask=None,
284
+ additional_tokens=None,
285
+ n_times_crossframe_attn_in_self=0,
286
+ ):
287
+ if additional_tokens is not None:
288
+ # get the number of masked tokens at the beginning of the output sequence
289
+ n_tokens_to_mask = additional_tokens.shape[1]
290
+ # add additional token
291
+ x = torch.cat([additional_tokens, x], dim=1)
292
+ q = self.to_q(x)
293
+ context = default(context, x)
294
+ k = self.to_k(context)
295
+ v = self.to_v(context)
296
+
297
+ if n_times_crossframe_attn_in_self:
298
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
299
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
300
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
301
+ k = repeat(
302
+ k[::n_times_crossframe_attn_in_self],
303
+ "b ... -> (b n) ...",
304
+ n=n_times_crossframe_attn_in_self,
305
+ )
306
+ v = repeat(
307
+ v[::n_times_crossframe_attn_in_self],
308
+ "b ... -> (b n) ...",
309
+ n=n_times_crossframe_attn_in_self,
310
+ )
311
+
312
+ b, _, _ = q.shape
313
+ q, k, v = map(
314
+ lambda t: t.unsqueeze(3)
315
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
316
+ .permute(0, 2, 1, 3)
317
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
318
+ .contiguous(),
319
+ (q, k, v),
320
+ )
321
+
322
+ # actually compute the attention, what we cannot get enough of
323
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
324
+
325
+ # TODO: Use this directly in the attention operation, as a bias
326
+ if exists(mask):
327
+ raise NotImplementedError
328
+ out = (
329
+ out.unsqueeze(0)
330
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
331
+ .permute(0, 2, 1, 3)
332
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
333
+ )
334
+ if additional_tokens is not None:
335
+ # remove additional token
336
+ out = out[:, n_tokens_to_mask:]
337
+ return self.to_out(out)
338
+
339
+
340
+ class BasicTransformerBlock(nn.Module):
341
+ ATTENTION_MODES = {
342
+ "softmax": CrossAttention, # vanilla attention
343
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
344
+ }
345
+
346
+ def __init__(
347
+ self,
348
+ dim,
349
+ n_heads,
350
+ d_head,
351
+ dropout=0.0,
352
+ context_dim=None,
353
+ gated_ff=True,
354
+ checkpoint=True,
355
+ disable_self_attn=False,
356
+ attn_mode="softmax",
357
+ sdp_backend=None,
358
+ ):
359
+ super().__init__()
360
+ assert attn_mode in self.ATTENTION_MODES
361
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
362
+ print(
363
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
364
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
365
+ )
366
+ attn_mode = "softmax"
367
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
368
+ print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
369
+ if not XFORMERS_IS_AVAILABLE:
370
+ assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
371
+ else:
372
+ print("Falling back to xformers efficient attention.")
373
+ attn_mode = "softmax-xformers"
374
+ attn_cls = self.ATTENTION_MODES[attn_mode]
375
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
376
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
377
+ else:
378
+ assert sdp_backend is None
379
+ self.disable_self_attn = disable_self_attn
380
+ self.attn1 = attn_cls(
381
+ query_dim=dim,
382
+ heads=n_heads,
383
+ dim_head=d_head,
384
+ dropout=dropout,
385
+ context_dim=context_dim if self.disable_self_attn else None,
386
+ backend=sdp_backend,
387
+ ) # is a self-attention if not self.disable_self_attn
388
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
389
+ self.attn2 = attn_cls(
390
+ query_dim=dim,
391
+ context_dim=context_dim,
392
+ heads=n_heads,
393
+ dim_head=d_head,
394
+ dropout=dropout,
395
+ backend=sdp_backend,
396
+ ) # is self-attn if context is none
397
+ self.norm1 = nn.LayerNorm(dim)
398
+ self.norm2 = nn.LayerNorm(dim)
399
+ self.norm3 = nn.LayerNorm(dim)
400
+ self.checkpoint = checkpoint
401
+ if self.checkpoint:
402
+ print(f"{self.__class__.__name__} is using checkpointing")
403
+
404
+ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
405
+ kwargs = {"x": x}
406
+
407
+ if context is not None:
408
+ kwargs.update({"context": context})
409
+
410
+ if additional_tokens is not None:
411
+ kwargs.update({"additional_tokens": additional_tokens})
412
+
413
+ if n_times_crossframe_attn_in_self:
414
+ kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
415
+
416
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
417
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
418
+
419
+ def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
420
+ x = (
421
+ self.attn1(
422
+ self.norm1(x),
423
+ context=context if self.disable_self_attn else None,
424
+ additional_tokens=additional_tokens,
425
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
426
+ )
427
+ + x
428
+ )
429
+ x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
430
+ x = self.ff(self.norm3(x)) + x
431
+ return x
432
+
433
+
434
+ class BasicTransformerSingleLayerBlock(nn.Module):
435
+ ATTENTION_MODES = {
436
+ "softmax": CrossAttention, # vanilla attention
437
+ "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
438
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
439
+ }
440
+
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ n_heads,
445
+ d_head,
446
+ dropout=0.0,
447
+ context_dim=None,
448
+ gated_ff=True,
449
+ checkpoint=True,
450
+ attn_mode="softmax",
451
+ ):
452
+ super().__init__()
453
+ assert attn_mode in self.ATTENTION_MODES
454
+ attn_cls = self.ATTENTION_MODES[attn_mode]
455
+ self.attn1 = attn_cls(
456
+ query_dim=dim,
457
+ heads=n_heads,
458
+ dim_head=d_head,
459
+ dropout=dropout,
460
+ context_dim=context_dim,
461
+ )
462
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
463
+ self.norm1 = nn.LayerNorm(dim)
464
+ self.norm2 = nn.LayerNorm(dim)
465
+ self.checkpoint = checkpoint
466
+
467
+ def forward(self, x, context=None):
468
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
469
+
470
+ def _forward(self, x, context=None):
471
+ x = self.attn1(self.norm1(x), context=context) + x
472
+ x = self.ff(self.norm2(x)) + x
473
+ return x
474
+
475
+
476
+ class SpatialTransformer(nn.Module):
477
+ """
478
+ Transformer block for image-like data.
479
+ First, project the input (aka embedding)
480
+ and reshape to b, t, d.
481
+ Then apply standard transformer action.
482
+ Finally, reshape to image
483
+ NEW: use_linear for more efficiency instead of the 1x1 convs
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ in_channels,
489
+ n_heads,
490
+ d_head,
491
+ depth=1,
492
+ dropout=0.0,
493
+ context_dim=None,
494
+ disable_self_attn=False,
495
+ use_linear=False,
496
+ attn_type="softmax",
497
+ use_checkpoint=True,
498
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
499
+ sdp_backend=None,
500
+ ):
501
+ super().__init__()
502
+ print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
503
+ from omegaconf import ListConfig
504
+
505
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
506
+ context_dim = [context_dim]
507
+ if exists(context_dim) and isinstance(context_dim, list):
508
+ if depth != len(context_dim):
509
+ print(
510
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
511
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
512
+ )
513
+ # depth does not match context dims.
514
+ assert all(
515
+ map(lambda x: x == context_dim[0], context_dim)
516
+ ), "need homogenous context_dim to match depth automatically"
517
+ context_dim = depth * [context_dim[0]]
518
+ elif context_dim is None:
519
+ context_dim = [None] * depth
520
+ self.in_channels = in_channels
521
+ inner_dim = n_heads * d_head
522
+ self.norm = Normalize(in_channels)
523
+ if not use_linear:
524
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
525
+ else:
526
+ self.proj_in = nn.Linear(in_channels, inner_dim)
527
+
528
+ self.transformer_blocks = nn.ModuleList(
529
+ [
530
+ BasicTransformerBlock(
531
+ inner_dim,
532
+ n_heads,
533
+ d_head,
534
+ dropout=dropout,
535
+ context_dim=context_dim[d],
536
+ disable_self_attn=disable_self_attn,
537
+ attn_mode=attn_type,
538
+ checkpoint=use_checkpoint,
539
+ sdp_backend=sdp_backend,
540
+ )
541
+ for d in range(depth)
542
+ ]
543
+ )
544
+ if not use_linear:
545
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
546
+ else:
547
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
548
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
549
+ self.use_linear = use_linear
550
+
551
+ def forward(self, x, context=None):
552
+ # note: if no context is given, cross-attention defaults to self-attention
553
+ if not isinstance(context, list):
554
+ context = [context]
555
+ b, c, h, w = x.shape
556
+ x_in = x
557
+ x = self.norm(x)
558
+ if not self.use_linear:
559
+ x = self.proj_in(x)
560
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
561
+ if self.use_linear:
562
+ x = self.proj_in(x)
563
+ for i, block in enumerate(self.transformer_blocks):
564
+ if i > 0 and len(context) == 1:
565
+ i = 0 # use same context for each block
566
+ x = block(x, context=context[i])
567
+ if self.use_linear:
568
+ x = self.proj_out(x)
569
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
570
+ if not self.use_linear:
571
+ x = self.proj_out(x)
572
+ return x + x_in
sat/sgm/modules/autoencoding/__init__.py ADDED
File without changes
sat/sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "GeneralLPIPSWithDiscriminator",
3
+ "LatentLPIPS",
4
+ ]
5
+
6
+ from .discriminator_loss import GeneralLPIPSWithDiscriminator
7
+ from .lpips import LatentLPIPS
8
+ from .video_loss import VideoAutoencoderLoss
sat/sgm/modules/autoencoding/losses/discriminator_loss.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from einops import rearrange
8
+ from matplotlib import colormaps
9
+ from matplotlib import pyplot as plt
10
+
11
+ from ....util import default, instantiate_from_config
12
+ from ..lpips.loss.lpips import LPIPS
13
+ from ..lpips.model.model import weights_init
14
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
15
+
16
+
17
+ class GeneralLPIPSWithDiscriminator(nn.Module):
18
+ def __init__(
19
+ self,
20
+ disc_start: int,
21
+ logvar_init: float = 0.0,
22
+ disc_num_layers: int = 3,
23
+ disc_in_channels: int = 3,
24
+ disc_factor: float = 1.0,
25
+ disc_weight: float = 1.0,
26
+ perceptual_weight: float = 1.0,
27
+ disc_loss: str = "hinge",
28
+ scale_input_to_tgt_size: bool = False,
29
+ dims: int = 2,
30
+ learn_logvar: bool = False,
31
+ regularization_weights: Union[None, Dict[str, float]] = None,
32
+ additional_log_keys: Optional[List[str]] = None,
33
+ discriminator_config: Optional[Dict] = None,
34
+ ):
35
+ super().__init__()
36
+ self.dims = dims
37
+ if self.dims > 2:
38
+ print(
39
+ f"running with dims={dims}. This means that for perceptual loss "
40
+ f"calculation, the LPIPS loss will be applied to each frame "
41
+ f"independently."
42
+ )
43
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
44
+ assert disc_loss in ["hinge", "vanilla"]
45
+ self.perceptual_loss = LPIPS().eval()
46
+ self.perceptual_weight = perceptual_weight
47
+ # output log variance
48
+ self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
49
+ self.learn_logvar = learn_logvar
50
+
51
+ discriminator_config = default(
52
+ discriminator_config,
53
+ {
54
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
55
+ "params": {
56
+ "input_nc": disc_in_channels,
57
+ "n_layers": disc_num_layers,
58
+ "use_actnorm": False,
59
+ },
60
+ },
61
+ )
62
+
63
+ self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
64
+ self.discriminator_iter_start = disc_start
65
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
66
+ self.disc_factor = disc_factor
67
+ self.discriminator_weight = disc_weight
68
+ self.regularization_weights = default(regularization_weights, {})
69
+
70
+ self.forward_keys = [
71
+ "optimizer_idx",
72
+ "global_step",
73
+ "last_layer",
74
+ "split",
75
+ "regularization_log",
76
+ ]
77
+
78
+ self.additional_log_keys = set(default(additional_log_keys, []))
79
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
80
+
81
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
82
+ return self.discriminator.parameters()
83
+
84
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
85
+ if self.learn_logvar:
86
+ yield self.logvar
87
+ yield from ()
88
+
89
+ @torch.no_grad()
90
+ def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
91
+ # calc logits of real/fake
92
+ logits_real = self.discriminator(inputs.contiguous().detach())
93
+ if len(logits_real.shape) < 4:
94
+ # Non patch-discriminator
95
+ return dict()
96
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
97
+ # -> (b, 1, h, w)
98
+
99
+ # parameters for colormapping
100
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
101
+ cmap = colormaps["PiYG"] # diverging colormap
102
+
103
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
104
+ """(b, 1, ...) -> (b, 3, ...)"""
105
+ logits = (logits + high) / (2 * high)
106
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
107
+ # -> (b, 1, ..., 3)
108
+ logits = torch.from_numpy(logits_np).to(logits.device)
109
+ return rearrange(logits, "b 1 ... c -> b c ...")
110
+
111
+ logits_real = torch.nn.functional.interpolate(
112
+ logits_real,
113
+ size=inputs.shape[-2:],
114
+ mode="nearest",
115
+ antialias=False,
116
+ )
117
+ logits_fake = torch.nn.functional.interpolate(
118
+ logits_fake,
119
+ size=reconstructions.shape[-2:],
120
+ mode="nearest",
121
+ antialias=False,
122
+ )
123
+
124
+ # alpha value of logits for overlay
125
+ alpha_real = torch.abs(logits_real) / high
126
+ alpha_fake = torch.abs(logits_fake) / high
127
+ # -> (b, 1, h, w) in range [0, 0.5]
128
+ # alpha value of lines don't really matter, since the values are the same
129
+ # for both images and logits anyway
130
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
131
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
132
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
133
+ # -> (1, h, w)
134
+ # blend logits and images together
135
+
136
+ # prepare logits for plotting
137
+ logits_real = to_colormap(logits_real)
138
+ logits_fake = to_colormap(logits_fake)
139
+ # resize logits
140
+ # -> (b, 3, h, w)
141
+
142
+ # make some grids
143
+ # add all logits to one plot
144
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
145
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
146
+ # I just love how torchvision calls the number of columns `nrow`
147
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
148
+ # -> (3, h, w)
149
+
150
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
151
+ grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
152
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
153
+ # -> (3, h, w) in range [0, 1]
154
+
155
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
156
+
157
+ # Create labeled colorbar
158
+ dpi = 100
159
+ height = 128 / dpi
160
+ width = grid_logits.shape[2] / dpi
161
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
162
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
163
+ plt.colorbar(
164
+ img,
165
+ cax=ax,
166
+ orientation="horizontal",
167
+ fraction=0.9,
168
+ aspect=width / height,
169
+ pad=0.0,
170
+ )
171
+ img.set_visible(False)
172
+ fig.tight_layout()
173
+ fig.canvas.draw()
174
+ # manually convert figure to numpy
175
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
176
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
177
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
178
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
179
+
180
+ # Add colorbar to plot
181
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
182
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
183
+ return {
184
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
185
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
186
+ }
187
+
188
+ def calculate_adaptive_weight(
189
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
190
+ ) -> torch.Tensor:
191
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
192
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
193
+
194
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
195
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
196
+ d_weight = d_weight * self.discriminator_weight
197
+ return d_weight
198
+
199
+ def forward(
200
+ self,
201
+ inputs: torch.Tensor,
202
+ reconstructions: torch.Tensor,
203
+ *, # added because I changed the order here
204
+ regularization_log: Dict[str, torch.Tensor],
205
+ optimizer_idx: int,
206
+ global_step: int,
207
+ last_layer: torch.Tensor,
208
+ split: str = "train",
209
+ weights: Union[None, float, torch.Tensor] = None,
210
+ ) -> Tuple[torch.Tensor, dict]:
211
+ if self.scale_input_to_tgt_size:
212
+ inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
213
+
214
+ if self.dims > 2:
215
+ inputs, reconstructions = map(
216
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
217
+ (inputs, reconstructions),
218
+ )
219
+
220
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
221
+ if self.perceptual_weight > 0:
222
+ frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
223
+
224
+ from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
225
+
226
+ input_frames = pick_video_frame(inputs, frame_indices)
227
+ recon_frames = pick_video_frame(reconstructions, frame_indices)
228
+
229
+ p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
230
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
231
+
232
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
233
+
234
+ # now the GAN part
235
+ if optimizer_idx == 0:
236
+ # generator update
237
+ if global_step >= self.discriminator_iter_start or not self.training:
238
+ logits_fake = self.discriminator(reconstructions.contiguous())
239
+ g_loss = -torch.mean(logits_fake)
240
+ if self.training:
241
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
242
+ else:
243
+ d_weight = torch.tensor(1.0)
244
+ else:
245
+ d_weight = torch.tensor(0.0)
246
+ g_loss = torch.tensor(0.0, requires_grad=True)
247
+
248
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
249
+ log = dict()
250
+ for k in regularization_log:
251
+ if k in self.regularization_weights:
252
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
253
+ if k in self.additional_log_keys:
254
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
255
+
256
+ log.update(
257
+ {
258
+ f"{split}/loss/total": loss.clone().detach().mean(),
259
+ f"{split}/loss/nll": nll_loss.detach().mean(),
260
+ f"{split}/loss/rec": rec_loss.detach().mean(),
261
+ f"{split}/loss/percep": p_loss.detach().mean(),
262
+ f"{split}/loss/rec": rec_loss.detach().mean(),
263
+ f"{split}/loss/g": g_loss.detach().mean(),
264
+ f"{split}/scalars/logvar": self.logvar.detach(),
265
+ f"{split}/scalars/d_weight": d_weight.detach(),
266
+ }
267
+ )
268
+
269
+ return loss, log
270
+ elif optimizer_idx == 1:
271
+ # second pass for discriminator update
272
+ logits_real = self.discriminator(inputs.contiguous().detach())
273
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
274
+
275
+ if global_step >= self.discriminator_iter_start or not self.training:
276
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
277
+ else:
278
+ d_loss = torch.tensor(0.0, requires_grad=True)
279
+
280
+ log = {
281
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
282
+ f"{split}/logits/real": logits_real.detach().mean(),
283
+ f"{split}/logits/fake": logits_fake.detach().mean(),
284
+ }
285
+ return d_loss, log
286
+ else:
287
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
288
+
289
+ def get_nll_loss(
290
+ self,
291
+ rec_loss: torch.Tensor,
292
+ weights: Optional[Union[float, torch.Tensor]] = None,
293
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
294
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
295
+ weighted_nll_loss = nll_loss
296
+ if weights is not None:
297
+ weighted_nll_loss = weights * nll_loss
298
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
299
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
300
+
301
+ return nll_loss, weighted_nll_loss
sat/sgm/modules/autoencoding/losses/lpips.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ....util import default, instantiate_from_config
5
+ from ..lpips.loss.lpips import LPIPS
6
+
7
+
8
+ class LatentLPIPS(nn.Module):
9
+ def __init__(
10
+ self,
11
+ decoder_config,
12
+ perceptual_weight=1.0,
13
+ latent_weight=1.0,
14
+ scale_input_to_tgt_size=False,
15
+ scale_tgt_to_input_size=False,
16
+ perceptual_weight_on_inputs=0.0,
17
+ ):
18
+ super().__init__()
19
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
20
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
21
+ self.init_decoder(decoder_config)
22
+ self.perceptual_loss = LPIPS().eval()
23
+ self.perceptual_weight = perceptual_weight
24
+ self.latent_weight = latent_weight
25
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26
+
27
+ def init_decoder(self, config):
28
+ self.decoder = instantiate_from_config(config)
29
+ if hasattr(self.decoder, "encoder"):
30
+ del self.decoder.encoder
31
+
32
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33
+ log = dict()
34
+ loss = (latent_inputs - latent_predictions) ** 2
35
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36
+ image_reconstructions = None
37
+ if self.perceptual_weight > 0.0:
38
+ image_reconstructions = self.decoder.decode(latent_predictions)
39
+ image_targets = self.decoder.decode(latent_inputs)
40
+ perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
41
+ loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
42
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
43
+
44
+ if self.perceptual_weight_on_inputs > 0.0:
45
+ image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
46
+ if self.scale_input_to_tgt_size:
47
+ image_inputs = torch.nn.functional.interpolate(
48
+ image_inputs,
49
+ image_reconstructions.shape[2:],
50
+ mode="bicubic",
51
+ antialias=True,
52
+ )
53
+ elif self.scale_tgt_to_input_size:
54
+ image_reconstructions = torch.nn.functional.interpolate(
55
+ image_reconstructions,
56
+ image_inputs.shape[2:],
57
+ mode="bicubic",
58
+ antialias=True,
59
+ )
60
+
61
+ perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
62
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
63
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
64
+ return loss, log