bubbliiiing commited on
Commit
00db68b
1 Parent(s): e9ff055

Create Code

Browse files
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ models*
3
+ output*
4
+ logs*
5
+ taming*
6
+ samples*
7
+ datasets*
8
+ asset*
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/#use-with-ide
117
+ .pdm.toml
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -11,3 +11,10 @@ license: other
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # License
16
+ This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
17
+
18
+ The CogVideoX-2B model (including its corresponding Transformers module and VAE module) is released under the [Apache 2.0 License](LICENSE).
19
+
20
+ The CogVideoX-5B model (Transformers module) is released under the [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE).
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+
4
+ from cogvideox.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
5
+ from cogvideox.ui.ui import ui_modelscope, ui_eas, ui
6
+
7
+ if __name__ == "__main__":
8
+ # Choose the ui mode
9
+ ui_mode = "eas"
10
+
11
+ # Low gpu memory mode, this is used when the GPU memory is under 16GB
12
+ low_gpu_memory_mode = False
13
+ # Use torch.float16 if GPU does not support torch.bfloat16
14
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
15
+ weight_dtype = torch.bfloat16
16
+
17
+ # Server ip
18
+ server_name = "0.0.0.0"
19
+ server_port = 7860
20
+
21
+ # Params below is used when ui_mode = "modelscope"
22
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-5b-InP"
23
+ savedir_sample = "samples"
24
+
25
+ if ui_mode == "modelscope":
26
+ demo, controller = ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
27
+ elif ui_mode == "eas":
28
+ demo, controller = ui_eas(model_name, savedir_sample)
29
+ else:
30
+ demo, controller = ui(low_gpu_memory_mode, weight_dtype)
31
+
32
+ # launch gradio
33
+ app, _, _ = demo.queue(status_update_rate=1).launch(
34
+ server_name=server_name,
35
+ server_port=server_port,
36
+ prevent_thread_lock=True
37
+ )
38
+
39
+ # launch api
40
+ infer_forward_api(None, app, controller)
41
+ update_diffusion_transformer_api(None, app, controller)
42
+ update_edition_api(None, app, controller)
43
+
44
+ # not close the python
45
+ while True:
46
+ time.sleep(5)
asset/1.png ADDED
asset/2.png ADDED
asset/3.png ADDED
asset/4.png ADDED
asset/5.png ADDED
cogvideox/__init__.py ADDED
File without changes
cogvideox/api/api.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gc
3
+ import base64
4
+ import torch
5
+ import gradio as gr
6
+ import tempfile
7
+ import hashlib
8
+ import os
9
+
10
+ from fastapi import FastAPI
11
+ from io import BytesIO
12
+ from PIL import Image
13
+
14
+ # Function to encode a file to Base64
15
+ def encode_file_to_base64(file_path):
16
+ with open(file_path, "rb") as file:
17
+ # Encode the data to Base64
18
+ file_base64 = base64.b64encode(file.read())
19
+ return file_base64
20
+
21
+ def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
22
+ @app.post("/cogvideox_fun/update_edition")
23
+ def _update_edition_api(
24
+ datas: dict,
25
+ ):
26
+ edition = datas.get('edition', 'v2')
27
+
28
+ try:
29
+ controller.update_edition(
30
+ edition
31
+ )
32
+ comment = "Success"
33
+ except Exception as e:
34
+ torch.cuda.empty_cache()
35
+ comment = f"Error. error information is {str(e)}"
36
+
37
+ return {"message": comment}
38
+
39
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
40
+ @app.post("/cogvideox_fun/update_diffusion_transformer")
41
+ def _update_diffusion_transformer_api(
42
+ datas: dict,
43
+ ):
44
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
45
+
46
+ try:
47
+ controller.update_diffusion_transformer(
48
+ diffusion_transformer_path
49
+ )
50
+ comment = "Success"
51
+ except Exception as e:
52
+ torch.cuda.empty_cache()
53
+ comment = f"Error. error information is {str(e)}"
54
+
55
+ return {"message": comment}
56
+
57
+ def save_base64_video(base64_string):
58
+ video_data = base64.b64decode(base64_string)
59
+
60
+ md5_hash = hashlib.md5(video_data).hexdigest()
61
+ filename = f"{md5_hash}.mp4"
62
+
63
+ temp_dir = tempfile.gettempdir()
64
+ file_path = os.path.join(temp_dir, filename)
65
+
66
+ with open(file_path, 'wb') as video_file:
67
+ video_file.write(video_data)
68
+
69
+ return file_path
70
+
71
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
72
+ @app.post("/cogvideox_fun/infer_forward")
73
+ def _infer_forward_api(
74
+ datas: dict,
75
+ ):
76
+ base_model_path = datas.get('base_model_path', 'none')
77
+ lora_model_path = datas.get('lora_model_path', 'none')
78
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
79
+ prompt_textbox = datas.get('prompt_textbox', None)
80
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ')
81
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
82
+ sample_step_slider = datas.get('sample_step_slider', 30)
83
+ resize_method = datas.get('resize_method', "Generate by")
84
+ width_slider = datas.get('width_slider', 672)
85
+ height_slider = datas.get('height_slider', 384)
86
+ base_resolution = datas.get('base_resolution', 512)
87
+ is_image = datas.get('is_image', False)
88
+ generation_method = datas.get('generation_method', False)
89
+ length_slider = datas.get('length_slider', 49)
90
+ overlap_video_length = datas.get('overlap_video_length', 4)
91
+ partial_video_length = datas.get('partial_video_length', 72)
92
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
93
+ start_image = datas.get('start_image', None)
94
+ end_image = datas.get('end_image', None)
95
+ validation_video = datas.get('validation_video', None)
96
+ denoise_strength = datas.get('denoise_strength', 0.70)
97
+ seed_textbox = datas.get("seed_textbox", 43)
98
+
99
+ generation_method = "Image Generation" if is_image else generation_method
100
+
101
+ if start_image is not None:
102
+ start_image = base64.b64decode(start_image)
103
+ start_image = [Image.open(BytesIO(start_image))]
104
+
105
+ if end_image is not None:
106
+ end_image = base64.b64decode(end_image)
107
+ end_image = [Image.open(BytesIO(end_image))]
108
+
109
+ if validation_video is not None:
110
+ validation_video = save_base64_video(validation_video)
111
+
112
+ try:
113
+ save_sample_path, comment = controller.generate(
114
+ "",
115
+ base_model_path,
116
+ lora_model_path,
117
+ lora_alpha_slider,
118
+ prompt_textbox,
119
+ negative_prompt_textbox,
120
+ sampler_dropdown,
121
+ sample_step_slider,
122
+ resize_method,
123
+ width_slider,
124
+ height_slider,
125
+ base_resolution,
126
+ generation_method,
127
+ length_slider,
128
+ overlap_video_length,
129
+ partial_video_length,
130
+ cfg_scale_slider,
131
+ start_image,
132
+ end_image,
133
+ validation_video,
134
+ denoise_strength,
135
+ seed_textbox,
136
+ is_api = True,
137
+ )
138
+ except Exception as e:
139
+ gc.collect()
140
+ torch.cuda.empty_cache()
141
+ torch.cuda.ipc_collect()
142
+ save_sample_path = ""
143
+ comment = f"Error. error information is {str(e)}"
144
+ return {"message": comment}
145
+
146
+ if save_sample_path != "":
147
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
148
+ else:
149
+ return {"message": comment, "save_sample_path": save_sample_path}
cogvideox/api/post_infer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import sys
4
+ import time
5
+ from datetime import datetime
6
+ from io import BytesIO
7
+
8
+ import cv2
9
+ import requests
10
+ import base64
11
+
12
+
13
+ def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
14
+ datas = json.dumps({
15
+ "diffusion_transformer_path": diffusion_transformer_path
16
+ })
17
+ r = requests.post(f'{url}/cogvideox_fun/update_diffusion_transformer', data=datas, timeout=1500)
18
+ data = r.content.decode('utf-8')
19
+ return data
20
+
21
+ def post_update_edition(edition, url='http://0.0.0.0:7860'):
22
+ datas = json.dumps({
23
+ "edition": edition
24
+ })
25
+ r = requests.post(f'{url}/cogvideox_fun/update_edition', data=datas, timeout=1500)
26
+ data = r.content.decode('utf-8')
27
+ return data
28
+
29
+ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
30
+ datas = json.dumps({
31
+ "base_model_path": "none",
32
+ "motion_module_path": "none",
33
+ "lora_model_path": "none",
34
+ "lora_alpha_slider": 0.55,
35
+ "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
36
+ "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. ",
37
+ "sampler_dropdown": "Euler",
38
+ "sample_step_slider": 50,
39
+ "width_slider": 672,
40
+ "height_slider": 384,
41
+ "generation_method": "Video Generation",
42
+ "length_slider": length_slider,
43
+ "cfg_scale_slider": 6,
44
+ "seed_textbox": 43,
45
+ })
46
+ r = requests.post(f'{url}/cogvideox_fun/infer_forward', data=datas, timeout=1500)
47
+ data = r.content.decode('utf-8')
48
+ return data
49
+
50
+ if __name__ == '__main__':
51
+ # initiate time
52
+ now_date = datetime.now()
53
+ time_start = time.time()
54
+
55
+ # -------------------------- #
56
+ # Step 1: update edition
57
+ # -------------------------- #
58
+ diffusion_transformer_path = "models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
59
+ outputs = post_diffusion_transformer(diffusion_transformer_path)
60
+ print('Output update edition: ', outputs)
61
+
62
+ # -------------------------- #
63
+ # Step 2: infer
64
+ # -------------------------- #
65
+ # "Video Generation" and "Image Generation"
66
+ generation_method = "Video Generation"
67
+ length_slider = 49
68
+ outputs = post_infer(generation_method, length_slider)
69
+
70
+ # Get decoded data
71
+ outputs = json.loads(outputs)
72
+ base64_encoding = outputs["base64_encoding"]
73
+ decoded_data = base64.b64decode(base64_encoding)
74
+
75
+ is_image = True if generation_method == "Image Generation" else False
76
+ if is_image or length_slider == 1:
77
+ file_path = "1.png"
78
+ else:
79
+ file_path = "1.mp4"
80
+ with open(file_path, "wb") as file:
81
+ file.write(decoded_data)
82
+
83
+ # End of record time
84
+ # The calculated time difference is the execution time of the program, expressed in seconds / s
85
+ time_end = time.time()
86
+ time_sum = (time_end - time_start) % 60
87
+ print('# --------------------------------------------------------- #')
88
+ print(f'# Total expenditure: {time_sum}s')
89
+ print('# --------------------------------------------------------- #')
cogvideox/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e)
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e)
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e)
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
cogvideox/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/data/dataset_image_video.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from threading import Thread
8
+
9
+ import albumentations
10
+ import cv2
11
+ import gc
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+
16
+ from func_timeout import func_timeout, FunctionTimedOut
17
+ from decord import VideoReader
18
+ from PIL import Image
19
+ from torch.utils.data import BatchSampler, Sampler
20
+ from torch.utils.data.dataset import Dataset
21
+ from contextlib import contextmanager
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ if f != 1:
29
+ mask_index = np.random.choice([0, 1, 2, 3, 4], p = [0.05, 0.3, 0.3, 0.3, 0.05]) # np.random.randint(0, 5)
30
+ else:
31
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) # np.random.randint(0, 2)
32
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
33
+
34
+ if mask_index == 0:
35
+ center_x = torch.randint(0, w, (1,)).item()
36
+ center_y = torch.randint(0, h, (1,)).item()
37
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
38
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
39
+
40
+ start_x = max(center_x - block_size_x // 2, 0)
41
+ end_x = min(center_x + block_size_x // 2, w)
42
+ start_y = max(center_y - block_size_y // 2, 0)
43
+ end_y = min(center_y + block_size_y // 2, h)
44
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
45
+ elif mask_index == 1:
46
+ mask[:, :, :, :] = 1
47
+ elif mask_index == 2:
48
+ mask_frame_index = np.random.randint(1, 5)
49
+ mask[mask_frame_index:, :, :, :] = 1
50
+ elif mask_index == 3:
51
+ mask_frame_index = np.random.randint(1, 5)
52
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
53
+ elif mask_index == 4:
54
+ center_x = torch.randint(0, w, (1,)).item()
55
+ center_y = torch.randint(0, h, (1,)).item()
56
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
57
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
58
+
59
+ start_x = max(center_x - block_size_x // 2, 0)
60
+ end_x = min(center_x + block_size_x // 2, w)
61
+ start_y = max(center_y - block_size_y // 2, 0)
62
+ end_y = min(center_y + block_size_y // 2, h)
63
+
64
+ mask_frame_before = np.random.randint(0, f // 2)
65
+ mask_frame_after = np.random.randint(f // 2, f)
66
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
67
+ else:
68
+ raise ValueError(f"The mask_index {mask_index} is not define")
69
+ return mask
70
+
71
+ class ImageVideoSampler(BatchSampler):
72
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
73
+
74
+ Args:
75
+ sampler (Sampler): Base sampler.
76
+ dataset (Dataset): Dataset providing data information.
77
+ batch_size (int): Size of mini-batch.
78
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
79
+ its size would be less than ``batch_size``.
80
+ aspect_ratios (dict): The predefined aspect ratios.
81
+ """
82
+
83
+ def __init__(self,
84
+ sampler: Sampler,
85
+ dataset: Dataset,
86
+ batch_size: int,
87
+ drop_last: bool = False
88
+ ) -> None:
89
+ if not isinstance(sampler, Sampler):
90
+ raise TypeError('sampler should be an instance of ``Sampler``, '
91
+ f'but got {sampler}')
92
+ if not isinstance(batch_size, int) or batch_size <= 0:
93
+ raise ValueError('batch_size should be a positive integer value, '
94
+ f'but got batch_size={batch_size}')
95
+ self.sampler = sampler
96
+ self.dataset = dataset
97
+ self.batch_size = batch_size
98
+ self.drop_last = drop_last
99
+
100
+ # buckets for each aspect ratio
101
+ self.bucket = {'image':[], 'video':[]}
102
+
103
+ def __iter__(self):
104
+ for idx in self.sampler:
105
+ content_type = self.dataset.dataset[idx].get('type', 'image')
106
+ self.bucket[content_type].append(idx)
107
+
108
+ # yield a batch of indices in the same aspect ratio group
109
+ if len(self.bucket['video']) == self.batch_size:
110
+ bucket = self.bucket['video']
111
+ yield bucket[:]
112
+ del bucket[:]
113
+ elif len(self.bucket['image']) == self.batch_size:
114
+ bucket = self.bucket['image']
115
+ yield bucket[:]
116
+ del bucket[:]
117
+
118
+ @contextmanager
119
+ def VideoReader_contextmanager(*args, **kwargs):
120
+ vr = VideoReader(*args, **kwargs)
121
+ try:
122
+ yield vr
123
+ finally:
124
+ del vr
125
+ gc.collect()
126
+
127
+ def get_video_reader_batch(video_reader, batch_index):
128
+ frames = video_reader.get_batch(batch_index).asnumpy()
129
+ return frames
130
+
131
+ def resize_frame(frame, target_short_side):
132
+ h, w, _ = frame.shape
133
+ if h < w:
134
+ if target_short_side > h:
135
+ return frame
136
+ new_h = target_short_side
137
+ new_w = int(target_short_side * w / h)
138
+ else:
139
+ if target_short_side > w:
140
+ return frame
141
+ new_w = target_short_side
142
+ new_h = int(target_short_side * h / w)
143
+
144
+ resized_frame = cv2.resize(frame, (new_w, new_h))
145
+ return resized_frame
146
+
147
+ class ImageVideoDataset(Dataset):
148
+ def __init__(
149
+ self,
150
+ ann_path, data_root=None,
151
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
152
+ image_sample_size=512,
153
+ video_repeat=0,
154
+ text_drop_ratio=-1,
155
+ enable_bucket=False,
156
+ video_length_drop_start=0.1,
157
+ video_length_drop_end=0.9,
158
+ enable_inpaint=False,
159
+ ):
160
+ # Loading annotations from files
161
+ print(f"loading annotations from {ann_path} ...")
162
+ if ann_path.endswith('.csv'):
163
+ with open(ann_path, 'r') as csvfile:
164
+ dataset = list(csv.DictReader(csvfile))
165
+ elif ann_path.endswith('.json'):
166
+ dataset = json.load(open(ann_path))
167
+
168
+ self.data_root = data_root
169
+
170
+ # It's used to balance num of images and videos.
171
+ self.dataset = []
172
+ for data in dataset:
173
+ if data.get('type', 'image') != 'video':
174
+ self.dataset.append(data)
175
+ if video_repeat > 0:
176
+ for _ in range(video_repeat):
177
+ for data in dataset:
178
+ if data.get('type', 'image') == 'video':
179
+ self.dataset.append(data)
180
+ del dataset
181
+
182
+ self.length = len(self.dataset)
183
+ print(f"data scale: {self.length}")
184
+ # TODO: enable bucket training
185
+ self.enable_bucket = enable_bucket
186
+ self.text_drop_ratio = text_drop_ratio
187
+ self.enable_inpaint = enable_inpaint
188
+
189
+ self.video_length_drop_start = video_length_drop_start
190
+ self.video_length_drop_end = video_length_drop_end
191
+
192
+ # Video params
193
+ self.video_sample_stride = video_sample_stride
194
+ self.video_sample_n_frames = video_sample_n_frames
195
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
196
+ self.video_transforms = transforms.Compose(
197
+ [
198
+ transforms.Resize(min(self.video_sample_size)),
199
+ transforms.CenterCrop(self.video_sample_size),
200
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
201
+ ]
202
+ )
203
+
204
+ # Image params
205
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
206
+ self.image_transforms = transforms.Compose([
207
+ transforms.Resize(min(self.image_sample_size)),
208
+ transforms.CenterCrop(self.image_sample_size),
209
+ transforms.ToTensor(),
210
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
211
+ ])
212
+
213
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
214
+
215
+ def get_batch(self, idx):
216
+ data_info = self.dataset[idx % len(self.dataset)]
217
+
218
+ if data_info.get('type', 'image')=='video':
219
+ video_id, text = data_info['file_path'], data_info['text']
220
+
221
+ if self.data_root is None:
222
+ video_dir = video_id
223
+ else:
224
+ video_dir = os.path.join(self.data_root, video_id)
225
+
226
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
227
+ min_sample_n_frames = min(
228
+ self.video_sample_n_frames,
229
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
230
+ )
231
+ if min_sample_n_frames == 0:
232
+ raise ValueError(f"No Frames in video.")
233
+
234
+ video_length = int(self.video_length_drop_end * len(video_reader))
235
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
236
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
237
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
238
+
239
+ try:
240
+ sample_args = (video_reader, batch_index)
241
+ pixel_values = func_timeout(
242
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
243
+ )
244
+ resized_frames = []
245
+ for i in range(len(pixel_values)):
246
+ frame = pixel_values[i]
247
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
248
+ resized_frames.append(resized_frame)
249
+ pixel_values = np.array(resized_frames)
250
+ except FunctionTimedOut:
251
+ raise ValueError(f"Read {idx} timeout.")
252
+ except Exception as e:
253
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
254
+
255
+ if not self.enable_bucket:
256
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
257
+ pixel_values = pixel_values / 255.
258
+ del video_reader
259
+ else:
260
+ pixel_values = pixel_values
261
+
262
+ if not self.enable_bucket:
263
+ pixel_values = self.video_transforms(pixel_values)
264
+
265
+ # Random use no text generation
266
+ if random.random() < self.text_drop_ratio:
267
+ text = ''
268
+ return pixel_values, text, 'video'
269
+ else:
270
+ image_path, text = data_info['file_path'], data_info['text']
271
+ if self.data_root is not None:
272
+ image_path = os.path.join(self.data_root, image_path)
273
+ image = Image.open(image_path).convert('RGB')
274
+ if not self.enable_bucket:
275
+ image = self.image_transforms(image).unsqueeze(0)
276
+ else:
277
+ image = np.expand_dims(np.array(image), 0)
278
+ if random.random() < self.text_drop_ratio:
279
+ text = ''
280
+ return image, text, 'image'
281
+
282
+ def __len__(self):
283
+ return self.length
284
+
285
+ def __getitem__(self, idx):
286
+ data_info = self.dataset[idx % len(self.dataset)]
287
+ data_type = data_info.get('type', 'image')
288
+ while True:
289
+ sample = {}
290
+ try:
291
+ data_info_local = self.dataset[idx % len(self.dataset)]
292
+ data_type_local = data_info_local.get('type', 'image')
293
+ if data_type_local != data_type:
294
+ raise ValueError("data_type_local != data_type")
295
+
296
+ pixel_values, name, data_type = self.get_batch(idx)
297
+ sample["pixel_values"] = pixel_values
298
+ sample["text"] = name
299
+ sample["data_type"] = data_type
300
+ sample["idx"] = idx
301
+
302
+ if len(sample) > 0:
303
+ break
304
+ except Exception as e:
305
+ print(e, self.dataset[idx % len(self.dataset)])
306
+ idx = random.randint(0, self.length-1)
307
+
308
+ if self.enable_inpaint and not self.enable_bucket:
309
+ mask = get_random_mask(pixel_values.size())
310
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
311
+ sample["mask_pixel_values"] = mask_pixel_values
312
+ sample["mask"] = mask
313
+
314
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
315
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
316
+ sample["clip_pixel_values"] = clip_pixel_values
317
+
318
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
319
+ if (mask == 1).all():
320
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
321
+ sample["ref_pixel_values"] = ref_pixel_values
322
+
323
+ return sample
324
+
cogvideox/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/models/autoencoder_magvit.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
25
+ from diffusers.utils import logging
26
+ from diffusers.utils.accelerate_utils import apply_forward_hook
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.downsampling import CogVideoXDownsample3D
29
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.models.upsampling import CogVideoXUpsample3D
32
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class CogVideoXSafeConv3d(nn.Conv3d):
39
+ r"""
40
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
41
+ """
42
+
43
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
44
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
45
+
46
+ # Set to 2GB, suitable for CuDNN
47
+ if memory_count > 2:
48
+ kernel_size = self.kernel_size[0]
49
+ part_num = int(memory_count / 2) + 1
50
+ input_chunks = torch.chunk(input, part_num, dim=2)
51
+
52
+ if kernel_size > 1:
53
+ input_chunks = [input_chunks[0]] + [
54
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
55
+ for i in range(1, len(input_chunks))
56
+ ]
57
+
58
+ output_chunks = []
59
+ for input_chunk in input_chunks:
60
+ output_chunks.append(super().forward(input_chunk))
61
+ output = torch.cat(output_chunks, dim=2)
62
+ return output
63
+ else:
64
+ return super().forward(input)
65
+
66
+
67
+ class CogVideoXCausalConv3d(nn.Module):
68
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
69
+
70
+ Args:
71
+ in_channels (`int`): Number of channels in the input tensor.
72
+ out_channels (`int`): Number of output channels produced by the convolution.
73
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
74
+ stride (`int`, defaults to `1`): Stride of the convolution.
75
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
76
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: int,
83
+ kernel_size: Union[int, Tuple[int, int, int]],
84
+ stride: int = 1,
85
+ dilation: int = 1,
86
+ pad_mode: str = "constant",
87
+ ):
88
+ super().__init__()
89
+
90
+ if isinstance(kernel_size, int):
91
+ kernel_size = (kernel_size,) * 3
92
+
93
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
94
+
95
+ self.pad_mode = pad_mode
96
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
97
+ height_pad = height_kernel_size // 2
98
+ width_pad = width_kernel_size // 2
99
+
100
+ self.height_pad = height_pad
101
+ self.width_pad = width_pad
102
+ self.time_pad = time_pad
103
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
104
+
105
+ self.temporal_dim = 2
106
+ self.time_kernel_size = time_kernel_size
107
+
108
+ stride = (stride, 1, 1)
109
+ dilation = (dilation, 1, 1)
110
+ self.conv = CogVideoXSafeConv3d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=kernel_size,
114
+ stride=stride,
115
+ dilation=dilation,
116
+ )
117
+
118
+ self.conv_cache = None
119
+
120
+ def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
121
+ kernel_size = self.time_kernel_size
122
+ if kernel_size > 1:
123
+ cached_inputs = (
124
+ [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
+ )
126
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
+ return inputs
128
+
129
+ def _clear_fake_context_parallel_cache(self):
130
+ del self.conv_cache
131
+ self.conv_cache = None
132
+
133
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
+ inputs = self.fake_context_parallel_forward(inputs)
135
+
136
+ self._clear_fake_context_parallel_cache()
137
+ # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
+ # hundred megabytes and so let's not do it for now
139
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140
+
141
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
142
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
+
144
+ output = self.conv(inputs)
145
+ return output
146
+
147
+
148
+ class CogVideoXSpatialNorm3D(nn.Module):
149
+ r"""
150
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
151
+ to 3D-video like data.
152
+
153
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
154
+
155
+ Args:
156
+ f_channels (`int`):
157
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
158
+ zq_channels (`int`):
159
+ The number of channels for the quantized vector as described in the paper.
160
+ groups (`int`):
161
+ Number of groups to separate the channels into for group normalization.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ f_channels: int,
167
+ zq_channels: int,
168
+ groups: int = 32,
169
+ ):
170
+ super().__init__()
171
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
172
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
+
175
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
176
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
177
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
178
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
179
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
180
+ z_first = F.interpolate(z_first, size=f_first_size)
181
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
182
+ zq = torch.cat([z_first, z_rest], dim=2)
183
+ else:
184
+ zq = F.interpolate(zq, size=f.shape[-3:])
185
+
186
+ norm_f = self.norm_layer(f)
187
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
188
+ return new_f
189
+
190
+
191
+ class CogVideoXResnetBlock3D(nn.Module):
192
+ r"""
193
+ A 3D ResNet block used in the CogVideoX model.
194
+
195
+ Args:
196
+ in_channels (`int`):
197
+ Number of input channels.
198
+ out_channels (`int`, *optional*):
199
+ Number of output channels. If None, defaults to `in_channels`.
200
+ dropout (`float`, defaults to `0.0`):
201
+ Dropout rate.
202
+ temb_channels (`int`, defaults to `512`):
203
+ Number of time embedding channels.
204
+ groups (`int`, defaults to `32`):
205
+ Number of groups to separate the channels into for group normalization.
206
+ eps (`float`, defaults to `1e-6`):
207
+ Epsilon value for normalization layers.
208
+ non_linearity (`str`, defaults to `"swish"`):
209
+ Activation function to use.
210
+ conv_shortcut (bool, defaults to `False`):
211
+ Whether or not to use a convolution shortcut.
212
+ spatial_norm_dim (`int`, *optional*):
213
+ The dimension to use for spatial norm if it is to be used instead of group norm.
214
+ pad_mode (str, defaults to `"first"`):
215
+ Padding mode.
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ in_channels: int,
221
+ out_channels: Optional[int] = None,
222
+ dropout: float = 0.0,
223
+ temb_channels: int = 512,
224
+ groups: int = 32,
225
+ eps: float = 1e-6,
226
+ non_linearity: str = "swish",
227
+ conv_shortcut: bool = False,
228
+ spatial_norm_dim: Optional[int] = None,
229
+ pad_mode: str = "first",
230
+ ):
231
+ super().__init__()
232
+
233
+ out_channels = out_channels or in_channels
234
+
235
+ self.in_channels = in_channels
236
+ self.out_channels = out_channels
237
+ self.nonlinearity = get_activation(non_linearity)
238
+ self.use_conv_shortcut = conv_shortcut
239
+
240
+ if spatial_norm_dim is None:
241
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
242
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
243
+ else:
244
+ self.norm1 = CogVideoXSpatialNorm3D(
245
+ f_channels=in_channels,
246
+ zq_channels=spatial_norm_dim,
247
+ groups=groups,
248
+ )
249
+ self.norm2 = CogVideoXSpatialNorm3D(
250
+ f_channels=out_channels,
251
+ zq_channels=spatial_norm_dim,
252
+ groups=groups,
253
+ )
254
+
255
+ self.conv1 = CogVideoXCausalConv3d(
256
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
257
+ )
258
+
259
+ if temb_channels > 0:
260
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
261
+
262
+ self.dropout = nn.Dropout(dropout)
263
+ self.conv2 = CogVideoXCausalConv3d(
264
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
265
+ )
266
+
267
+ if self.in_channels != self.out_channels:
268
+ if self.use_conv_shortcut:
269
+ self.conv_shortcut = CogVideoXCausalConv3d(
270
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
271
+ )
272
+ else:
273
+ self.conv_shortcut = CogVideoXSafeConv3d(
274
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ inputs: torch.Tensor,
280
+ temb: Optional[torch.Tensor] = None,
281
+ zq: Optional[torch.Tensor] = None,
282
+ ) -> torch.Tensor:
283
+ hidden_states = inputs
284
+
285
+ if zq is not None:
286
+ hidden_states = self.norm1(hidden_states, zq)
287
+ else:
288
+ hidden_states = self.norm1(hidden_states)
289
+
290
+ hidden_states = self.nonlinearity(hidden_states)
291
+ hidden_states = self.conv1(hidden_states)
292
+
293
+ if temb is not None:
294
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
295
+
296
+ if zq is not None:
297
+ hidden_states = self.norm2(hidden_states, zq)
298
+ else:
299
+ hidden_states = self.norm2(hidden_states)
300
+
301
+ hidden_states = self.nonlinearity(hidden_states)
302
+ hidden_states = self.dropout(hidden_states)
303
+ hidden_states = self.conv2(hidden_states)
304
+
305
+ if self.in_channels != self.out_channels:
306
+ inputs = self.conv_shortcut(inputs)
307
+
308
+ hidden_states = hidden_states + inputs
309
+ return hidden_states
310
+
311
+
312
+ class CogVideoXDownBlock3D(nn.Module):
313
+ r"""
314
+ A downsampling block used in the CogVideoX model.
315
+
316
+ Args:
317
+ in_channels (`int`):
318
+ Number of input channels.
319
+ out_channels (`int`, *optional*):
320
+ Number of output channels. If None, defaults to `in_channels`.
321
+ temb_channels (`int`, defaults to `512`):
322
+ Number of time embedding channels.
323
+ num_layers (`int`, defaults to `1`):
324
+ Number of resnet layers.
325
+ dropout (`float`, defaults to `0.0`):
326
+ Dropout rate.
327
+ resnet_eps (`float`, defaults to `1e-6`):
328
+ Epsilon value for normalization layers.
329
+ resnet_act_fn (`str`, defaults to `"swish"`):
330
+ Activation function to use.
331
+ resnet_groups (`int`, defaults to `32`):
332
+ Number of groups to separate the channels into for group normalization.
333
+ add_downsample (`bool`, defaults to `True`):
334
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
335
+ compress_time (`bool`, defaults to `False`):
336
+ Whether or not to downsample across temporal dimension.
337
+ pad_mode (str, defaults to `"first"`):
338
+ Padding mode.
339
+ """
340
+
341
+ _supports_gradient_checkpointing = True
342
+
343
+ def __init__(
344
+ self,
345
+ in_channels: int,
346
+ out_channels: int,
347
+ temb_channels: int,
348
+ dropout: float = 0.0,
349
+ num_layers: int = 1,
350
+ resnet_eps: float = 1e-6,
351
+ resnet_act_fn: str = "swish",
352
+ resnet_groups: int = 32,
353
+ add_downsample: bool = True,
354
+ downsample_padding: int = 0,
355
+ compress_time: bool = False,
356
+ pad_mode: str = "first",
357
+ ):
358
+ super().__init__()
359
+
360
+ resnets = []
361
+ for i in range(num_layers):
362
+ in_channel = in_channels if i == 0 else out_channels
363
+ resnets.append(
364
+ CogVideoXResnetBlock3D(
365
+ in_channels=in_channel,
366
+ out_channels=out_channels,
367
+ dropout=dropout,
368
+ temb_channels=temb_channels,
369
+ groups=resnet_groups,
370
+ eps=resnet_eps,
371
+ non_linearity=resnet_act_fn,
372
+ pad_mode=pad_mode,
373
+ )
374
+ )
375
+
376
+ self.resnets = nn.ModuleList(resnets)
377
+ self.downsamplers = None
378
+
379
+ if add_downsample:
380
+ self.downsamplers = nn.ModuleList(
381
+ [
382
+ CogVideoXDownsample3D(
383
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
384
+ )
385
+ ]
386
+ )
387
+
388
+ self.gradient_checkpointing = False
389
+
390
+ def forward(
391
+ self,
392
+ hidden_states: torch.Tensor,
393
+ temb: Optional[torch.Tensor] = None,
394
+ zq: Optional[torch.Tensor] = None,
395
+ ) -> torch.Tensor:
396
+ for resnet in self.resnets:
397
+ if self.training and self.gradient_checkpointing:
398
+
399
+ def create_custom_forward(module):
400
+ def create_forward(*inputs):
401
+ return module(*inputs)
402
+
403
+ return create_forward
404
+
405
+ hidden_states = torch.utils.checkpoint.checkpoint(
406
+ create_custom_forward(resnet), hidden_states, temb, zq
407
+ )
408
+ else:
409
+ hidden_states = resnet(hidden_states, temb, zq)
410
+
411
+ if self.downsamplers is not None:
412
+ for downsampler in self.downsamplers:
413
+ hidden_states = downsampler(hidden_states)
414
+
415
+ return hidden_states
416
+
417
+
418
+ class CogVideoXMidBlock3D(nn.Module):
419
+ r"""
420
+ A middle block used in the CogVideoX model.
421
+
422
+ Args:
423
+ in_channels (`int`):
424
+ Number of input channels.
425
+ temb_channels (`int`, defaults to `512`):
426
+ Number of time embedding channels.
427
+ dropout (`float`, defaults to `0.0`):
428
+ Dropout rate.
429
+ num_layers (`int`, defaults to `1`):
430
+ Number of resnet layers.
431
+ resnet_eps (`float`, defaults to `1e-6`):
432
+ Epsilon value for normalization layers.
433
+ resnet_act_fn (`str`, defaults to `"swish"`):
434
+ Activation function to use.
435
+ resnet_groups (`int`, defaults to `32`):
436
+ Number of groups to separate the channels into for group normalization.
437
+ spatial_norm_dim (`int`, *optional*):
438
+ The dimension to use for spatial norm if it is to be used instead of group norm.
439
+ pad_mode (str, defaults to `"first"`):
440
+ Padding mode.
441
+ """
442
+
443
+ _supports_gradient_checkpointing = True
444
+
445
+ def __init__(
446
+ self,
447
+ in_channels: int,
448
+ temb_channels: int,
449
+ dropout: float = 0.0,
450
+ num_layers: int = 1,
451
+ resnet_eps: float = 1e-6,
452
+ resnet_act_fn: str = "swish",
453
+ resnet_groups: int = 32,
454
+ spatial_norm_dim: Optional[int] = None,
455
+ pad_mode: str = "first",
456
+ ):
457
+ super().__init__()
458
+
459
+ resnets = []
460
+ for _ in range(num_layers):
461
+ resnets.append(
462
+ CogVideoXResnetBlock3D(
463
+ in_channels=in_channels,
464
+ out_channels=in_channels,
465
+ dropout=dropout,
466
+ temb_channels=temb_channels,
467
+ groups=resnet_groups,
468
+ eps=resnet_eps,
469
+ spatial_norm_dim=spatial_norm_dim,
470
+ non_linearity=resnet_act_fn,
471
+ pad_mode=pad_mode,
472
+ )
473
+ )
474
+ self.resnets = nn.ModuleList(resnets)
475
+
476
+ self.gradient_checkpointing = False
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ temb: Optional[torch.Tensor] = None,
482
+ zq: Optional[torch.Tensor] = None,
483
+ ) -> torch.Tensor:
484
+ for resnet in self.resnets:
485
+ if self.training and self.gradient_checkpointing:
486
+
487
+ def create_custom_forward(module):
488
+ def create_forward(*inputs):
489
+ return module(*inputs)
490
+
491
+ return create_forward
492
+
493
+ hidden_states = torch.utils.checkpoint.checkpoint(
494
+ create_custom_forward(resnet), hidden_states, temb, zq
495
+ )
496
+ else:
497
+ hidden_states = resnet(hidden_states, temb, zq)
498
+
499
+ return hidden_states
500
+
501
+
502
+ class CogVideoXUpBlock3D(nn.Module):
503
+ r"""
504
+ An upsampling block used in the CogVideoX model.
505
+
506
+ Args:
507
+ in_channels (`int`):
508
+ Number of input channels.
509
+ out_channels (`int`, *optional*):
510
+ Number of output channels. If None, defaults to `in_channels`.
511
+ temb_channels (`int`, defaults to `512`):
512
+ Number of time embedding channels.
513
+ dropout (`float`, defaults to `0.0`):
514
+ Dropout rate.
515
+ num_layers (`int`, defaults to `1`):
516
+ Number of resnet layers.
517
+ resnet_eps (`float`, defaults to `1e-6`):
518
+ Epsilon value for normalization layers.
519
+ resnet_act_fn (`str`, defaults to `"swish"`):
520
+ Activation function to use.
521
+ resnet_groups (`int`, defaults to `32`):
522
+ Number of groups to separate the channels into for group normalization.
523
+ spatial_norm_dim (`int`, defaults to `16`):
524
+ The dimension to use for spatial norm if it is to be used instead of group norm.
525
+ add_upsample (`bool`, defaults to `True`):
526
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
527
+ compress_time (`bool`, defaults to `False`):
528
+ Whether or not to downsample across temporal dimension.
529
+ pad_mode (str, defaults to `"first"`):
530
+ Padding mode.
531
+ """
532
+
533
+ def __init__(
534
+ self,
535
+ in_channels: int,
536
+ out_channels: int,
537
+ temb_channels: int,
538
+ dropout: float = 0.0,
539
+ num_layers: int = 1,
540
+ resnet_eps: float = 1e-6,
541
+ resnet_act_fn: str = "swish",
542
+ resnet_groups: int = 32,
543
+ spatial_norm_dim: int = 16,
544
+ add_upsample: bool = True,
545
+ upsample_padding: int = 1,
546
+ compress_time: bool = False,
547
+ pad_mode: str = "first",
548
+ ):
549
+ super().__init__()
550
+
551
+ resnets = []
552
+ for i in range(num_layers):
553
+ in_channel = in_channels if i == 0 else out_channels
554
+ resnets.append(
555
+ CogVideoXResnetBlock3D(
556
+ in_channels=in_channel,
557
+ out_channels=out_channels,
558
+ dropout=dropout,
559
+ temb_channels=temb_channels,
560
+ groups=resnet_groups,
561
+ eps=resnet_eps,
562
+ non_linearity=resnet_act_fn,
563
+ spatial_norm_dim=spatial_norm_dim,
564
+ pad_mode=pad_mode,
565
+ )
566
+ )
567
+
568
+ self.resnets = nn.ModuleList(resnets)
569
+ self.upsamplers = None
570
+
571
+ if add_upsample:
572
+ self.upsamplers = nn.ModuleList(
573
+ [
574
+ CogVideoXUpsample3D(
575
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
576
+ )
577
+ ]
578
+ )
579
+
580
+ self.gradient_checkpointing = False
581
+
582
+ def forward(
583
+ self,
584
+ hidden_states: torch.Tensor,
585
+ temb: Optional[torch.Tensor] = None,
586
+ zq: Optional[torch.Tensor] = None,
587
+ ) -> torch.Tensor:
588
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
589
+ for resnet in self.resnets:
590
+ if self.training and self.gradient_checkpointing:
591
+
592
+ def create_custom_forward(module):
593
+ def create_forward(*inputs):
594
+ return module(*inputs)
595
+
596
+ return create_forward
597
+
598
+ hidden_states = torch.utils.checkpoint.checkpoint(
599
+ create_custom_forward(resnet), hidden_states, temb, zq
600
+ )
601
+ else:
602
+ hidden_states = resnet(hidden_states, temb, zq)
603
+
604
+ if self.upsamplers is not None:
605
+ for upsampler in self.upsamplers:
606
+ hidden_states = upsampler(hidden_states)
607
+
608
+ return hidden_states
609
+
610
+
611
+ class CogVideoXEncoder3D(nn.Module):
612
+ r"""
613
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
614
+
615
+ Args:
616
+ in_channels (`int`, *optional*, defaults to 3):
617
+ The number of input channels.
618
+ out_channels (`int`, *optional*, defaults to 3):
619
+ The number of output channels.
620
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
621
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
622
+ options.
623
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
624
+ The number of output channels for each block.
625
+ act_fn (`str`, *optional*, defaults to `"silu"`):
626
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
627
+ layers_per_block (`int`, *optional*, defaults to 2):
628
+ The number of layers per block.
629
+ norm_num_groups (`int`, *optional*, defaults to 32):
630
+ The number of groups for normalization.
631
+ """
632
+
633
+ _supports_gradient_checkpointing = True
634
+
635
+ def __init__(
636
+ self,
637
+ in_channels: int = 3,
638
+ out_channels: int = 16,
639
+ down_block_types: Tuple[str, ...] = (
640
+ "CogVideoXDownBlock3D",
641
+ "CogVideoXDownBlock3D",
642
+ "CogVideoXDownBlock3D",
643
+ "CogVideoXDownBlock3D",
644
+ ),
645
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
646
+ layers_per_block: int = 3,
647
+ act_fn: str = "silu",
648
+ norm_eps: float = 1e-6,
649
+ norm_num_groups: int = 32,
650
+ dropout: float = 0.0,
651
+ pad_mode: str = "first",
652
+ temporal_compression_ratio: float = 4,
653
+ ):
654
+ super().__init__()
655
+
656
+ # log2 of temporal_compress_times
657
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
658
+
659
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
660
+ self.down_blocks = nn.ModuleList([])
661
+
662
+ # down blocks
663
+ output_channel = block_out_channels[0]
664
+ for i, down_block_type in enumerate(down_block_types):
665
+ input_channel = output_channel
666
+ output_channel = block_out_channels[i]
667
+ is_final_block = i == len(block_out_channels) - 1
668
+ compress_time = i < temporal_compress_level
669
+
670
+ if down_block_type == "CogVideoXDownBlock3D":
671
+ down_block = CogVideoXDownBlock3D(
672
+ in_channels=input_channel,
673
+ out_channels=output_channel,
674
+ temb_channels=0,
675
+ dropout=dropout,
676
+ num_layers=layers_per_block,
677
+ resnet_eps=norm_eps,
678
+ resnet_act_fn=act_fn,
679
+ resnet_groups=norm_num_groups,
680
+ add_downsample=not is_final_block,
681
+ compress_time=compress_time,
682
+ )
683
+ else:
684
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
685
+
686
+ self.down_blocks.append(down_block)
687
+
688
+ # mid block
689
+ self.mid_block = CogVideoXMidBlock3D(
690
+ in_channels=block_out_channels[-1],
691
+ temb_channels=0,
692
+ dropout=dropout,
693
+ num_layers=2,
694
+ resnet_eps=norm_eps,
695
+ resnet_act_fn=act_fn,
696
+ resnet_groups=norm_num_groups,
697
+ pad_mode=pad_mode,
698
+ )
699
+
700
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
701
+ self.conv_act = nn.SiLU()
702
+ self.conv_out = CogVideoXCausalConv3d(
703
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
704
+ )
705
+
706
+ self.gradient_checkpointing = False
707
+
708
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
709
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
710
+ hidden_states = self.conv_in(sample)
711
+
712
+ if self.training and self.gradient_checkpointing:
713
+
714
+ def create_custom_forward(module):
715
+ def custom_forward(*inputs):
716
+ return module(*inputs)
717
+
718
+ return custom_forward
719
+
720
+ # 1. Down
721
+ for down_block in self.down_blocks:
722
+ hidden_states = torch.utils.checkpoint.checkpoint(
723
+ create_custom_forward(down_block), hidden_states, temb, None
724
+ )
725
+
726
+ # 2. Mid
727
+ hidden_states = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(self.mid_block), hidden_states, temb, None
729
+ )
730
+ else:
731
+ # 1. Down
732
+ for down_block in self.down_blocks:
733
+ hidden_states = down_block(hidden_states, temb, None)
734
+
735
+ # 2. Mid
736
+ hidden_states = self.mid_block(hidden_states, temb, None)
737
+
738
+ # 3. Post-process
739
+ hidden_states = self.norm_out(hidden_states)
740
+ hidden_states = self.conv_act(hidden_states)
741
+ hidden_states = self.conv_out(hidden_states)
742
+ return hidden_states
743
+
744
+
745
+ class CogVideoXDecoder3D(nn.Module):
746
+ r"""
747
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
748
+ sample.
749
+
750
+ Args:
751
+ in_channels (`int`, *optional*, defaults to 3):
752
+ The number of input channels.
753
+ out_channels (`int`, *optional*, defaults to 3):
754
+ The number of output channels.
755
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
756
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
757
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
758
+ The number of output channels for each block.
759
+ act_fn (`str`, *optional*, defaults to `"silu"`):
760
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
761
+ layers_per_block (`int`, *optional*, defaults to 2):
762
+ The number of layers per block.
763
+ norm_num_groups (`int`, *optional*, defaults to 32):
764
+ The number of groups for normalization.
765
+ """
766
+
767
+ _supports_gradient_checkpointing = True
768
+
769
+ def __init__(
770
+ self,
771
+ in_channels: int = 16,
772
+ out_channels: int = 3,
773
+ up_block_types: Tuple[str, ...] = (
774
+ "CogVideoXUpBlock3D",
775
+ "CogVideoXUpBlock3D",
776
+ "CogVideoXUpBlock3D",
777
+ "CogVideoXUpBlock3D",
778
+ ),
779
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
780
+ layers_per_block: int = 3,
781
+ act_fn: str = "silu",
782
+ norm_eps: float = 1e-6,
783
+ norm_num_groups: int = 32,
784
+ dropout: float = 0.0,
785
+ pad_mode: str = "first",
786
+ temporal_compression_ratio: float = 4,
787
+ ):
788
+ super().__init__()
789
+
790
+ reversed_block_out_channels = list(reversed(block_out_channels))
791
+
792
+ self.conv_in = CogVideoXCausalConv3d(
793
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
794
+ )
795
+
796
+ # mid block
797
+ self.mid_block = CogVideoXMidBlock3D(
798
+ in_channels=reversed_block_out_channels[0],
799
+ temb_channels=0,
800
+ num_layers=2,
801
+ resnet_eps=norm_eps,
802
+ resnet_act_fn=act_fn,
803
+ resnet_groups=norm_num_groups,
804
+ spatial_norm_dim=in_channels,
805
+ pad_mode=pad_mode,
806
+ )
807
+
808
+ # up blocks
809
+ self.up_blocks = nn.ModuleList([])
810
+
811
+ output_channel = reversed_block_out_channels[0]
812
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
813
+
814
+ for i, up_block_type in enumerate(up_block_types):
815
+ prev_output_channel = output_channel
816
+ output_channel = reversed_block_out_channels[i]
817
+ is_final_block = i == len(block_out_channels) - 1
818
+ compress_time = i < temporal_compress_level
819
+
820
+ if up_block_type == "CogVideoXUpBlock3D":
821
+ up_block = CogVideoXUpBlock3D(
822
+ in_channels=prev_output_channel,
823
+ out_channels=output_channel,
824
+ temb_channels=0,
825
+ dropout=dropout,
826
+ num_layers=layers_per_block + 1,
827
+ resnet_eps=norm_eps,
828
+ resnet_act_fn=act_fn,
829
+ resnet_groups=norm_num_groups,
830
+ spatial_norm_dim=in_channels,
831
+ add_upsample=not is_final_block,
832
+ compress_time=compress_time,
833
+ pad_mode=pad_mode,
834
+ )
835
+ prev_output_channel = output_channel
836
+ else:
837
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
838
+
839
+ self.up_blocks.append(up_block)
840
+
841
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
842
+ self.conv_act = nn.SiLU()
843
+ self.conv_out = CogVideoXCausalConv3d(
844
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
845
+ )
846
+
847
+ self.gradient_checkpointing = False
848
+
849
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
850
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
851
+ hidden_states = self.conv_in(sample)
852
+
853
+ if self.training and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Mid
862
+ hidden_states = torch.utils.checkpoint.checkpoint(
863
+ create_custom_forward(self.mid_block), hidden_states, temb, sample
864
+ )
865
+
866
+ # 2. Up
867
+ for up_block in self.up_blocks:
868
+ hidden_states = torch.utils.checkpoint.checkpoint(
869
+ create_custom_forward(up_block), hidden_states, temb, sample
870
+ )
871
+ else:
872
+ # 1. Mid
873
+ hidden_states = self.mid_block(hidden_states, temb, sample)
874
+
875
+ # 2. Up
876
+ for up_block in self.up_blocks:
877
+ hidden_states = up_block(hidden_states, temb, sample)
878
+
879
+ # 3. Post-process
880
+ hidden_states = self.norm_out(hidden_states, sample)
881
+ hidden_states = self.conv_act(hidden_states)
882
+ hidden_states = self.conv_out(hidden_states)
883
+ return hidden_states
884
+
885
+
886
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
887
+ r"""
888
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
889
+ [CogVideoX](https://github.com/THUDM/CogVideo).
890
+
891
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
892
+ for all models (such as downloading or saving).
893
+
894
+ Parameters:
895
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
896
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
897
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
898
+ Tuple of downsample block types.
899
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
900
+ Tuple of upsample block types.
901
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
902
+ Tuple of block output channels.
903
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
904
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
905
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
906
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
907
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
908
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
909
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
910
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
911
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
912
+ force_upcast (`bool`, *optional*, default to `True`):
913
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
914
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
915
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
916
+ """
917
+
918
+ _supports_gradient_checkpointing = True
919
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
920
+
921
+ @register_to_config
922
+ def __init__(
923
+ self,
924
+ in_channels: int = 3,
925
+ out_channels: int = 3,
926
+ down_block_types: Tuple[str] = (
927
+ "CogVideoXDownBlock3D",
928
+ "CogVideoXDownBlock3D",
929
+ "CogVideoXDownBlock3D",
930
+ "CogVideoXDownBlock3D",
931
+ ),
932
+ up_block_types: Tuple[str] = (
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ "CogVideoXUpBlock3D",
936
+ "CogVideoXUpBlock3D",
937
+ ),
938
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
939
+ latent_channels: int = 16,
940
+ layers_per_block: int = 3,
941
+ act_fn: str = "silu",
942
+ norm_eps: float = 1e-6,
943
+ norm_num_groups: int = 32,
944
+ temporal_compression_ratio: float = 4,
945
+ sample_height: int = 480,
946
+ sample_width: int = 720,
947
+ scaling_factor: float = 1.15258426,
948
+ shift_factor: Optional[float] = None,
949
+ latents_mean: Optional[Tuple[float]] = None,
950
+ latents_std: Optional[Tuple[float]] = None,
951
+ force_upcast: float = True,
952
+ use_quant_conv: bool = False,
953
+ use_post_quant_conv: bool = False,
954
+ ):
955
+ super().__init__()
956
+
957
+ self.encoder = CogVideoXEncoder3D(
958
+ in_channels=in_channels,
959
+ out_channels=latent_channels,
960
+ down_block_types=down_block_types,
961
+ block_out_channels=block_out_channels,
962
+ layers_per_block=layers_per_block,
963
+ act_fn=act_fn,
964
+ norm_eps=norm_eps,
965
+ norm_num_groups=norm_num_groups,
966
+ temporal_compression_ratio=temporal_compression_ratio,
967
+ )
968
+ self.decoder = CogVideoXDecoder3D(
969
+ in_channels=latent_channels,
970
+ out_channels=out_channels,
971
+ up_block_types=up_block_types,
972
+ block_out_channels=block_out_channels,
973
+ layers_per_block=layers_per_block,
974
+ act_fn=act_fn,
975
+ norm_eps=norm_eps,
976
+ norm_num_groups=norm_num_groups,
977
+ temporal_compression_ratio=temporal_compression_ratio,
978
+ )
979
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
980
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
981
+
982
+ self.use_slicing = False
983
+ self.use_tiling = False
984
+
985
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
986
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
987
+ # If you decode X latent frames together, the number of output frames is:
988
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
989
+ #
990
+ # Example with num_latent_frames_batch_size = 2:
991
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
992
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
993
+ # => 6 * 8 = 48 frames
994
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
995
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
996
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
997
+ # => 1 * 9 + 5 * 8 = 49 frames
998
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
999
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1000
+ # number of temporal frames.
1001
+ self.num_latent_frames_batch_size = 2
1002
+
1003
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1004
+ self.tile_sample_min_height = sample_height // 2
1005
+ self.tile_sample_min_width = sample_width // 2
1006
+ self.tile_latent_min_height = int(
1007
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1008
+ )
1009
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1010
+
1011
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1012
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1013
+ # and so the tiling implementation has only been tested on those specific resolutions.
1014
+ self.tile_overlap_factor_height = 1 / 6
1015
+ self.tile_overlap_factor_width = 1 / 5
1016
+
1017
+ def _set_gradient_checkpointing(self, module, value=False):
1018
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1019
+ module.gradient_checkpointing = value
1020
+
1021
+ def _clear_fake_context_parallel_cache(self):
1022
+ for name, module in self.named_modules():
1023
+ if isinstance(module, CogVideoXCausalConv3d):
1024
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1025
+ module._clear_fake_context_parallel_cache()
1026
+
1027
+ def enable_tiling(
1028
+ self,
1029
+ tile_sample_min_height: Optional[int] = None,
1030
+ tile_sample_min_width: Optional[int] = None,
1031
+ tile_overlap_factor_height: Optional[float] = None,
1032
+ tile_overlap_factor_width: Optional[float] = None,
1033
+ ) -> None:
1034
+ r"""
1035
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1036
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1037
+ processing larger images.
1038
+
1039
+ Args:
1040
+ tile_sample_min_height (`int`, *optional*):
1041
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1042
+ tile_sample_min_width (`int`, *optional*):
1043
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1044
+ tile_overlap_factor_height (`int`, *optional*):
1045
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1046
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1047
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1048
+ tile_overlap_factor_width (`int`, *optional*):
1049
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1050
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1051
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1052
+ """
1053
+ self.use_tiling = True
1054
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1055
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1056
+ self.tile_latent_min_height = int(
1057
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1058
+ )
1059
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1060
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1061
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1062
+
1063
+ def disable_tiling(self) -> None:
1064
+ r"""
1065
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1066
+ decoding in one step.
1067
+ """
1068
+ self.use_tiling = False
1069
+
1070
+ def enable_slicing(self) -> None:
1071
+ r"""
1072
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1073
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1074
+ """
1075
+ self.use_slicing = True
1076
+
1077
+ def disable_slicing(self) -> None:
1078
+ r"""
1079
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1080
+ decoding in one step.
1081
+ """
1082
+ self.use_slicing = False
1083
+
1084
+ @apply_forward_hook
1085
+ def encode(
1086
+ self, x: torch.Tensor, return_dict: bool = True
1087
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1088
+ """
1089
+ Encode a batch of images into latents.
1090
+
1091
+ Args:
1092
+ x (`torch.Tensor`): Input batch of images.
1093
+ return_dict (`bool`, *optional*, defaults to `True`):
1094
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1095
+
1096
+ Returns:
1097
+ The latent representations of the encoded images. If `return_dict` is True, a
1098
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1099
+ """
1100
+ batch_size, num_channels, num_frames, height, width = x.shape
1101
+ if num_frames == 1:
1102
+ h = self.encoder(x)
1103
+ if self.quant_conv is not None:
1104
+ h = self.quant_conv(h)
1105
+ posterior = DiagonalGaussianDistribution(h)
1106
+ else:
1107
+ frame_batch_size = 4
1108
+ h = []
1109
+ for i in range(num_frames // frame_batch_size):
1110
+ remaining_frames = num_frames % frame_batch_size
1111
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1112
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1113
+ z_intermediate = x[:, :, start_frame:end_frame]
1114
+ z_intermediate = self.encoder(z_intermediate)
1115
+ if self.quant_conv is not None:
1116
+ z_intermediate = self.quant_conv(z_intermediate)
1117
+ h.append(z_intermediate)
1118
+ self._clear_fake_context_parallel_cache()
1119
+ h = torch.cat(h, dim=2)
1120
+ posterior = DiagonalGaussianDistribution(h)
1121
+ if not return_dict:
1122
+ return (posterior,)
1123
+ return AutoencoderKLOutput(latent_dist=posterior)
1124
+
1125
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1126
+ batch_size, num_channels, num_frames, height, width = z.shape
1127
+
1128
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1129
+ return self.tiled_decode(z, return_dict=return_dict)
1130
+
1131
+ if num_frames == 1:
1132
+ dec = []
1133
+ z_intermediate = z
1134
+ if self.post_quant_conv is not None:
1135
+ z_intermediate = self.post_quant_conv(z_intermediate)
1136
+ z_intermediate = self.decoder(z_intermediate)
1137
+ dec.append(z_intermediate)
1138
+ else:
1139
+ frame_batch_size = self.num_latent_frames_batch_size
1140
+ dec = []
1141
+ for i in range(num_frames // frame_batch_size):
1142
+ remaining_frames = num_frames % frame_batch_size
1143
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1144
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1145
+ z_intermediate = z[:, :, start_frame:end_frame]
1146
+ if self.post_quant_conv is not None:
1147
+ z_intermediate = self.post_quant_conv(z_intermediate)
1148
+ z_intermediate = self.decoder(z_intermediate)
1149
+ dec.append(z_intermediate)
1150
+
1151
+ self._clear_fake_context_parallel_cache()
1152
+ dec = torch.cat(dec, dim=2)
1153
+
1154
+ if not return_dict:
1155
+ return (dec,)
1156
+
1157
+ return DecoderOutput(sample=dec)
1158
+
1159
+ @apply_forward_hook
1160
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1161
+ """
1162
+ Decode a batch of images.
1163
+
1164
+ Args:
1165
+ z (`torch.Tensor`): Input batch of latent vectors.
1166
+ return_dict (`bool`, *optional*, defaults to `True`):
1167
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1168
+
1169
+ Returns:
1170
+ [`~models.vae.DecoderOutput`] or `tuple`:
1171
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1172
+ returned.
1173
+ """
1174
+ if self.use_slicing and z.shape[0] > 1:
1175
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1176
+ decoded = torch.cat(decoded_slices)
1177
+ else:
1178
+ decoded = self._decode(z).sample
1179
+
1180
+ if not return_dict:
1181
+ return (decoded,)
1182
+ return DecoderOutput(sample=decoded)
1183
+
1184
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1185
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1186
+ for y in range(blend_extent):
1187
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1188
+ y / blend_extent
1189
+ )
1190
+ return b
1191
+
1192
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1193
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1194
+ for x in range(blend_extent):
1195
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1196
+ x / blend_extent
1197
+ )
1198
+ return b
1199
+
1200
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1201
+ r"""
1202
+ Decode a batch of images using a tiled decoder.
1203
+
1204
+ Args:
1205
+ z (`torch.Tensor`): Input batch of latent vectors.
1206
+ return_dict (`bool`, *optional*, defaults to `True`):
1207
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1208
+
1209
+ Returns:
1210
+ [`~models.vae.DecoderOutput`] or `tuple`:
1211
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1212
+ returned.
1213
+ """
1214
+ # Rough memory assessment:
1215
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1216
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1217
+ # - Assume fp16 (2 bytes per value).
1218
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1219
+ #
1220
+ # Memory assessment when using tiling:
1221
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1222
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1223
+
1224
+ batch_size, num_channels, num_frames, height, width = z.shape
1225
+
1226
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1227
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1228
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1229
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1230
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1231
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1232
+ frame_batch_size = self.num_latent_frames_batch_size
1233
+
1234
+ # Split z into overlapping tiles and decode them separately.
1235
+ # The tiles have an overlap to avoid seams between tiles.
1236
+ rows = []
1237
+ for i in range(0, height, overlap_height):
1238
+ row = []
1239
+ for j in range(0, width, overlap_width):
1240
+ time = []
1241
+ for k in range(num_frames // frame_batch_size):
1242
+ remaining_frames = num_frames % frame_batch_size
1243
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1244
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1245
+ tile = z[
1246
+ :,
1247
+ :,
1248
+ start_frame:end_frame,
1249
+ i : i + self.tile_latent_min_height,
1250
+ j : j + self.tile_latent_min_width,
1251
+ ]
1252
+ if self.post_quant_conv is not None:
1253
+ tile = self.post_quant_conv(tile)
1254
+ tile = self.decoder(tile)
1255
+ time.append(tile)
1256
+ self._clear_fake_context_parallel_cache()
1257
+ row.append(torch.cat(time, dim=2))
1258
+ rows.append(row)
1259
+
1260
+ result_rows = []
1261
+ for i, row in enumerate(rows):
1262
+ result_row = []
1263
+ for j, tile in enumerate(row):
1264
+ # blend the above tile and the left tile
1265
+ # to the current tile and add the current tile to the result row
1266
+ if i > 0:
1267
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1268
+ if j > 0:
1269
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1270
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1271
+ result_rows.append(torch.cat(result_row, dim=4))
1272
+
1273
+ dec = torch.cat(result_rows, dim=3)
1274
+
1275
+ if not return_dict:
1276
+ return (dec,)
1277
+
1278
+ return DecoderOutput(sample=dec)
1279
+
1280
+ def forward(
1281
+ self,
1282
+ sample: torch.Tensor,
1283
+ sample_posterior: bool = False,
1284
+ return_dict: bool = True,
1285
+ generator: Optional[torch.Generator] = None,
1286
+ ) -> Union[torch.Tensor, torch.Tensor]:
1287
+ x = sample
1288
+ posterior = self.encode(x).latent_dist
1289
+ if sample_posterior:
1290
+ z = posterior.sample(generator=generator)
1291
+ else:
1292
+ z = posterior.mode()
1293
+ dec = self.decode(z)
1294
+ if not return_dict:
1295
+ return (dec,)
1296
+ return dec
cogvideox/models/transformer3d.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import os
19
+ import json
20
+ import torch
21
+ import glob
22
+ import torch.nn.functional as F
23
+ from torch import nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import is_torch_version, logging
27
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
28
+ from diffusers.models.attention import Attention, FeedForward
29
+ from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
30
+ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @maybe_allow_in_graph
40
+ class CogVideoXBlock(nn.Module):
41
+ r"""
42
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
43
+
44
+ Parameters:
45
+ dim (`int`):
46
+ The number of channels in the input and output.
47
+ num_attention_heads (`int`):
48
+ The number of heads to use for multi-head attention.
49
+ attention_head_dim (`int`):
50
+ The number of channels in each head.
51
+ time_embed_dim (`int`):
52
+ The number of channels in timestep embedding.
53
+ dropout (`float`, defaults to `0.0`):
54
+ The dropout probability to use.
55
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
56
+ Activation function to be used in feed-forward.
57
+ attention_bias (`bool`, defaults to `False`):
58
+ Whether or not to use bias in attention projection layers.
59
+ qk_norm (`bool`, defaults to `True`):
60
+ Whether or not to use normalization after query and key projections in Attention.
61
+ norm_elementwise_affine (`bool`, defaults to `True`):
62
+ Whether to use learnable elementwise affine parameters for normalization.
63
+ norm_eps (`float`, defaults to `1e-5`):
64
+ Epsilon value for normalization layers.
65
+ final_dropout (`bool` defaults to `False`):
66
+ Whether to apply a final dropout after the last feed-forward layer.
67
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
68
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
69
+ ff_bias (`bool`, defaults to `True`):
70
+ Whether or not to use bias in Feed-forward layer.
71
+ attention_out_bias (`bool`, defaults to `True`):
72
+ Whether or not to use bias in Attention output projection layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ time_embed_dim: int,
81
+ dropout: float = 0.0,
82
+ activation_fn: str = "gelu-approximate",
83
+ attention_bias: bool = False,
84
+ qk_norm: bool = True,
85
+ norm_elementwise_affine: bool = True,
86
+ norm_eps: float = 1e-5,
87
+ final_dropout: bool = True,
88
+ ff_inner_dim: Optional[int] = None,
89
+ ff_bias: bool = True,
90
+ attention_out_bias: bool = True,
91
+ ):
92
+ super().__init__()
93
+
94
+ # 1. Self Attention
95
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
96
+
97
+ self.attn1 = Attention(
98
+ query_dim=dim,
99
+ dim_head=attention_head_dim,
100
+ heads=num_attention_heads,
101
+ qk_norm="layer_norm" if qk_norm else None,
102
+ eps=1e-6,
103
+ bias=attention_bias,
104
+ out_bias=attention_out_bias,
105
+ processor=CogVideoXAttnProcessor2_0(),
106
+ )
107
+
108
+ # 2. Feed Forward
109
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
110
+
111
+ self.ff = FeedForward(
112
+ dim,
113
+ dropout=dropout,
114
+ activation_fn=activation_fn,
115
+ final_dropout=final_dropout,
116
+ inner_dim=ff_inner_dim,
117
+ bias=ff_bias,
118
+ )
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ encoder_hidden_states: torch.Tensor,
124
+ temb: torch.Tensor,
125
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
126
+ ) -> torch.Tensor:
127
+ text_seq_length = encoder_hidden_states.size(1)
128
+
129
+ # norm & modulate
130
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
131
+ hidden_states, encoder_hidden_states, temb
132
+ )
133
+
134
+ # attention
135
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
136
+ hidden_states=norm_hidden_states,
137
+ encoder_hidden_states=norm_encoder_hidden_states,
138
+ image_rotary_emb=image_rotary_emb,
139
+ )
140
+
141
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
142
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
143
+
144
+ # norm & modulate
145
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
146
+ hidden_states, encoder_hidden_states, temb
147
+ )
148
+
149
+ # feed-forward
150
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
151
+ ff_output = self.ff(norm_hidden_states)
152
+
153
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
154
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
155
+
156
+ return hidden_states, encoder_hidden_states
157
+
158
+
159
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
160
+ """
161
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
162
+
163
+ Parameters:
164
+ num_attention_heads (`int`, defaults to `30`):
165
+ The number of heads to use for multi-head attention.
166
+ attention_head_dim (`int`, defaults to `64`):
167
+ The number of channels in each head.
168
+ in_channels (`int`, defaults to `16`):
169
+ The number of channels in the input.
170
+ out_channels (`int`, *optional*, defaults to `16`):
171
+ The number of channels in the output.
172
+ flip_sin_to_cos (`bool`, defaults to `True`):
173
+ Whether to flip the sin to cos in the time embedding.
174
+ time_embed_dim (`int`, defaults to `512`):
175
+ Output dimension of timestep embeddings.
176
+ text_embed_dim (`int`, defaults to `4096`):
177
+ Input dimension of text embeddings from the text encoder.
178
+ num_layers (`int`, defaults to `30`):
179
+ The number of layers of Transformer blocks to use.
180
+ dropout (`float`, defaults to `0.0`):
181
+ The dropout probability to use.
182
+ attention_bias (`bool`, defaults to `True`):
183
+ Whether or not to use bias in the attention projection layers.
184
+ sample_width (`int`, defaults to `90`):
185
+ The width of the input latents.
186
+ sample_height (`int`, defaults to `60`):
187
+ The height of the input latents.
188
+ sample_frames (`int`, defaults to `49`):
189
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
190
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
191
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
192
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
193
+ patch_size (`int`, defaults to `2`):
194
+ The size of the patches to use in the patch embedding layer.
195
+ temporal_compression_ratio (`int`, defaults to `4`):
196
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
197
+ max_text_seq_length (`int`, defaults to `226`):
198
+ The maximum sequence length of the input text embeddings.
199
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
200
+ Activation function to use in feed-forward.
201
+ timestep_activation_fn (`str`, defaults to `"silu"`):
202
+ Activation function to use when generating the timestep embeddings.
203
+ norm_elementwise_affine (`bool`, defaults to `True`):
204
+ Whether or not to use elementwise affine in normalization layers.
205
+ norm_eps (`float`, defaults to `1e-5`):
206
+ The epsilon value to use in normalization layers.
207
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
208
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
209
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
210
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
211
+ """
212
+
213
+ _supports_gradient_checkpointing = True
214
+
215
+ @register_to_config
216
+ def __init__(
217
+ self,
218
+ num_attention_heads: int = 30,
219
+ attention_head_dim: int = 64,
220
+ in_channels: int = 16,
221
+ out_channels: Optional[int] = 16,
222
+ flip_sin_to_cos: bool = True,
223
+ freq_shift: int = 0,
224
+ time_embed_dim: int = 512,
225
+ text_embed_dim: int = 4096,
226
+ num_layers: int = 30,
227
+ dropout: float = 0.0,
228
+ attention_bias: bool = True,
229
+ sample_width: int = 90,
230
+ sample_height: int = 60,
231
+ sample_frames: int = 49,
232
+ patch_size: int = 2,
233
+ temporal_compression_ratio: int = 4,
234
+ max_text_seq_length: int = 226,
235
+ activation_fn: str = "gelu-approximate",
236
+ timestep_activation_fn: str = "silu",
237
+ norm_elementwise_affine: bool = True,
238
+ norm_eps: float = 1e-5,
239
+ spatial_interpolation_scale: float = 1.875,
240
+ temporal_interpolation_scale: float = 1.0,
241
+ use_rotary_positional_embeddings: bool = False,
242
+ ):
243
+ super().__init__()
244
+ inner_dim = num_attention_heads * attention_head_dim
245
+
246
+ post_patch_height = sample_height // patch_size
247
+ post_patch_width = sample_width // patch_size
248
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
249
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
250
+ self.post_patch_height = post_patch_height
251
+ self.post_patch_width = post_patch_width
252
+ self.post_time_compression_frames = post_time_compression_frames
253
+ self.patch_size = patch_size
254
+
255
+ # 1. Patch embedding
256
+ self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
257
+ self.embedding_dropout = nn.Dropout(dropout)
258
+
259
+ # 2. 3D positional embeddings
260
+ spatial_pos_embedding = get_3d_sincos_pos_embed(
261
+ inner_dim,
262
+ (post_patch_width, post_patch_height),
263
+ post_time_compression_frames,
264
+ spatial_interpolation_scale,
265
+ temporal_interpolation_scale,
266
+ )
267
+ spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
268
+ pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
269
+ pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
270
+ self.register_buffer("pos_embedding", pos_embedding, persistent=False)
271
+
272
+ # 3. Time embeddings
273
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
274
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
275
+
276
+ # 4. Define spatio-temporal transformers blocks
277
+ self.transformer_blocks = nn.ModuleList(
278
+ [
279
+ CogVideoXBlock(
280
+ dim=inner_dim,
281
+ num_attention_heads=num_attention_heads,
282
+ attention_head_dim=attention_head_dim,
283
+ time_embed_dim=time_embed_dim,
284
+ dropout=dropout,
285
+ activation_fn=activation_fn,
286
+ attention_bias=attention_bias,
287
+ norm_elementwise_affine=norm_elementwise_affine,
288
+ norm_eps=norm_eps,
289
+ )
290
+ for _ in range(num_layers)
291
+ ]
292
+ )
293
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
294
+
295
+ # 5. Output blocks
296
+ self.norm_out = AdaLayerNorm(
297
+ embedding_dim=time_embed_dim,
298
+ output_dim=2 * inner_dim,
299
+ norm_elementwise_affine=norm_elementwise_affine,
300
+ norm_eps=norm_eps,
301
+ chunk_dim=1,
302
+ )
303
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
304
+
305
+ self.gradient_checkpointing = False
306
+
307
+ def _set_gradient_checkpointing(self, module, value=False):
308
+ self.gradient_checkpointing = value
309
+
310
+ @property
311
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
312
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
313
+ r"""
314
+ Returns:
315
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
316
+ indexed by its weight name.
317
+ """
318
+ # set recursively
319
+ processors = {}
320
+
321
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
322
+ if hasattr(module, "get_processor"):
323
+ processors[f"{name}.processor"] = module.get_processor()
324
+
325
+ for sub_name, child in module.named_children():
326
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
327
+
328
+ return processors
329
+
330
+ for name, module in self.named_children():
331
+ fn_recursive_add_processors(name, module, processors)
332
+
333
+ return processors
334
+
335
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
336
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
337
+ r"""
338
+ Sets the attention processor to use to compute attention.
339
+
340
+ Parameters:
341
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
342
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
343
+ for **all** `Attention` layers.
344
+
345
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
346
+ processor. This is strongly recommended when setting trainable attention processors.
347
+
348
+ """
349
+ count = len(self.attn_processors.keys())
350
+
351
+ if isinstance(processor, dict) and len(processor) != count:
352
+ raise ValueError(
353
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
354
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
355
+ )
356
+
357
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
358
+ if hasattr(module, "set_processor"):
359
+ if not isinstance(processor, dict):
360
+ module.set_processor(processor)
361
+ else:
362
+ module.set_processor(processor.pop(f"{name}.processor"))
363
+
364
+ for sub_name, child in module.named_children():
365
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
366
+
367
+ for name, module in self.named_children():
368
+ fn_recursive_attn_processor(name, module, processor)
369
+
370
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
371
+ def fuse_qkv_projections(self):
372
+ """
373
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
374
+ are fused. For cross-attention modules, key and value projection matrices are fused.
375
+
376
+ <Tip warning={true}>
377
+
378
+ This API is 🧪 experimental.
379
+
380
+ </Tip>
381
+ """
382
+ self.original_attn_processors = None
383
+
384
+ for _, attn_processor in self.attn_processors.items():
385
+ if "Added" in str(attn_processor.__class__.__name__):
386
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
387
+
388
+ self.original_attn_processors = self.attn_processors
389
+
390
+ for module in self.modules():
391
+ if isinstance(module, Attention):
392
+ module.fuse_projections(fuse=True)
393
+
394
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
395
+
396
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
397
+ def unfuse_qkv_projections(self):
398
+ """Disables the fused QKV projection if enabled.
399
+
400
+ <Tip warning={true}>
401
+
402
+ This API is 🧪 experimental.
403
+
404
+ </Tip>
405
+
406
+ """
407
+ if self.original_attn_processors is not None:
408
+ self.set_attn_processor(self.original_attn_processors)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ encoder_hidden_states: torch.Tensor,
414
+ timestep: Union[int, float, torch.LongTensor],
415
+ timestep_cond: Optional[torch.Tensor] = None,
416
+ inpaint_latents: Optional[torch.Tensor] = None,
417
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418
+ return_dict: bool = True,
419
+ ):
420
+ batch_size, num_frames, channels, height, width = hidden_states.shape
421
+
422
+ # 1. Time embedding
423
+ timesteps = timestep
424
+ t_emb = self.time_proj(timesteps)
425
+
426
+ # timesteps does not contain any weights and will always return f32 tensors
427
+ # but time_embedding might actually be running in fp16. so we need to cast here.
428
+ # there might be better ways to encapsulate this.
429
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
430
+ emb = self.time_embedding(t_emb, timestep_cond)
431
+
432
+ # 2. Patch embedding
433
+ if inpaint_latents is not None:
434
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
435
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
436
+
437
+ # 3. Position embedding
438
+ text_seq_length = encoder_hidden_states.shape[1]
439
+ if not self.config.use_rotary_positional_embeddings:
440
+ seq_length = height * width * num_frames // (self.config.patch_size**2)
441
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
442
+ pos_embeds = self.pos_embedding
443
+ emb_size = hidden_states.size()[-1]
444
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
445
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
446
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
447
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
448
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
449
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
450
+ hidden_states = hidden_states + pos_embeds
451
+ hidden_states = self.embedding_dropout(hidden_states)
452
+
453
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
454
+ hidden_states = hidden_states[:, text_seq_length:]
455
+
456
+ # 4. Transformer blocks
457
+ for i, block in enumerate(self.transformer_blocks):
458
+ if self.training and self.gradient_checkpointing:
459
+
460
+ def create_custom_forward(module):
461
+ def custom_forward(*inputs):
462
+ return module(*inputs)
463
+
464
+ return custom_forward
465
+
466
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
467
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
468
+ create_custom_forward(block),
469
+ hidden_states,
470
+ encoder_hidden_states,
471
+ emb,
472
+ image_rotary_emb,
473
+ **ckpt_kwargs,
474
+ )
475
+ else:
476
+ hidden_states, encoder_hidden_states = block(
477
+ hidden_states=hidden_states,
478
+ encoder_hidden_states=encoder_hidden_states,
479
+ temb=emb,
480
+ image_rotary_emb=image_rotary_emb,
481
+ )
482
+
483
+ if not self.config.use_rotary_positional_embeddings:
484
+ # CogVideoX-2B
485
+ hidden_states = self.norm_final(hidden_states)
486
+ else:
487
+ # CogVideoX-5B
488
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
489
+ hidden_states = self.norm_final(hidden_states)
490
+ hidden_states = hidden_states[:, text_seq_length:]
491
+
492
+ # 5. Final block
493
+ hidden_states = self.norm_out(hidden_states, temb=emb)
494
+ hidden_states = self.proj_out(hidden_states)
495
+
496
+ # 6. Unpatchify
497
+ p = self.config.patch_size
498
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
499
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
500
+
501
+ if not return_dict:
502
+ return (output,)
503
+ return Transformer2DModelOutput(sample=output)
504
+
505
+ @classmethod
506
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
507
+ if subfolder is not None:
508
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
509
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
510
+
511
+ config_file = os.path.join(pretrained_model_path, 'config.json')
512
+ if not os.path.isfile(config_file):
513
+ raise RuntimeError(f"{config_file} does not exist")
514
+ with open(config_file, "r") as f:
515
+ config = json.load(f)
516
+
517
+ from diffusers.utils import WEIGHTS_NAME
518
+ model = cls.from_config(config, **transformer_additional_kwargs)
519
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
520
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
521
+ if os.path.exists(model_file):
522
+ state_dict = torch.load(model_file, map_location="cpu")
523
+ elif os.path.exists(model_file_safetensors):
524
+ from safetensors.torch import load_file, safe_open
525
+ state_dict = load_file(model_file_safetensors)
526
+ else:
527
+ from safetensors.torch import load_file, safe_open
528
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
529
+ state_dict = {}
530
+ for model_file_safetensors in model_files_safetensors:
531
+ _state_dict = load_file(model_file_safetensors)
532
+ for key in _state_dict:
533
+ state_dict[key] = _state_dict[key]
534
+
535
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
536
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
537
+ if len(new_shape) == 5:
538
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
539
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
540
+ else:
541
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
542
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
543
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
544
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
545
+ else:
546
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
547
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
548
+
549
+ tmp_state_dict = {}
550
+ for key in state_dict:
551
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
552
+ tmp_state_dict[key] = state_dict[key]
553
+ else:
554
+ print(key, "Size don't match, skip")
555
+ state_dict = tmp_state_dict
556
+
557
+ m, u = model.load_state_dict(state_dict, strict=False)
558
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
559
+ print(m)
560
+
561
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
562
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
563
+
564
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
565
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
566
+
567
+ return model
cogvideox/pipeline/pipeline_cogvideox.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```python
40
+ >>> import torch
41
+ >>> from diffusers import CogVideoX_Fun_Pipeline
42
+ >>> from diffusers.utils import export_to_video
43
+
44
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
45
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
46
+ >>> prompt = (
47
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
48
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
49
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
50
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
51
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
52
+ ... "atmosphere of this unique musical performance."
53
+ ... )
54
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
55
+ >>> export_to_video(video, "output.mp4", fps=8)
56
+ ```
57
+ """
58
+
59
+
60
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
61
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
62
+ tw = tgt_width
63
+ th = tgt_height
64
+ h, w = src
65
+ r = h / w
66
+ if r > (th / tw):
67
+ resize_height = th
68
+ resize_width = int(round(th / h * w))
69
+ else:
70
+ resize_width = tw
71
+ resize_height = int(round(tw / w * h))
72
+
73
+ crop_top = int(round((th - resize_height) / 2.0))
74
+ crop_left = int(round((tw - resize_width) / 2.0))
75
+
76
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
80
+ def retrieve_timesteps(
81
+ scheduler,
82
+ num_inference_steps: Optional[int] = None,
83
+ device: Optional[Union[str, torch.device]] = None,
84
+ timesteps: Optional[List[int]] = None,
85
+ sigmas: Optional[List[float]] = None,
86
+ **kwargs,
87
+ ):
88
+ """
89
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
90
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
91
+
92
+ Args:
93
+ scheduler (`SchedulerMixin`):
94
+ The scheduler to get timesteps from.
95
+ num_inference_steps (`int`):
96
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
97
+ must be `None`.
98
+ device (`str` or `torch.device`, *optional*):
99
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
100
+ timesteps (`List[int]`, *optional*):
101
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
102
+ `num_inference_steps` and `sigmas` must be `None`.
103
+ sigmas (`List[float]`, *optional*):
104
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
105
+ `num_inference_steps` and `timesteps` must be `None`.
106
+
107
+ Returns:
108
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
109
+ second element is the number of inference steps.
110
+ """
111
+ if timesteps is not None and sigmas is not None:
112
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
113
+ if timesteps is not None:
114
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115
+ if not accepts_timesteps:
116
+ raise ValueError(
117
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118
+ f" timestep schedules. Please check whether you are using the correct scheduler."
119
+ )
120
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ num_inference_steps = len(timesteps)
123
+ elif sigmas is not None:
124
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
141
+ r"""
142
+ Output class for CogVideo pipelines.
143
+
144
+ Args:
145
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
146
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
147
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
148
+ `(batch_size, num_frames, channels, height, width)`.
149
+ """
150
+
151
+ videos: torch.Tensor
152
+
153
+
154
+ class CogVideoX_Fun_Pipeline(DiffusionPipeline):
155
+ r"""
156
+ Pipeline for text-to-video generation using CogVideoX_Fun.
157
+
158
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
159
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
+
161
+ Args:
162
+ vae ([`AutoencoderKL`]):
163
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
164
+ text_encoder ([`T5EncoderModel`]):
165
+ Frozen text-encoder. CogVideoX uses
166
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
167
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
168
+ tokenizer (`T5Tokenizer`):
169
+ Tokenizer of class
170
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
171
+ transformer ([`CogVideoXTransformer3DModel`]):
172
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
173
+ scheduler ([`SchedulerMixin`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
175
+ """
176
+
177
+ _optional_components = []
178
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
179
+
180
+ _callback_tensor_inputs = [
181
+ "latents",
182
+ "prompt_embeds",
183
+ "negative_prompt_embeds",
184
+ ]
185
+
186
+ def __init__(
187
+ self,
188
+ tokenizer: T5Tokenizer,
189
+ text_encoder: T5EncoderModel,
190
+ vae: AutoencoderKLCogVideoX,
191
+ transformer: CogVideoXTransformer3DModel,
192
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.register_modules(
197
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
198
+ )
199
+ self.vae_scale_factor_spatial = (
200
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
201
+ )
202
+ self.vae_scale_factor_temporal = (
203
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
204
+ )
205
+
206
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
207
+
208
+ def _get_t5_prompt_embeds(
209
+ self,
210
+ prompt: Union[str, List[str]] = None,
211
+ num_videos_per_prompt: int = 1,
212
+ max_sequence_length: int = 226,
213
+ device: Optional[torch.device] = None,
214
+ dtype: Optional[torch.dtype] = None,
215
+ ):
216
+ device = device or self._execution_device
217
+ dtype = dtype or self.text_encoder.dtype
218
+
219
+ prompt = [prompt] if isinstance(prompt, str) else prompt
220
+ batch_size = len(prompt)
221
+
222
+ text_inputs = self.tokenizer(
223
+ prompt,
224
+ padding="max_length",
225
+ max_length=max_sequence_length,
226
+ truncation=True,
227
+ add_special_tokens=True,
228
+ return_tensors="pt",
229
+ )
230
+ text_input_ids = text_inputs.input_ids
231
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
232
+
233
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
234
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
235
+ logger.warning(
236
+ "The following part of your input was truncated because `max_sequence_length` is set to "
237
+ f" {max_sequence_length} tokens: {removed_text}"
238
+ )
239
+
240
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
241
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
242
+
243
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
244
+ _, seq_len, _ = prompt_embeds.shape
245
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
246
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
247
+
248
+ return prompt_embeds
249
+
250
+ def encode_prompt(
251
+ self,
252
+ prompt: Union[str, List[str]],
253
+ negative_prompt: Optional[Union[str, List[str]]] = None,
254
+ do_classifier_free_guidance: bool = True,
255
+ num_videos_per_prompt: int = 1,
256
+ prompt_embeds: Optional[torch.Tensor] = None,
257
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
258
+ max_sequence_length: int = 226,
259
+ device: Optional[torch.device] = None,
260
+ dtype: Optional[torch.dtype] = None,
261
+ ):
262
+ r"""
263
+ Encodes the prompt into text encoder hidden states.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ prompt to be encoded
268
+ negative_prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
271
+ less than `1`).
272
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
273
+ Whether to use classifier free guidance or not.
274
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
275
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
276
+ prompt_embeds (`torch.Tensor`, *optional*):
277
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
278
+ provided, text embeddings will be generated from `prompt` input argument.
279
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
280
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
281
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
282
+ argument.
283
+ device: (`torch.device`, *optional*):
284
+ torch device
285
+ dtype: (`torch.dtype`, *optional*):
286
+ torch dtype
287
+ """
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ if prompt is not None:
292
+ batch_size = len(prompt)
293
+ else:
294
+ batch_size = prompt_embeds.shape[0]
295
+
296
+ if prompt_embeds is None:
297
+ prompt_embeds = self._get_t5_prompt_embeds(
298
+ prompt=prompt,
299
+ num_videos_per_prompt=num_videos_per_prompt,
300
+ max_sequence_length=max_sequence_length,
301
+ device=device,
302
+ dtype=dtype,
303
+ )
304
+
305
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
306
+ negative_prompt = negative_prompt or ""
307
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
308
+
309
+ if prompt is not None and type(prompt) is not type(negative_prompt):
310
+ raise TypeError(
311
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
312
+ f" {type(prompt)}."
313
+ )
314
+ elif batch_size != len(negative_prompt):
315
+ raise ValueError(
316
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
317
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
318
+ " the batch size of `prompt`."
319
+ )
320
+
321
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
322
+ prompt=negative_prompt,
323
+ num_videos_per_prompt=num_videos_per_prompt,
324
+ max_sequence_length=max_sequence_length,
325
+ device=device,
326
+ dtype=dtype,
327
+ )
328
+
329
+ return prompt_embeds, negative_prompt_embeds
330
+
331
+ def prepare_latents(
332
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
333
+ ):
334
+ shape = (
335
+ batch_size,
336
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
337
+ num_channels_latents,
338
+ height // self.vae_scale_factor_spatial,
339
+ width // self.vae_scale_factor_spatial,
340
+ )
341
+ if isinstance(generator, list) and len(generator) != batch_size:
342
+ raise ValueError(
343
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
344
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
345
+ )
346
+
347
+ if latents is None:
348
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
349
+ else:
350
+ latents = latents.to(device)
351
+
352
+ # scale the initial noise by the standard deviation required by the scheduler
353
+ latents = latents * self.scheduler.init_noise_sigma
354
+ return latents
355
+
356
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
357
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
358
+ latents = 1 / self.vae.config.scaling_factor * latents
359
+
360
+ frames = self.vae.decode(latents).sample
361
+ frames = (frames / 2 + 0.5).clamp(0, 1)
362
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
363
+ frames = frames.cpu().float().numpy()
364
+ return frames
365
+
366
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
367
+ def prepare_extra_step_kwargs(self, generator, eta):
368
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
369
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
370
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
371
+ # and should be between [0, 1]
372
+
373
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
374
+ extra_step_kwargs = {}
375
+ if accepts_eta:
376
+ extra_step_kwargs["eta"] = eta
377
+
378
+ # check if the scheduler accepts generator
379
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
380
+ if accepts_generator:
381
+ extra_step_kwargs["generator"] = generator
382
+ return extra_step_kwargs
383
+
384
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
385
+ def check_inputs(
386
+ self,
387
+ prompt,
388
+ height,
389
+ width,
390
+ negative_prompt,
391
+ callback_on_step_end_tensor_inputs,
392
+ prompt_embeds=None,
393
+ negative_prompt_embeds=None,
394
+ ):
395
+ if height % 8 != 0 or width % 8 != 0:
396
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
397
+
398
+ if callback_on_step_end_tensor_inputs is not None and not all(
399
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
400
+ ):
401
+ raise ValueError(
402
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
403
+ )
404
+ if prompt is not None and prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
407
+ " only forward one of the two."
408
+ )
409
+ elif prompt is None and prompt_embeds is None:
410
+ raise ValueError(
411
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
412
+ )
413
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
414
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
415
+
416
+ if prompt is not None and negative_prompt_embeds is not None:
417
+ raise ValueError(
418
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
419
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
420
+ )
421
+
422
+ if negative_prompt is not None and negative_prompt_embeds is not None:
423
+ raise ValueError(
424
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
425
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
426
+ )
427
+
428
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
429
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
430
+ raise ValueError(
431
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
432
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
433
+ f" {negative_prompt_embeds.shape}."
434
+ )
435
+
436
+ def fuse_qkv_projections(self) -> None:
437
+ r"""Enables fused QKV projections."""
438
+ self.fusing_transformer = True
439
+ self.transformer.fuse_qkv_projections()
440
+
441
+ def unfuse_qkv_projections(self) -> None:
442
+ r"""Disable QKV projection fusion if enabled."""
443
+ if not self.fusing_transformer:
444
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
445
+ else:
446
+ self.transformer.unfuse_qkv_projections()
447
+ self.fusing_transformer = False
448
+
449
+ def _prepare_rotary_positional_embeddings(
450
+ self,
451
+ height: int,
452
+ width: int,
453
+ num_frames: int,
454
+ device: torch.device,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
457
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
458
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
459
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
460
+
461
+ grid_crops_coords = get_resize_crop_region_for_grid(
462
+ (grid_height, grid_width), base_size_width, base_size_height
463
+ )
464
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
465
+ embed_dim=self.transformer.config.attention_head_dim,
466
+ crops_coords=grid_crops_coords,
467
+ grid_size=(grid_height, grid_width),
468
+ temporal_size=num_frames,
469
+ use_real=True,
470
+ )
471
+
472
+ freqs_cos = freqs_cos.to(device=device)
473
+ freqs_sin = freqs_sin.to(device=device)
474
+ return freqs_cos, freqs_sin
475
+
476
+ @property
477
+ def guidance_scale(self):
478
+ return self._guidance_scale
479
+
480
+ @property
481
+ def num_timesteps(self):
482
+ return self._num_timesteps
483
+
484
+ @property
485
+ def interrupt(self):
486
+ return self._interrupt
487
+
488
+ @torch.no_grad()
489
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
490
+ def __call__(
491
+ self,
492
+ prompt: Optional[Union[str, List[str]]] = None,
493
+ negative_prompt: Optional[Union[str, List[str]]] = None,
494
+ height: int = 480,
495
+ width: int = 720,
496
+ num_frames: int = 49,
497
+ num_inference_steps: int = 50,
498
+ timesteps: Optional[List[int]] = None,
499
+ guidance_scale: float = 6,
500
+ use_dynamic_cfg: bool = False,
501
+ num_videos_per_prompt: int = 1,
502
+ eta: float = 0.0,
503
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
504
+ latents: Optional[torch.FloatTensor] = None,
505
+ prompt_embeds: Optional[torch.FloatTensor] = None,
506
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
507
+ output_type: str = "numpy",
508
+ return_dict: bool = False,
509
+ callback_on_step_end: Optional[
510
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
511
+ ] = None,
512
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
513
+ max_sequence_length: int = 226,
514
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
515
+ """
516
+ Function invoked when calling the pipeline for generation.
517
+
518
+ Args:
519
+ prompt (`str` or `List[str]`, *optional*):
520
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
521
+ instead.
522
+ negative_prompt (`str` or `List[str]`, *optional*):
523
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
524
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
525
+ less than `1`).
526
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
527
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
528
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
529
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
530
+ num_frames (`int`, defaults to `48`):
531
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
532
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
533
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
534
+ needs to be satisfied is that of divisibility mentioned above.
535
+ num_inference_steps (`int`, *optional*, defaults to 50):
536
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
537
+ expense of slower inference.
538
+ timesteps (`List[int]`, *optional*):
539
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
540
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
541
+ passed will be used. Must be in descending order.
542
+ guidance_scale (`float`, *optional*, defaults to 7.0):
543
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
544
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
545
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
546
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
547
+ usually at the expense of lower image quality.
548
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
549
+ The number of videos to generate per prompt.
550
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
551
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
552
+ to make generation deterministic.
553
+ latents (`torch.FloatTensor`, *optional*):
554
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
555
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
556
+ tensor will ge generated by sampling using the supplied random `generator`.
557
+ prompt_embeds (`torch.FloatTensor`, *optional*):
558
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
559
+ provided, text embeddings will be generated from `prompt` input argument.
560
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
561
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
562
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
563
+ argument.
564
+ output_type (`str`, *optional*, defaults to `"pil"`):
565
+ The output format of the generate image. Choose between
566
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
567
+ return_dict (`bool`, *optional*, defaults to `True`):
568
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
569
+ of a plain tuple.
570
+ callback_on_step_end (`Callable`, *optional*):
571
+ A function that calls at the end of each denoising steps during the inference. The function is called
572
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
573
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
574
+ `callback_on_step_end_tensor_inputs`.
575
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
576
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
577
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
578
+ `._callback_tensor_inputs` attribute of your pipeline class.
579
+ max_sequence_length (`int`, defaults to `226`):
580
+ Maximum sequence length in encoded prompt. Must be consistent with
581
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
582
+
583
+ Examples:
584
+
585
+ Returns:
586
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
587
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
588
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
589
+ """
590
+
591
+ if num_frames > 49:
592
+ raise ValueError(
593
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
594
+ )
595
+
596
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
597
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
598
+
599
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
600
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
601
+ num_videos_per_prompt = 1
602
+
603
+ # 1. Check inputs. Raise error if not correct
604
+ self.check_inputs(
605
+ prompt,
606
+ height,
607
+ width,
608
+ negative_prompt,
609
+ callback_on_step_end_tensor_inputs,
610
+ prompt_embeds,
611
+ negative_prompt_embeds,
612
+ )
613
+ self._guidance_scale = guidance_scale
614
+ self._interrupt = False
615
+
616
+ # 2. Default call parameters
617
+ if prompt is not None and isinstance(prompt, str):
618
+ batch_size = 1
619
+ elif prompt is not None and isinstance(prompt, list):
620
+ batch_size = len(prompt)
621
+ else:
622
+ batch_size = prompt_embeds.shape[0]
623
+
624
+ device = self._execution_device
625
+
626
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
627
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
628
+ # corresponds to doing no classifier free guidance.
629
+ do_classifier_free_guidance = guidance_scale > 1.0
630
+
631
+ # 3. Encode input prompt
632
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
633
+ prompt,
634
+ negative_prompt,
635
+ do_classifier_free_guidance,
636
+ num_videos_per_prompt=num_videos_per_prompt,
637
+ prompt_embeds=prompt_embeds,
638
+ negative_prompt_embeds=negative_prompt_embeds,
639
+ max_sequence_length=max_sequence_length,
640
+ device=device,
641
+ )
642
+ if do_classifier_free_guidance:
643
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
644
+
645
+ # 4. Prepare timesteps
646
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
647
+ self._num_timesteps = len(timesteps)
648
+
649
+ # 5. Prepare latents.
650
+ latent_channels = self.transformer.config.in_channels
651
+ latents = self.prepare_latents(
652
+ batch_size * num_videos_per_prompt,
653
+ latent_channels,
654
+ num_frames,
655
+ height,
656
+ width,
657
+ prompt_embeds.dtype,
658
+ device,
659
+ generator,
660
+ latents,
661
+ )
662
+
663
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
664
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
665
+
666
+ # 7. Create rotary embeds if required
667
+ image_rotary_emb = (
668
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
669
+ if self.transformer.config.use_rotary_positional_embeddings
670
+ else None
671
+ )
672
+
673
+ # 8. Denoising loop
674
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
675
+
676
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
677
+ # for DPM-solver++
678
+ old_pred_original_sample = None
679
+ for i, t in enumerate(timesteps):
680
+ if self.interrupt:
681
+ continue
682
+
683
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
684
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
685
+
686
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687
+ timestep = t.expand(latent_model_input.shape[0])
688
+
689
+ # predict noise model_output
690
+ noise_pred = self.transformer(
691
+ hidden_states=latent_model_input,
692
+ encoder_hidden_states=prompt_embeds,
693
+ timestep=timestep,
694
+ image_rotary_emb=image_rotary_emb,
695
+ return_dict=False,
696
+ )[0]
697
+ noise_pred = noise_pred.float()
698
+
699
+ # perform guidance
700
+ if use_dynamic_cfg:
701
+ self._guidance_scale = 1 + guidance_scale * (
702
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
703
+ )
704
+ if do_classifier_free_guidance:
705
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
706
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
707
+
708
+ # compute the previous noisy sample x_t -> x_t-1
709
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
710
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
711
+ else:
712
+ latents, old_pred_original_sample = self.scheduler.step(
713
+ noise_pred,
714
+ old_pred_original_sample,
715
+ t,
716
+ timesteps[i - 1] if i > 0 else None,
717
+ latents,
718
+ **extra_step_kwargs,
719
+ return_dict=False,
720
+ )
721
+ latents = latents.to(prompt_embeds.dtype)
722
+
723
+ # call the callback, if provided
724
+ if callback_on_step_end is not None:
725
+ callback_kwargs = {}
726
+ for k in callback_on_step_end_tensor_inputs:
727
+ callback_kwargs[k] = locals()[k]
728
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
729
+
730
+ latents = callback_outputs.pop("latents", latents)
731
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
732
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
733
+
734
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
735
+ progress_bar.update()
736
+
737
+ if output_type == "numpy":
738
+ video = self.decode_latents(latents)
739
+ elif not output_type == "latent":
740
+ video = self.decode_latents(latents)
741
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
742
+ else:
743
+ video = latents
744
+
745
+ # Offload all models
746
+ self.maybe_free_model_hooks()
747
+
748
+ if not return_dict:
749
+ video = torch.from_numpy(video)
750
+
751
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_inpaint.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from transformers import T5EncoderModel, T5Tokenizer
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.image_processor import VaeImageProcessor
35
+ from einops import rearrange
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ >>> import torch
45
+ >>> from diffusers import CogVideoX_Fun_Pipeline
46
+ >>> from diffusers.utils import export_to_video
47
+
48
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
49
+ >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
50
+ >>> prompt = (
51
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
52
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
53
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
54
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
55
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
56
+ ... "atmosphere of this unique musical performance."
57
+ ... )
58
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
59
+ >>> export_to_video(video, "output.mp4", fps=8)
60
+ ```
61
+ """
62
+
63
+
64
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
65
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
66
+ tw = tgt_width
67
+ th = tgt_height
68
+ h, w = src
69
+ r = h / w
70
+ if r > (th / tw):
71
+ resize_height = th
72
+ resize_width = int(round(th / h * w))
73
+ else:
74
+ resize_width = tw
75
+ resize_height = int(round(tw / w * h))
76
+
77
+ crop_top = int(round((th - resize_height) / 2.0))
78
+ crop_left = int(round((tw - resize_width) / 2.0))
79
+
80
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
81
+
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
84
+ def retrieve_timesteps(
85
+ scheduler,
86
+ num_inference_steps: Optional[int] = None,
87
+ device: Optional[Union[str, torch.device]] = None,
88
+ timesteps: Optional[List[int]] = None,
89
+ sigmas: Optional[List[float]] = None,
90
+ **kwargs,
91
+ ):
92
+ """
93
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
+
96
+ Args:
97
+ scheduler (`SchedulerMixin`):
98
+ The scheduler to get timesteps from.
99
+ num_inference_steps (`int`):
100
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
101
+ must be `None`.
102
+ device (`str` or `torch.device`, *optional*):
103
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
104
+ timesteps (`List[int]`, *optional*):
105
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
106
+ `num_inference_steps` and `sigmas` must be `None`.
107
+ sigmas (`List[float]`, *optional*):
108
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
109
+ `num_inference_steps` and `timesteps` must be `None`.
110
+
111
+ Returns:
112
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
113
+ second element is the number of inference steps.
114
+ """
115
+ if timesteps is not None and sigmas is not None:
116
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ elif sigmas is not None:
128
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accept_sigmas:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ else:
138
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ return timesteps, num_inference_steps
141
+
142
+
143
+ def resize_mask(mask, latent, process_first_frame_only=True):
144
+ latent_size = latent.size()
145
+ batch_size, channels, num_frames, height, width = mask.shape
146
+
147
+ if process_first_frame_only:
148
+ target_size = list(latent_size[2:])
149
+ target_size[0] = 1
150
+ first_frame_resized = F.interpolate(
151
+ mask[:, :, 0:1, :, :],
152
+ size=target_size,
153
+ mode='trilinear',
154
+ align_corners=False
155
+ )
156
+
157
+ target_size = list(latent_size[2:])
158
+ target_size[0] = target_size[0] - 1
159
+ if target_size[0] != 0:
160
+ remaining_frames_resized = F.interpolate(
161
+ mask[:, :, 1:, :, :],
162
+ size=target_size,
163
+ mode='trilinear',
164
+ align_corners=False
165
+ )
166
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
167
+ else:
168
+ resized_mask = first_frame_resized
169
+ else:
170
+ target_size = list(latent_size[2:])
171
+ resized_mask = F.interpolate(
172
+ mask,
173
+ size=target_size,
174
+ mode='trilinear',
175
+ align_corners=False
176
+ )
177
+ return resized_mask
178
+
179
+
180
+ @dataclass
181
+ class CogVideoX_Fun_PipelineOutput(BaseOutput):
182
+ r"""
183
+ Output class for CogVideo pipelines.
184
+
185
+ Args:
186
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
187
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
188
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
189
+ `(batch_size, num_frames, channels, height, width)`.
190
+ """
191
+
192
+ videos: torch.Tensor
193
+
194
+
195
+ class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
196
+ r"""
197
+ Pipeline for text-to-video generation using CogVideoX.
198
+
199
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
200
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
201
+
202
+ Args:
203
+ vae ([`AutoencoderKL`]):
204
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
205
+ text_encoder ([`T5EncoderModel`]):
206
+ Frozen text-encoder. CogVideoX_Fun uses
207
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
208
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
209
+ tokenizer (`T5Tokenizer`):
210
+ Tokenizer of class
211
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
212
+ transformer ([`CogVideoXTransformer3DModel`]):
213
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
214
+ scheduler ([`SchedulerMixin`]):
215
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
216
+ """
217
+
218
+ _optional_components = []
219
+ model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
220
+
221
+ _callback_tensor_inputs = [
222
+ "latents",
223
+ "prompt_embeds",
224
+ "negative_prompt_embeds",
225
+ ]
226
+
227
+ def __init__(
228
+ self,
229
+ tokenizer: T5Tokenizer,
230
+ text_encoder: T5EncoderModel,
231
+ vae: AutoencoderKLCogVideoX,
232
+ transformer: CogVideoXTransformer3DModel,
233
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
234
+ ):
235
+ super().__init__()
236
+
237
+ self.register_modules(
238
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
239
+ )
240
+ self.vae_scale_factor_spatial = (
241
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
242
+ )
243
+ self.vae_scale_factor_temporal = (
244
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
245
+ )
246
+
247
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
248
+
249
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
250
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
251
+ self.mask_processor = VaeImageProcessor(
252
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
253
+ )
254
+
255
+ def _get_t5_prompt_embeds(
256
+ self,
257
+ prompt: Union[str, List[str]] = None,
258
+ num_videos_per_prompt: int = 1,
259
+ max_sequence_length: int = 226,
260
+ device: Optional[torch.device] = None,
261
+ dtype: Optional[torch.dtype] = None,
262
+ ):
263
+ device = device or self._execution_device
264
+ dtype = dtype or self.text_encoder.dtype
265
+
266
+ prompt = [prompt] if isinstance(prompt, str) else prompt
267
+ batch_size = len(prompt)
268
+
269
+ text_inputs = self.tokenizer(
270
+ prompt,
271
+ padding="max_length",
272
+ max_length=max_sequence_length,
273
+ truncation=True,
274
+ add_special_tokens=True,
275
+ return_tensors="pt",
276
+ )
277
+ text_input_ids = text_inputs.input_ids
278
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
279
+
280
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
281
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
282
+ logger.warning(
283
+ "The following part of your input was truncated because `max_sequence_length` is set to "
284
+ f" {max_sequence_length} tokens: {removed_text}"
285
+ )
286
+
287
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
288
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
289
+
290
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
291
+ _, seq_len, _ = prompt_embeds.shape
292
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
293
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
294
+
295
+ return prompt_embeds
296
+
297
+ def encode_prompt(
298
+ self,
299
+ prompt: Union[str, List[str]],
300
+ negative_prompt: Optional[Union[str, List[str]]] = None,
301
+ do_classifier_free_guidance: bool = True,
302
+ num_videos_per_prompt: int = 1,
303
+ prompt_embeds: Optional[torch.Tensor] = None,
304
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
305
+ max_sequence_length: int = 226,
306
+ device: Optional[torch.device] = None,
307
+ dtype: Optional[torch.dtype] = None,
308
+ ):
309
+ r"""
310
+ Encodes the prompt into text encoder hidden states.
311
+
312
+ Args:
313
+ prompt (`str` or `List[str]`, *optional*):
314
+ prompt to be encoded
315
+ negative_prompt (`str` or `List[str]`, *optional*):
316
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
317
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
318
+ less than `1`).
319
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
320
+ Whether to use classifier free guidance or not.
321
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
322
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
323
+ prompt_embeds (`torch.Tensor`, *optional*):
324
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
325
+ provided, text embeddings will be generated from `prompt` input argument.
326
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
327
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
328
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
329
+ argument.
330
+ device: (`torch.device`, *optional*):
331
+ torch device
332
+ dtype: (`torch.dtype`, *optional*):
333
+ torch dtype
334
+ """
335
+ device = device or self._execution_device
336
+
337
+ prompt = [prompt] if isinstance(prompt, str) else prompt
338
+ if prompt is not None:
339
+ batch_size = len(prompt)
340
+ else:
341
+ batch_size = prompt_embeds.shape[0]
342
+
343
+ if prompt_embeds is None:
344
+ prompt_embeds = self._get_t5_prompt_embeds(
345
+ prompt=prompt,
346
+ num_videos_per_prompt=num_videos_per_prompt,
347
+ max_sequence_length=max_sequence_length,
348
+ device=device,
349
+ dtype=dtype,
350
+ )
351
+
352
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
353
+ negative_prompt = negative_prompt or ""
354
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
355
+
356
+ if prompt is not None and type(prompt) is not type(negative_prompt):
357
+ raise TypeError(
358
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
359
+ f" {type(prompt)}."
360
+ )
361
+ elif batch_size != len(negative_prompt):
362
+ raise ValueError(
363
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
+ " the batch size of `prompt`."
366
+ )
367
+
368
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
369
+ prompt=negative_prompt,
370
+ num_videos_per_prompt=num_videos_per_prompt,
371
+ max_sequence_length=max_sequence_length,
372
+ device=device,
373
+ dtype=dtype,
374
+ )
375
+
376
+ return prompt_embeds, negative_prompt_embeds
377
+
378
+ def prepare_latents(
379
+ self,
380
+ batch_size,
381
+ num_channels_latents,
382
+ height,
383
+ width,
384
+ video_length,
385
+ dtype,
386
+ device,
387
+ generator,
388
+ latents=None,
389
+ video=None,
390
+ timestep=None,
391
+ is_strength_max=True,
392
+ return_noise=False,
393
+ return_video_latents=False,
394
+ ):
395
+ shape = (
396
+ batch_size,
397
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
398
+ num_channels_latents,
399
+ height // self.vae_scale_factor_spatial,
400
+ width // self.vae_scale_factor_spatial,
401
+ )
402
+ if isinstance(generator, list) and len(generator) != batch_size:
403
+ raise ValueError(
404
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
405
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
406
+ )
407
+
408
+ if return_video_latents or (latents is None and not is_strength_max):
409
+ video = video.to(device=device, dtype=self.vae.dtype)
410
+
411
+ bs = 1
412
+ new_video = []
413
+ for i in range(0, video.shape[0], bs):
414
+ video_bs = video[i : i + bs]
415
+ video_bs = self.vae.encode(video_bs)[0]
416
+ video_bs = video_bs.sample()
417
+ new_video.append(video_bs)
418
+ video = torch.cat(new_video, dim = 0)
419
+ video = video * self.vae.config.scaling_factor
420
+
421
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
422
+ video_latents = video_latents.to(device=device, dtype=dtype)
423
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
424
+
425
+ if latents is None:
426
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
427
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
428
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
429
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
430
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
431
+ else:
432
+ noise = latents.to(device)
433
+ latents = noise * self.scheduler.init_noise_sigma
434
+
435
+ # scale the initial noise by the standard deviation required by the scheduler
436
+ outputs = (latents,)
437
+
438
+ if return_noise:
439
+ outputs += (noise,)
440
+
441
+ if return_video_latents:
442
+ outputs += (video_latents,)
443
+
444
+ return outputs
445
+
446
+ def prepare_mask_latents(
447
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
448
+ ):
449
+ # resize the mask to latents shape as we concatenate the mask to the latents
450
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
451
+ # and half precision
452
+
453
+ if mask is not None:
454
+ mask = mask.to(device=device, dtype=self.vae.dtype)
455
+ bs = 1
456
+ new_mask = []
457
+ for i in range(0, mask.shape[0], bs):
458
+ mask_bs = mask[i : i + bs]
459
+ mask_bs = self.vae.encode(mask_bs)[0]
460
+ mask_bs = mask_bs.mode()
461
+ new_mask.append(mask_bs)
462
+ mask = torch.cat(new_mask, dim = 0)
463
+ mask = mask * self.vae.config.scaling_factor
464
+
465
+ if masked_image is not None:
466
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
467
+ bs = 1
468
+ new_mask_pixel_values = []
469
+ for i in range(0, masked_image.shape[0], bs):
470
+ mask_pixel_values_bs = masked_image[i : i + bs]
471
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
472
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
473
+ new_mask_pixel_values.append(mask_pixel_values_bs)
474
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
475
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
476
+ else:
477
+ masked_image_latents = None
478
+
479
+ return mask, masked_image_latents
480
+
481
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
482
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
483
+ latents = 1 / self.vae.config.scaling_factor * latents
484
+
485
+ frames = self.vae.decode(latents).sample
486
+ frames = (frames / 2 + 0.5).clamp(0, 1)
487
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
488
+ frames = frames.cpu().float().numpy()
489
+ return frames
490
+
491
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
492
+ def prepare_extra_step_kwargs(self, generator, eta):
493
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
494
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
495
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
496
+ # and should be between [0, 1]
497
+
498
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
499
+ extra_step_kwargs = {}
500
+ if accepts_eta:
501
+ extra_step_kwargs["eta"] = eta
502
+
503
+ # check if the scheduler accepts generator
504
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
505
+ if accepts_generator:
506
+ extra_step_kwargs["generator"] = generator
507
+ return extra_step_kwargs
508
+
509
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
510
+ def check_inputs(
511
+ self,
512
+ prompt,
513
+ height,
514
+ width,
515
+ negative_prompt,
516
+ callback_on_step_end_tensor_inputs,
517
+ prompt_embeds=None,
518
+ negative_prompt_embeds=None,
519
+ ):
520
+ if height % 8 != 0 or width % 8 != 0:
521
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
522
+
523
+ if callback_on_step_end_tensor_inputs is not None and not all(
524
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
525
+ ):
526
+ raise ValueError(
527
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
528
+ )
529
+ if prompt is not None and prompt_embeds is not None:
530
+ raise ValueError(
531
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
532
+ " only forward one of the two."
533
+ )
534
+ elif prompt is None and prompt_embeds is None:
535
+ raise ValueError(
536
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
537
+ )
538
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
539
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
540
+
541
+ if prompt is not None and negative_prompt_embeds is not None:
542
+ raise ValueError(
543
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
544
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
545
+ )
546
+
547
+ if negative_prompt is not None and negative_prompt_embeds is not None:
548
+ raise ValueError(
549
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
550
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
551
+ )
552
+
553
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
554
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
555
+ raise ValueError(
556
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
557
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
558
+ f" {negative_prompt_embeds.shape}."
559
+ )
560
+
561
+ def fuse_qkv_projections(self) -> None:
562
+ r"""Enables fused QKV projections."""
563
+ self.fusing_transformer = True
564
+ self.transformer.fuse_qkv_projections()
565
+
566
+ def unfuse_qkv_projections(self) -> None:
567
+ r"""Disable QKV projection fusion if enabled."""
568
+ if not self.fusing_transformer:
569
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
570
+ else:
571
+ self.transformer.unfuse_qkv_projections()
572
+ self.fusing_transformer = False
573
+
574
+ def _prepare_rotary_positional_embeddings(
575
+ self,
576
+ height: int,
577
+ width: int,
578
+ num_frames: int,
579
+ device: torch.device,
580
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
581
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
582
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
583
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
584
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
585
+
586
+ grid_crops_coords = get_resize_crop_region_for_grid(
587
+ (grid_height, grid_width), base_size_width, base_size_height
588
+ )
589
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
590
+ embed_dim=self.transformer.config.attention_head_dim,
591
+ crops_coords=grid_crops_coords,
592
+ grid_size=(grid_height, grid_width),
593
+ temporal_size=num_frames,
594
+ use_real=True,
595
+ )
596
+
597
+ freqs_cos = freqs_cos.to(device=device)
598
+ freqs_sin = freqs_sin.to(device=device)
599
+ return freqs_cos, freqs_sin
600
+
601
+ @property
602
+ def guidance_scale(self):
603
+ return self._guidance_scale
604
+
605
+ @property
606
+ def num_timesteps(self):
607
+ return self._num_timesteps
608
+
609
+ @property
610
+ def interrupt(self):
611
+ return self._interrupt
612
+
613
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
614
+ def get_timesteps(self, num_inference_steps, strength, device):
615
+ # get the original timestep using init_timestep
616
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
617
+
618
+ t_start = max(num_inference_steps - init_timestep, 0)
619
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
620
+
621
+ return timesteps, num_inference_steps - t_start
622
+
623
+ @torch.no_grad()
624
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
625
+ def __call__(
626
+ self,
627
+ prompt: Optional[Union[str, List[str]]] = None,
628
+ negative_prompt: Optional[Union[str, List[str]]] = None,
629
+ height: int = 480,
630
+ width: int = 720,
631
+ video: Union[torch.FloatTensor] = None,
632
+ mask_video: Union[torch.FloatTensor] = None,
633
+ masked_video_latents: Union[torch.FloatTensor] = None,
634
+ num_frames: int = 49,
635
+ num_inference_steps: int = 50,
636
+ timesteps: Optional[List[int]] = None,
637
+ guidance_scale: float = 6,
638
+ use_dynamic_cfg: bool = False,
639
+ num_videos_per_prompt: int = 1,
640
+ eta: float = 0.0,
641
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
642
+ latents: Optional[torch.FloatTensor] = None,
643
+ prompt_embeds: Optional[torch.FloatTensor] = None,
644
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
645
+ output_type: str = "numpy",
646
+ return_dict: bool = False,
647
+ callback_on_step_end: Optional[
648
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
649
+ ] = None,
650
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
651
+ max_sequence_length: int = 226,
652
+ strength: float = 1,
653
+ comfyui_progressbar: bool = False,
654
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
655
+ """
656
+ Function invoked when calling the pipeline for generation.
657
+
658
+ Args:
659
+ prompt (`str` or `List[str]`, *optional*):
660
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
661
+ instead.
662
+ negative_prompt (`str` or `List[str]`, *optional*):
663
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
664
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
665
+ less than `1`).
666
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
667
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
668
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
669
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
670
+ num_frames (`int`, defaults to `48`):
671
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
672
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
673
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
674
+ needs to be satisfied is that of divisibility mentioned above.
675
+ num_inference_steps (`int`, *optional*, defaults to 50):
676
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
677
+ expense of slower inference.
678
+ timesteps (`List[int]`, *optional*):
679
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
680
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
681
+ passed will be used. Must be in descending order.
682
+ guidance_scale (`float`, *optional*, defaults to 7.0):
683
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
684
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
685
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
686
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
687
+ usually at the expense of lower image quality.
688
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
689
+ The number of videos to generate per prompt.
690
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
691
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
692
+ to make generation deterministic.
693
+ latents (`torch.FloatTensor`, *optional*):
694
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
695
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
696
+ tensor will ge generated by sampling using the supplied random `generator`.
697
+ prompt_embeds (`torch.FloatTensor`, *optional*):
698
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
699
+ provided, text embeddings will be generated from `prompt` input argument.
700
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
701
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
702
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
703
+ argument.
704
+ output_type (`str`, *optional*, defaults to `"pil"`):
705
+ The output format of the generate image. Choose between
706
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
707
+ return_dict (`bool`, *optional*, defaults to `True`):
708
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
709
+ of a plain tuple.
710
+ callback_on_step_end (`Callable`, *optional*):
711
+ A function that calls at the end of each denoising steps during the inference. The function is called
712
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
713
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
714
+ `callback_on_step_end_tensor_inputs`.
715
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
716
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
717
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
718
+ `._callback_tensor_inputs` attribute of your pipeline class.
719
+ max_sequence_length (`int`, defaults to `226`):
720
+ Maximum sequence length in encoded prompt. Must be consistent with
721
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
722
+
723
+ Examples:
724
+
725
+ Returns:
726
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
727
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
728
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
729
+ """
730
+
731
+ if num_frames > 49:
732
+ raise ValueError(
733
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
734
+ )
735
+
736
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
737
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
738
+
739
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
740
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
741
+ num_videos_per_prompt = 1
742
+
743
+ # 1. Check inputs. Raise error if not correct
744
+ self.check_inputs(
745
+ prompt,
746
+ height,
747
+ width,
748
+ negative_prompt,
749
+ callback_on_step_end_tensor_inputs,
750
+ prompt_embeds,
751
+ negative_prompt_embeds,
752
+ )
753
+ self._guidance_scale = guidance_scale
754
+ self._interrupt = False
755
+
756
+ # 2. Default call parameters
757
+ if prompt is not None and isinstance(prompt, str):
758
+ batch_size = 1
759
+ elif prompt is not None and isinstance(prompt, list):
760
+ batch_size = len(prompt)
761
+ else:
762
+ batch_size = prompt_embeds.shape[0]
763
+
764
+ device = self._execution_device
765
+
766
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
767
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
768
+ # corresponds to doing no classifier free guidance.
769
+ do_classifier_free_guidance = guidance_scale > 1.0
770
+
771
+ # 3. Encode input prompt
772
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
773
+ prompt,
774
+ negative_prompt,
775
+ do_classifier_free_guidance,
776
+ num_videos_per_prompt=num_videos_per_prompt,
777
+ prompt_embeds=prompt_embeds,
778
+ negative_prompt_embeds=negative_prompt_embeds,
779
+ max_sequence_length=max_sequence_length,
780
+ device=device,
781
+ )
782
+ if do_classifier_free_guidance:
783
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
784
+
785
+ # 4. set timesteps
786
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
787
+ timesteps, num_inference_steps = self.get_timesteps(
788
+ num_inference_steps=num_inference_steps, strength=strength, device=device
789
+ )
790
+ self._num_timesteps = len(timesteps)
791
+ if comfyui_progressbar:
792
+ from comfy.utils import ProgressBar
793
+ pbar = ProgressBar(num_inference_steps + 2)
794
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
795
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
796
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
797
+ is_strength_max = strength == 1.0
798
+
799
+ # 5. Prepare latents.
800
+ if video is not None:
801
+ video_length = video.shape[2]
802
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
803
+ init_video = init_video.to(dtype=torch.float32)
804
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
805
+ else:
806
+ init_video = None
807
+
808
+ num_channels_latents = self.vae.config.latent_channels
809
+ num_channels_transformer = self.transformer.config.in_channels
810
+ return_image_latents = num_channels_transformer == num_channels_latents
811
+
812
+ latents_outputs = self.prepare_latents(
813
+ batch_size * num_videos_per_prompt,
814
+ num_channels_latents,
815
+ height,
816
+ width,
817
+ video_length,
818
+ prompt_embeds.dtype,
819
+ device,
820
+ generator,
821
+ latents,
822
+ video=init_video,
823
+ timestep=latent_timestep,
824
+ is_strength_max=is_strength_max,
825
+ return_noise=True,
826
+ return_video_latents=return_image_latents,
827
+ )
828
+ if return_image_latents:
829
+ latents, noise, image_latents = latents_outputs
830
+ else:
831
+ latents, noise = latents_outputs
832
+ if comfyui_progressbar:
833
+ pbar.update(1)
834
+
835
+ if mask_video is not None:
836
+ if (mask_video == 255).all():
837
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
838
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
839
+
840
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
841
+ masked_video_latents_input = (
842
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
843
+ )
844
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
845
+ else:
846
+ # Prepare mask latent variables
847
+ video_length = video.shape[2]
848
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
849
+ mask_condition = mask_condition.to(dtype=torch.float32)
850
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
851
+
852
+ if num_channels_transformer != num_channels_latents:
853
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
854
+ if masked_video_latents is None:
855
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
856
+ else:
857
+ masked_video = masked_video_latents
858
+
859
+ _, masked_video_latents = self.prepare_mask_latents(
860
+ None,
861
+ masked_video,
862
+ batch_size,
863
+ height,
864
+ width,
865
+ prompt_embeds.dtype,
866
+ device,
867
+ generator,
868
+ do_classifier_free_guidance,
869
+ )
870
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
871
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
872
+
873
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
874
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
875
+
876
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
877
+ masked_video_latents_input = (
878
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
879
+ )
880
+
881
+ mask = rearrange(mask, "b c f h w -> b f c h w")
882
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
883
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
884
+
885
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
886
+ else:
887
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
888
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
889
+ mask = rearrange(mask, "b c f h w -> b f c h w")
890
+
891
+ inpaint_latents = None
892
+ else:
893
+ if num_channels_transformer != num_channels_latents:
894
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
895
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
896
+
897
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
898
+ masked_video_latents_input = (
899
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
900
+ )
901
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
902
+ else:
903
+ mask = torch.zeros_like(init_video[:, :1])
904
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
905
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
906
+ mask = rearrange(mask, "b c f h w -> b f c h w")
907
+
908
+ inpaint_latents = None
909
+ if comfyui_progressbar:
910
+ pbar.update(1)
911
+
912
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
913
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
914
+
915
+ # 7. Create rotary embeds if required
916
+ image_rotary_emb = (
917
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
918
+ if self.transformer.config.use_rotary_positional_embeddings
919
+ else None
920
+ )
921
+
922
+ # 8. Denoising loop
923
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
924
+
925
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
926
+ # for DPM-solver++
927
+ old_pred_original_sample = None
928
+ for i, t in enumerate(timesteps):
929
+ if self.interrupt:
930
+ continue
931
+
932
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
933
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
934
+
935
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
936
+ timestep = t.expand(latent_model_input.shape[0])
937
+
938
+ # predict noise model_output
939
+ noise_pred = self.transformer(
940
+ hidden_states=latent_model_input,
941
+ encoder_hidden_states=prompt_embeds,
942
+ timestep=timestep,
943
+ image_rotary_emb=image_rotary_emb,
944
+ return_dict=False,
945
+ inpaint_latents=inpaint_latents,
946
+ )[0]
947
+ noise_pred = noise_pred.float()
948
+
949
+ # perform guidance
950
+ if use_dynamic_cfg:
951
+ self._guidance_scale = 1 + guidance_scale * (
952
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
953
+ )
954
+ if do_classifier_free_guidance:
955
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
956
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
957
+
958
+ # compute the previous noisy sample x_t -> x_t-1
959
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
960
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
961
+ else:
962
+ latents, old_pred_original_sample = self.scheduler.step(
963
+ noise_pred,
964
+ old_pred_original_sample,
965
+ t,
966
+ timesteps[i - 1] if i > 0 else None,
967
+ latents,
968
+ **extra_step_kwargs,
969
+ return_dict=False,
970
+ )
971
+ latents = latents.to(prompt_embeds.dtype)
972
+
973
+ # call the callback, if provided
974
+ if callback_on_step_end is not None:
975
+ callback_kwargs = {}
976
+ for k in callback_on_step_end_tensor_inputs:
977
+ callback_kwargs[k] = locals()[k]
978
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
979
+
980
+ latents = callback_outputs.pop("latents", latents)
981
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
982
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
983
+
984
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
985
+ progress_bar.update()
986
+ if comfyui_progressbar:
987
+ pbar.update(1)
988
+
989
+ if output_type == "numpy":
990
+ video = self.decode_latents(latents)
991
+ elif not output_type == "latent":
992
+ video = self.decode_latents(latents)
993
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
994
+ else:
995
+ video = latents
996
+
997
+ # Offload all models
998
+ self.maybe_free_model_hooks()
999
+
1000
+ if not return_dict:
1001
+ video = torch.from_numpy(video)
1002
+
1003
+ return CogVideoX_Fun_PipelineOutput(videos=video)
cogvideox/ui/ui.py ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import base64
4
+ import gc
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from glob import glob
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import pkg_resources
15
+ import requests
16
+ import torch
17
+ from diffusers import (AutoencoderKL, AutoencoderKLCogVideoX,
18
+ CogVideoXDDIMScheduler, DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
21
+ PNDMScheduler)
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from omegaconf import OmegaConf
24
+ from PIL import Image
25
+ from safetensors import safe_open
26
+ from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
27
+ T5EncoderModel, T5Tokenizer)
28
+
29
+ from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
30
+ from ..models.autoencoder_magvit import AutoencoderKLCogVideoX
31
+ from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
32
+ from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
33
+ from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
34
+ CogVideoX_Fun_Pipeline_Inpaint
35
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
36
+ from cogvideox.utils.utils import (
37
+ get_image_to_video_latent, get_video_to_video_latent,
38
+ get_width_and_height_from_image_and_base_resolution, save_videos_grid)
39
+
40
+ scheduler_dict = {
41
+ "Euler": EulerDiscreteScheduler,
42
+ "Euler A": EulerAncestralDiscreteScheduler,
43
+ "DPM++": DPMSolverMultistepScheduler,
44
+ "PNDM": PNDMScheduler,
45
+ "DDIM_Cog": CogVideoXDDIMScheduler,
46
+ "DDIM_Origin": DDIMScheduler,
47
+ }
48
+
49
+ gradio_version = pkg_resources.get_distribution("gradio").version
50
+ gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
51
+
52
+ css = """
53
+ .toolbutton {
54
+ margin-buttom: 0em 0em 0em 0em;
55
+ max-width: 2.5em;
56
+ min-width: 2.5em !important;
57
+ height: 2.5em;
58
+ }
59
+ """
60
+
61
+ class CogVideoX_I2VController:
62
+ def __init__(self, low_gpu_memory_mode, weight_dtype):
63
+ # config dirs
64
+ self.basedir = os.getcwd()
65
+ self.config_dir = os.path.join(self.basedir, "config")
66
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
67
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
68
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
69
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
70
+ self.savedir_sample = os.path.join(self.savedir, "sample")
71
+ os.makedirs(self.savedir, exist_ok=True)
72
+
73
+ self.diffusion_transformer_list = []
74
+ self.motion_module_list = []
75
+ self.personalized_model_list = []
76
+
77
+ self.refresh_diffusion_transformer()
78
+ self.refresh_motion_module()
79
+ self.refresh_personalized_model()
80
+
81
+ # config models
82
+ self.tokenizer = None
83
+ self.text_encoder = None
84
+ self.vae = None
85
+ self.transformer = None
86
+ self.pipeline = None
87
+ self.motion_module_path = "none"
88
+ self.base_model_path = "none"
89
+ self.lora_model_path = "none"
90
+ self.low_gpu_memory_mode = low_gpu_memory_mode
91
+
92
+ self.weight_dtype = weight_dtype
93
+
94
+ def refresh_diffusion_transformer(self):
95
+ self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
96
+
97
+ def refresh_motion_module(self):
98
+ motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
99
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
100
+
101
+ def refresh_personalized_model(self):
102
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
103
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
104
+
105
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
106
+ print("Update diffusion transformer")
107
+ if diffusion_transformer_dropdown == "none":
108
+ return gr.update()
109
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
110
+ diffusion_transformer_dropdown,
111
+ subfolder="vae",
112
+ ).to(self.weight_dtype)
113
+
114
+ # Get Transformer
115
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
116
+ diffusion_transformer_dropdown,
117
+ subfolder="transformer",
118
+ ).to(self.weight_dtype)
119
+
120
+ # Get pipeline
121
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
122
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
123
+ diffusion_transformer_dropdown,
124
+ vae=self.vae,
125
+ transformer=self.transformer,
126
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
127
+ torch_dtype=self.weight_dtype
128
+ )
129
+ else:
130
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
131
+ diffusion_transformer_dropdown,
132
+ vae=self.vae,
133
+ transformer=self.transformer,
134
+ scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
135
+ torch_dtype=self.weight_dtype
136
+ )
137
+
138
+ if self.low_gpu_memory_mode:
139
+ self.pipeline.enable_sequential_cpu_offload()
140
+ else:
141
+ self.pipeline.enable_model_cpu_offload()
142
+ print("Update diffusion transformer done")
143
+ return gr.update()
144
+
145
+ def update_base_model(self, base_model_dropdown):
146
+ self.base_model_path = base_model_dropdown
147
+ print("Update base model")
148
+ if base_model_dropdown == "none":
149
+ return gr.update()
150
+ if self.transformer is None:
151
+ gr.Info(f"Please select a pretrained model path.")
152
+ return gr.update(value=None)
153
+ else:
154
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
155
+ base_model_state_dict = {}
156
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
157
+ for key in f.keys():
158
+ base_model_state_dict[key] = f.get_tensor(key)
159
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
160
+ print("Update base done")
161
+ return gr.update()
162
+
163
+ def update_lora_model(self, lora_model_dropdown):
164
+ print("Update lora model")
165
+ if lora_model_dropdown == "none":
166
+ self.lora_model_path = "none"
167
+ return gr.update()
168
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
169
+ self.lora_model_path = lora_model_dropdown
170
+ return gr.update()
171
+
172
+ def generate(
173
+ self,
174
+ diffusion_transformer_dropdown,
175
+ base_model_dropdown,
176
+ lora_model_dropdown,
177
+ lora_alpha_slider,
178
+ prompt_textbox,
179
+ negative_prompt_textbox,
180
+ sampler_dropdown,
181
+ sample_step_slider,
182
+ resize_method,
183
+ width_slider,
184
+ height_slider,
185
+ base_resolution,
186
+ generation_method,
187
+ length_slider,
188
+ overlap_video_length,
189
+ partial_video_length,
190
+ cfg_scale_slider,
191
+ start_image,
192
+ end_image,
193
+ validation_video,
194
+ denoise_strength,
195
+ seed_textbox,
196
+ is_api = False,
197
+ ):
198
+ gc.collect()
199
+ torch.cuda.empty_cache()
200
+ torch.cuda.ipc_collect()
201
+
202
+ if self.transformer is None:
203
+ raise gr.Error(f"Please select a pretrained model path.")
204
+
205
+ if self.base_model_path != base_model_dropdown:
206
+ self.update_base_model(base_model_dropdown)
207
+
208
+ if self.lora_model_path != lora_model_dropdown:
209
+ print("Update lora model")
210
+ self.update_lora_model(lora_model_dropdown)
211
+
212
+ if resize_method == "Resize according to Reference":
213
+ if start_image is None and validation_video is None:
214
+ if is_api:
215
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
216
+ else:
217
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
218
+
219
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
220
+
221
+ if validation_video is not None:
222
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
223
+ else:
224
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
225
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
226
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
227
+
228
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
229
+ if is_api:
230
+ return "", f"Please select an image to video pretrained model while using image to video."
231
+ else:
232
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
233
+
234
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
235
+ if is_api:
236
+ return "", f"Please select an image to video pretrained model while using long video generation."
237
+ else:
238
+ raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
239
+
240
+ if start_image is None and end_image is not None:
241
+ if is_api:
242
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
243
+ else:
244
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
245
+
246
+ is_image = True if generation_method == "Image Generation" else False
247
+
248
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
249
+ if self.lora_model_path != "none":
250
+ # lora part
251
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
252
+
253
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
254
+ else: seed_textbox = np.random.randint(0, 1e10)
255
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
256
+
257
+ try:
258
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
259
+ if generation_method == "Long Video Generation":
260
+ if validation_video is not None:
261
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
262
+ init_frames = 0
263
+ last_frames = init_frames + partial_video_length
264
+ while init_frames < length_slider:
265
+ if last_frames >= length_slider:
266
+ _partial_video_length = length_slider - init_frames
267
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
268
+
269
+ if _partial_video_length <= 0:
270
+ break
271
+ else:
272
+ _partial_video_length = partial_video_length
273
+
274
+ if last_frames >= length_slider:
275
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
276
+ else:
277
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
278
+
279
+ with torch.no_grad():
280
+ sample = self.pipeline(
281
+ prompt_textbox,
282
+ negative_prompt = negative_prompt_textbox,
283
+ num_inference_steps = sample_step_slider,
284
+ guidance_scale = cfg_scale_slider,
285
+ width = width_slider,
286
+ height = height_slider,
287
+ num_frames = _partial_video_length,
288
+ generator = generator,
289
+
290
+ video = input_video,
291
+ mask_video = input_video_mask,
292
+ strength = 1,
293
+ ).videos
294
+
295
+ if init_frames != 0:
296
+ mix_ratio = torch.from_numpy(
297
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
298
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
299
+
300
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
301
+ sample[:, :, :overlap_video_length] * mix_ratio
302
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
303
+
304
+ sample = new_sample
305
+ else:
306
+ new_sample = sample
307
+
308
+ if last_frames >= length_slider:
309
+ break
310
+
311
+ start_image = [
312
+ Image.fromarray(
313
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
314
+ ) for _index in range(-overlap_video_length, 0)
315
+ ]
316
+
317
+ init_frames = init_frames + _partial_video_length - overlap_video_length
318
+ last_frames = init_frames + _partial_video_length
319
+ else:
320
+ if validation_video is not None:
321
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
322
+ strength = denoise_strength
323
+ else:
324
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
325
+ strength = 1
326
+
327
+ sample = self.pipeline(
328
+ prompt_textbox,
329
+ negative_prompt = negative_prompt_textbox,
330
+ num_inference_steps = sample_step_slider,
331
+ guidance_scale = cfg_scale_slider,
332
+ width = width_slider,
333
+ height = height_slider,
334
+ num_frames = length_slider if not is_image else 1,
335
+ generator = generator,
336
+
337
+ video = input_video,
338
+ mask_video = input_video_mask,
339
+ strength = strength,
340
+ ).videos
341
+ else:
342
+ sample = self.pipeline(
343
+ prompt_textbox,
344
+ negative_prompt = negative_prompt_textbox,
345
+ num_inference_steps = sample_step_slider,
346
+ guidance_scale = cfg_scale_slider,
347
+ width = width_slider,
348
+ height = height_slider,
349
+ num_frames = length_slider if not is_image else 1,
350
+ generator = generator
351
+ ).videos
352
+ except Exception as e:
353
+ gc.collect()
354
+ torch.cuda.empty_cache()
355
+ torch.cuda.ipc_collect()
356
+ if self.lora_model_path != "none":
357
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
358
+ if is_api:
359
+ return "", f"Error. error information is {str(e)}"
360
+ else:
361
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
362
+
363
+ gc.collect()
364
+ torch.cuda.empty_cache()
365
+ torch.cuda.ipc_collect()
366
+
367
+ # lora part
368
+ if self.lora_model_path != "none":
369
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
370
+
371
+ sample_config = {
372
+ "prompt": prompt_textbox,
373
+ "n_prompt": negative_prompt_textbox,
374
+ "sampler": sampler_dropdown,
375
+ "num_inference_steps": sample_step_slider,
376
+ "guidance_scale": cfg_scale_slider,
377
+ "width": width_slider,
378
+ "height": height_slider,
379
+ "video_length": length_slider,
380
+ "seed_textbox": seed_textbox
381
+ }
382
+ json_str = json.dumps(sample_config, indent=4)
383
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
384
+ f.write(json_str)
385
+ f.write("\n\n")
386
+
387
+ if not os.path.exists(self.savedir_sample):
388
+ os.makedirs(self.savedir_sample, exist_ok=True)
389
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
390
+ prefix = str(index).zfill(3)
391
+
392
+ gc.collect()
393
+ torch.cuda.empty_cache()
394
+ torch.cuda.ipc_collect()
395
+ if is_image or length_slider == 1:
396
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
397
+
398
+ image = sample[0, :, 0]
399
+ image = image.transpose(0, 1).transpose(1, 2)
400
+ image = (image * 255).numpy().astype(np.uint8)
401
+ image = Image.fromarray(image)
402
+ image.save(save_sample_path)
403
+
404
+ if is_api:
405
+ return save_sample_path, "Success"
406
+ else:
407
+ if gradio_version_is_above_4:
408
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
409
+ else:
410
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
411
+ else:
412
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
413
+ save_videos_grid(sample, save_sample_path, fps=8)
414
+
415
+ if is_api:
416
+ return save_sample_path, "Success"
417
+ else:
418
+ if gradio_version_is_above_4:
419
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
420
+ else:
421
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
422
+
423
+
424
+ def ui(low_gpu_memory_mode, weight_dtype):
425
+ controller = CogVideoX_I2VController(low_gpu_memory_mode, weight_dtype)
426
+
427
+ with gr.Blocks(css=css) as demo:
428
+ gr.Markdown(
429
+ """
430
+ # CogVideoX-Fun:
431
+
432
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
433
+
434
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
435
+ """
436
+ )
437
+ with gr.Column(variant="panel"):
438
+ gr.Markdown(
439
+ """
440
+ ### 1. Model checkpoints (模型路径).
441
+ """
442
+ )
443
+ with gr.Row():
444
+ diffusion_transformer_dropdown = gr.Dropdown(
445
+ label="Pretrained Model Path (预训练模型路径)",
446
+ choices=controller.diffusion_transformer_list,
447
+ value="none",
448
+ interactive=True,
449
+ )
450
+ diffusion_transformer_dropdown.change(
451
+ fn=controller.update_diffusion_transformer,
452
+ inputs=[diffusion_transformer_dropdown],
453
+ outputs=[diffusion_transformer_dropdown]
454
+ )
455
+
456
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
457
+ def refresh_diffusion_transformer():
458
+ controller.refresh_diffusion_transformer()
459
+ return gr.update(choices=controller.diffusion_transformer_list)
460
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
461
+
462
+ with gr.Row():
463
+ base_model_dropdown = gr.Dropdown(
464
+ label="Select base Dreambooth model (选择基模型[非必需])",
465
+ choices=controller.personalized_model_list,
466
+ value="none",
467
+ interactive=True,
468
+ )
469
+
470
+ lora_model_dropdown = gr.Dropdown(
471
+ label="Select LoRA model (选择LoRA模型[非必需])",
472
+ choices=["none"] + controller.personalized_model_list,
473
+ value="none",
474
+ interactive=True,
475
+ )
476
+
477
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
478
+
479
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
480
+ def update_personalized_model():
481
+ controller.refresh_personalized_model()
482
+ return [
483
+ gr.update(choices=controller.personalized_model_list),
484
+ gr.update(choices=["none"] + controller.personalized_model_list)
485
+ ]
486
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
487
+
488
+ with gr.Column(variant="panel"):
489
+ gr.Markdown(
490
+ """
491
+ ### 2. Configs for Generation (生成参数配置).
492
+ """
493
+ )
494
+
495
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
496
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
497
+
498
+ with gr.Row():
499
+ with gr.Column():
500
+ with gr.Row():
501
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
502
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
503
+
504
+ resize_method = gr.Radio(
505
+ ["Generate by", "Resize according to Reference"],
506
+ value="Generate by",
507
+ show_label=False,
508
+ )
509
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1344, step=16)
510
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1344, step=16)
511
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
512
+
513
+ with gr.Group():
514
+ generation_method = gr.Radio(
515
+ ["Video Generation", "Image Generation", "Long Video Generation"],
516
+ value="Video Generation",
517
+ show_label=False,
518
+ )
519
+ with gr.Row():
520
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=1, maximum=49, step=4)
521
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
522
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
523
+
524
+ source_method = gr.Radio(
525
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
526
+ value="Text to Video (文本到视频)",
527
+ show_label=False,
528
+ )
529
+ with gr.Column(visible = False) as image_to_video_col:
530
+ start_image = gr.Image(
531
+ label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True,
532
+ elem_id="i2v_start", sources="upload", type="filepath",
533
+ )
534
+
535
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
536
+ def select_template(evt: gr.SelectData):
537
+ text = {
538
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
539
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
540
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
541
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
542
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
543
+ }[template_gallery_path[evt.index]]
544
+ return template_gallery_path[evt.index], text
545
+
546
+ template_gallery = gr.Gallery(
547
+ template_gallery_path,
548
+ columns=5, rows=1,
549
+ height=140,
550
+ allow_preview=False,
551
+ container=False,
552
+ label="Template Examples",
553
+ )
554
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
555
+
556
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
557
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
558
+
559
+ with gr.Column(visible = False) as video_to_video_col:
560
+ validation_video = gr.Video(
561
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
562
+ elem_id="v2v", sources="upload",
563
+ )
564
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
565
+
566
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
567
+
568
+ with gr.Row():
569
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
570
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
571
+ seed_button.click(
572
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
573
+ inputs=[],
574
+ outputs=[seed_textbox]
575
+ )
576
+
577
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
578
+
579
+ with gr.Column():
580
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
581
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
582
+ infer_progress = gr.Textbox(
583
+ label="Generation Info (生成信息)",
584
+ value="No task currently",
585
+ interactive=False
586
+ )
587
+
588
+ def upload_generation_method(generation_method):
589
+ if generation_method == "Video Generation":
590
+ return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
591
+ elif generation_method == "Image Generation":
592
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
593
+ else:
594
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
595
+ generation_method.change(
596
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
597
+ )
598
+
599
+ def upload_source_method(source_method):
600
+ if source_method == "Text to Video (文本到视频)":
601
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
602
+ elif source_method == "Image to Video (图片到视频)":
603
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
604
+ else:
605
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
606
+ source_method.change(
607
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
608
+ )
609
+
610
+ def upload_resize_method(resize_method):
611
+ if resize_method == "Generate by":
612
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
613
+ else:
614
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
615
+ resize_method.change(
616
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
617
+ )
618
+
619
+ generate_button.click(
620
+ fn=controller.generate,
621
+ inputs=[
622
+ diffusion_transformer_dropdown,
623
+ base_model_dropdown,
624
+ lora_model_dropdown,
625
+ lora_alpha_slider,
626
+ prompt_textbox,
627
+ negative_prompt_textbox,
628
+ sampler_dropdown,
629
+ sample_step_slider,
630
+ resize_method,
631
+ width_slider,
632
+ height_slider,
633
+ base_resolution,
634
+ generation_method,
635
+ length_slider,
636
+ overlap_video_length,
637
+ partial_video_length,
638
+ cfg_scale_slider,
639
+ start_image,
640
+ end_image,
641
+ validation_video,
642
+ denoise_strength,
643
+ seed_textbox,
644
+ ],
645
+ outputs=[result_image, result_video, infer_progress]
646
+ )
647
+ return demo, controller
648
+
649
+
650
+ class CogVideoX_I2VController_Modelscope:
651
+ def __init__(self, model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
652
+ # Basic dir
653
+ self.basedir = os.getcwd()
654
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
655
+ self.lora_model_path = "none"
656
+ self.savedir_sample = savedir_sample
657
+ self.refresh_personalized_model()
658
+ os.makedirs(self.savedir_sample, exist_ok=True)
659
+
660
+ # model path
661
+ self.weight_dtype = weight_dtype
662
+
663
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
664
+ model_name,
665
+ subfolder="vae",
666
+ ).to(self.weight_dtype)
667
+
668
+ # Get Transformer
669
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
670
+ model_name,
671
+ subfolder="transformer",
672
+ ).to(self.weight_dtype)
673
+
674
+ # Get pipeline
675
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
676
+ self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
677
+ model_name,
678
+ vae=self.vae,
679
+ transformer=self.transformer,
680
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
681
+ torch_dtype=self.weight_dtype
682
+ )
683
+ else:
684
+ self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
685
+ model_name,
686
+ vae=self.vae,
687
+ transformer=self.transformer,
688
+ scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
689
+ torch_dtype=self.weight_dtype
690
+ )
691
+
692
+ if low_gpu_memory_mode:
693
+ self.pipeline.enable_sequential_cpu_offload()
694
+ else:
695
+ self.pipeline.enable_model_cpu_offload()
696
+ print("Update diffusion transformer done")
697
+
698
+
699
+ def refresh_personalized_model(self):
700
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
701
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
702
+
703
+
704
+ def update_lora_model(self, lora_model_dropdown):
705
+ print("Update lora model")
706
+ if lora_model_dropdown == "none":
707
+ self.lora_model_path = "none"
708
+ return gr.update()
709
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
710
+ self.lora_model_path = lora_model_dropdown
711
+ return gr.update()
712
+
713
+
714
+ def generate(
715
+ self,
716
+ diffusion_transformer_dropdown,
717
+ base_model_dropdown,
718
+ lora_model_dropdown,
719
+ lora_alpha_slider,
720
+ prompt_textbox,
721
+ negative_prompt_textbox,
722
+ sampler_dropdown,
723
+ sample_step_slider,
724
+ resize_method,
725
+ width_slider,
726
+ height_slider,
727
+ base_resolution,
728
+ generation_method,
729
+ length_slider,
730
+ overlap_video_length,
731
+ partial_video_length,
732
+ cfg_scale_slider,
733
+ start_image,
734
+ end_image,
735
+ validation_video,
736
+ denoise_strength,
737
+ seed_textbox,
738
+ is_api = False,
739
+ ):
740
+ gc.collect()
741
+ torch.cuda.empty_cache()
742
+ torch.cuda.ipc_collect()
743
+
744
+ if self.transformer is None:
745
+ raise gr.Error(f"Please select a pretrained model path.")
746
+
747
+ if self.lora_model_path != lora_model_dropdown:
748
+ print("Update lora model")
749
+ self.update_lora_model(lora_model_dropdown)
750
+
751
+ if resize_method == "Resize according to Reference":
752
+ if start_image is None and validation_video is None:
753
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
754
+
755
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
756
+
757
+ if validation_video is not None:
758
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
759
+ else:
760
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
761
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
762
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
763
+
764
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
765
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
766
+
767
+ if start_image is None and end_image is not None:
768
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
769
+
770
+ is_image = True if generation_method == "Image Generation" else False
771
+
772
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
773
+ if self.lora_model_path != "none":
774
+ # lora part
775
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
776
+
777
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
778
+ else: seed_textbox = np.random.randint(0, 1e10)
779
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
780
+
781
+ try:
782
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
783
+ if validation_video is not None:
784
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
785
+ strength = denoise_strength
786
+ else:
787
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
788
+ strength = 1
789
+
790
+ sample = self.pipeline(
791
+ prompt_textbox,
792
+ negative_prompt = negative_prompt_textbox,
793
+ num_inference_steps = sample_step_slider,
794
+ guidance_scale = cfg_scale_slider,
795
+ width = width_slider,
796
+ height = height_slider,
797
+ num_frames = length_slider if not is_image else 1,
798
+ generator = generator,
799
+
800
+ video = input_video,
801
+ mask_video = input_video_mask,
802
+ strength = strength,
803
+ ).videos
804
+ else:
805
+ sample = self.pipeline(
806
+ prompt_textbox,
807
+ negative_prompt = negative_prompt_textbox,
808
+ num_inference_steps = sample_step_slider,
809
+ guidance_scale = cfg_scale_slider,
810
+ width = width_slider,
811
+ height = height_slider,
812
+ num_frames = length_slider if not is_image else 1,
813
+ generator = generator
814
+ ).videos
815
+ except Exception as e:
816
+ gc.collect()
817
+ torch.cuda.empty_cache()
818
+ torch.cuda.ipc_collect()
819
+ if self.lora_model_path != "none":
820
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
821
+ if is_api:
822
+ return "", f"Error. error information is {str(e)}"
823
+ else:
824
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
825
+
826
+ gc.collect()
827
+ torch.cuda.empty_cache()
828
+ torch.cuda.ipc_collect()
829
+
830
+ # lora part
831
+ if self.lora_model_path != "none":
832
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
833
+
834
+ if not os.path.exists(self.savedir_sample):
835
+ os.makedirs(self.savedir_sample, exist_ok=True)
836
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
837
+ prefix = str(index).zfill(3)
838
+
839
+ gc.collect()
840
+ torch.cuda.empty_cache()
841
+ torch.cuda.ipc_collect()
842
+ if is_image or length_slider == 1:
843
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
844
+
845
+ image = sample[0, :, 0]
846
+ image = image.transpose(0, 1).transpose(1, 2)
847
+ image = (image * 255).numpy().astype(np.uint8)
848
+ image = Image.fromarray(image)
849
+ image.save(save_sample_path)
850
+ if is_api:
851
+ return save_sample_path, "Success"
852
+ else:
853
+ if gradio_version_is_above_4:
854
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
855
+ else:
856
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
857
+ else:
858
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
859
+ save_videos_grid(sample, save_sample_path, fps=8)
860
+ if is_api:
861
+ return save_sample_path, "Success"
862
+ else:
863
+ if gradio_version_is_above_4:
864
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
865
+ else:
866
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
867
+
868
+
869
+ def ui_modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype):
870
+ controller = CogVideoX_I2VController_Modelscope(model_name, savedir_sample, low_gpu_memory_mode, weight_dtype)
871
+
872
+ with gr.Blocks(css=css) as demo:
873
+ gr.Markdown(
874
+ """
875
+ # CogVideoX-Fun
876
+
877
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
878
+
879
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
880
+ """
881
+ )
882
+ with gr.Column(variant="panel"):
883
+ gr.Markdown(
884
+ """
885
+ ### 1. Model checkpoints (模型路径).
886
+ """
887
+ )
888
+ with gr.Row():
889
+ diffusion_transformer_dropdown = gr.Dropdown(
890
+ label="Pretrained Model Path (预训练模型路径)",
891
+ choices=[model_name],
892
+ value=model_name,
893
+ interactive=False,
894
+ )
895
+ with gr.Row():
896
+ base_model_dropdown = gr.Dropdown(
897
+ label="Select base Dreambooth model (选择基模型[非必需])",
898
+ choices=["none"],
899
+ value="none",
900
+ interactive=False,
901
+ visible=False
902
+ )
903
+ with gr.Column(visible=False):
904
+ gr.Markdown(
905
+ """
906
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
907
+ """
908
+ )
909
+ with gr.Row():
910
+ lora_model_dropdown = gr.Dropdown(
911
+ label="Select LoRA model",
912
+ choices=["none"],
913
+ value="none",
914
+ interactive=True,
915
+ )
916
+
917
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
918
+
919
+ with gr.Column(variant="panel"):
920
+ gr.Markdown(
921
+ """
922
+ ### 2. Configs for Generation (生成参数配置).
923
+ """
924
+ )
925
+
926
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
927
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
928
+
929
+ with gr.Row():
930
+ with gr.Column():
931
+ with gr.Row():
932
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
933
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
934
+
935
+ resize_method = gr.Radio(
936
+ ["Generate by", "Resize according to Reference"],
937
+ value="Generate by",
938
+ show_label=False,
939
+ )
940
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
941
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
942
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
943
+
944
+ with gr.Group():
945
+ generation_method = gr.Radio(
946
+ ["Video Generation", "Image Generation"],
947
+ value="Video Generation",
948
+ show_label=False,
949
+ visible=True,
950
+ )
951
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5, maximum=49, step=4)
952
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
953
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False)
954
+
955
+ source_method = gr.Radio(
956
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
957
+ value="Text to Video (文本到视频)",
958
+ show_label=False,
959
+ )
960
+ with gr.Column(visible = False) as image_to_video_col:
961
+ with gr.Row():
962
+ start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
963
+
964
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
965
+ def select_template(evt: gr.SelectData):
966
+ text = {
967
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
968
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
969
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
970
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
971
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
972
+ }[template_gallery_path[evt.index]]
973
+ return template_gallery_path[evt.index], text
974
+
975
+ template_gallery = gr.Gallery(
976
+ template_gallery_path,
977
+ columns=5, rows=1,
978
+ height=140,
979
+ allow_preview=False,
980
+ container=False,
981
+ label="Template Examples",
982
+ )
983
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
984
+
985
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
986
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
987
+
988
+ with gr.Column(visible = False) as video_to_video_col:
989
+ validation_video = gr.Video(
990
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
991
+ elem_id="v2v", sources="upload",
992
+ )
993
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
994
+
995
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
996
+
997
+ with gr.Row():
998
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
999
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
1000
+ seed_button.click(
1001
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
1002
+ inputs=[],
1003
+ outputs=[seed_textbox]
1004
+ )
1005
+
1006
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
1007
+
1008
+ with gr.Column():
1009
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
1010
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
1011
+ infer_progress = gr.Textbox(
1012
+ label="Generation Info (生成信息)",
1013
+ value="No task currently",
1014
+ interactive=False
1015
+ )
1016
+
1017
+ def upload_generation_method(generation_method):
1018
+ if generation_method == "Video Generation":
1019
+ return gr.update(visible=True, minimum=8, maximum=49, value=49, interactive=True)
1020
+ elif generation_method == "Image Generation":
1021
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1022
+ generation_method.change(
1023
+ upload_generation_method, generation_method, [length_slider]
1024
+ )
1025
+
1026
+ def upload_source_method(source_method):
1027
+ if source_method == "Text to Video (文本到视频)":
1028
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1029
+ elif source_method == "Image to Video (图片到视频)":
1030
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
1031
+ else:
1032
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
1033
+ source_method.change(
1034
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
1035
+ )
1036
+
1037
+ def upload_resize_method(resize_method):
1038
+ if resize_method == "Generate by":
1039
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1040
+ else:
1041
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1042
+ resize_method.change(
1043
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1044
+ )
1045
+
1046
+ generate_button.click(
1047
+ fn=controller.generate,
1048
+ inputs=[
1049
+ diffusion_transformer_dropdown,
1050
+ base_model_dropdown,
1051
+ lora_model_dropdown,
1052
+ lora_alpha_slider,
1053
+ prompt_textbox,
1054
+ negative_prompt_textbox,
1055
+ sampler_dropdown,
1056
+ sample_step_slider,
1057
+ resize_method,
1058
+ width_slider,
1059
+ height_slider,
1060
+ base_resolution,
1061
+ generation_method,
1062
+ length_slider,
1063
+ overlap_video_length,
1064
+ partial_video_length,
1065
+ cfg_scale_slider,
1066
+ start_image,
1067
+ end_image,
1068
+ validation_video,
1069
+ denoise_strength,
1070
+ seed_textbox,
1071
+ ],
1072
+ outputs=[result_image, result_video, infer_progress]
1073
+ )
1074
+ return demo, controller
1075
+
1076
+
1077
+ def post_eas(
1078
+ diffusion_transformer_dropdown,
1079
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1080
+ prompt_textbox, negative_prompt_textbox,
1081
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1082
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1083
+ start_image, end_image, validation_video, denoise_strength, seed_textbox,
1084
+ ):
1085
+ if start_image is not None:
1086
+ with open(start_image, 'rb') as file:
1087
+ file_content = file.read()
1088
+ start_image_encoded_content = base64.b64encode(file_content)
1089
+ start_image = start_image_encoded_content.decode('utf-8')
1090
+
1091
+ if end_image is not None:
1092
+ with open(end_image, 'rb') as file:
1093
+ file_content = file.read()
1094
+ end_image_encoded_content = base64.b64encode(file_content)
1095
+ end_image = end_image_encoded_content.decode('utf-8')
1096
+
1097
+ if validation_video is not None:
1098
+ with open(validation_video, 'rb') as file:
1099
+ file_content = file.read()
1100
+ validation_video_encoded_content = base64.b64encode(file_content)
1101
+ validation_video = validation_video_encoded_content.decode('utf-8')
1102
+
1103
+ datas = {
1104
+ "base_model_path": base_model_dropdown,
1105
+ "lora_model_path": lora_model_dropdown,
1106
+ "lora_alpha_slider": lora_alpha_slider,
1107
+ "prompt_textbox": prompt_textbox,
1108
+ "negative_prompt_textbox": negative_prompt_textbox,
1109
+ "sampler_dropdown": sampler_dropdown,
1110
+ "sample_step_slider": sample_step_slider,
1111
+ "resize_method": resize_method,
1112
+ "width_slider": width_slider,
1113
+ "height_slider": height_slider,
1114
+ "base_resolution": base_resolution,
1115
+ "generation_method": generation_method,
1116
+ "length_slider": length_slider,
1117
+ "cfg_scale_slider": cfg_scale_slider,
1118
+ "start_image": start_image,
1119
+ "end_image": end_image,
1120
+ "validation_video": validation_video,
1121
+ "denoise_strength": denoise_strength,
1122
+ "seed_textbox": seed_textbox,
1123
+ }
1124
+
1125
+ session = requests.session()
1126
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
1127
+
1128
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/cogvideox_fun/infer_forward', json=datas, timeout=300)
1129
+
1130
+ outputs = response.json()
1131
+ return outputs
1132
+
1133
+
1134
+ class CogVideoX_I2VController_EAS:
1135
+ def __init__(self, edition, config_path, model_name, savedir_sample):
1136
+ self.savedir_sample = savedir_sample
1137
+ os.makedirs(self.savedir_sample, exist_ok=True)
1138
+
1139
+ def generate(
1140
+ self,
1141
+ diffusion_transformer_dropdown,
1142
+ base_model_dropdown,
1143
+ lora_model_dropdown,
1144
+ lora_alpha_slider,
1145
+ prompt_textbox,
1146
+ negative_prompt_textbox,
1147
+ sampler_dropdown,
1148
+ sample_step_slider,
1149
+ resize_method,
1150
+ width_slider,
1151
+ height_slider,
1152
+ base_resolution,
1153
+ generation_method,
1154
+ length_slider,
1155
+ cfg_scale_slider,
1156
+ start_image,
1157
+ end_image,
1158
+ validation_video,
1159
+ denoise_strength,
1160
+ seed_textbox
1161
+ ):
1162
+ is_image = True if generation_method == "Image Generation" else False
1163
+
1164
+ outputs = post_eas(
1165
+ diffusion_transformer_dropdown,
1166
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1167
+ prompt_textbox, negative_prompt_textbox,
1168
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1169
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1170
+ start_image, end_image, validation_video, denoise_strength,
1171
+ seed_textbox
1172
+ )
1173
+ try:
1174
+ base64_encoding = outputs["base64_encoding"]
1175
+ except:
1176
+ return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
1177
+
1178
+ decoded_data = base64.b64decode(base64_encoding)
1179
+
1180
+ if not os.path.exists(self.savedir_sample):
1181
+ os.makedirs(self.savedir_sample, exist_ok=True)
1182
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
1183
+ prefix = str(index).zfill(3)
1184
+
1185
+ if is_image or length_slider == 1:
1186
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
1187
+ with open(save_sample_path, "wb") as file:
1188
+ file.write(decoded_data)
1189
+ if gradio_version_is_above_4:
1190
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
1191
+ else:
1192
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
1193
+ else:
1194
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
1195
+ with open(save_sample_path, "wb") as file:
1196
+ file.write(decoded_data)
1197
+ if gradio_version_is_above_4:
1198
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
1199
+ else:
1200
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
1201
+
1202
+
1203
+ def ui_eas(model_name, savedir_sample):
1204
+ controller = CogVideoX_I2VController_EAS(model_name, savedir_sample)
1205
+
1206
+ with gr.Blocks(css=css) as demo:
1207
+ gr.Markdown(
1208
+ """
1209
+ # CogVideoX-Fun
1210
+
1211
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
1212
+
1213
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
1214
+ """
1215
+ )
1216
+ with gr.Column(variant="panel"):
1217
+ gr.Markdown(
1218
+ """
1219
+ ### 1. Model checkpoints.
1220
+ """
1221
+ )
1222
+ with gr.Row():
1223
+ diffusion_transformer_dropdown = gr.Dropdown(
1224
+ label="Pretrained Model Path",
1225
+ choices=[model_name],
1226
+ value=model_name,
1227
+ interactive=False,
1228
+ )
1229
+ with gr.Row():
1230
+ base_model_dropdown = gr.Dropdown(
1231
+ label="Select base Dreambooth model",
1232
+ choices=["none"],
1233
+ value="none",
1234
+ interactive=False,
1235
+ visible=False
1236
+ )
1237
+ with gr.Column(visible=False):
1238
+ gr.Markdown(
1239
+ """
1240
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
1241
+ """
1242
+ )
1243
+ with gr.Row():
1244
+ lora_model_dropdown = gr.Dropdown(
1245
+ label="Select LoRA model",
1246
+ choices=["none"],
1247
+ value="none",
1248
+ interactive=True,
1249
+ )
1250
+
1251
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
1252
+
1253
+ with gr.Column(variant="panel"):
1254
+ gr.Markdown(
1255
+ """
1256
+ ### 2. Configs for Generation.
1257
+ """
1258
+ )
1259
+
1260
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1261
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. " )
1262
+
1263
+ with gr.Row():
1264
+ with gr.Column():
1265
+ with gr.Row():
1266
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
1267
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=50, step=1, interactive=False)
1268
+
1269
+ resize_method = gr.Radio(
1270
+ ["Generate by", "Resize according to Reference"],
1271
+ value="Generate by",
1272
+ show_label=False,
1273
+ )
1274
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
1275
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
1276
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
1277
+
1278
+ with gr.Group():
1279
+ generation_method = gr.Radio(
1280
+ ["Video Generation", "Image Generation"],
1281
+ value="Video Generation",
1282
+ show_label=False,
1283
+ visible=True,
1284
+ )
1285
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5, maximum=49, step=4)
1286
+
1287
+ source_method = gr.Radio(
1288
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
1289
+ value="Text to Video (文本到视频)",
1290
+ show_label=False,
1291
+ )
1292
+ with gr.Column(visible = False) as image_to_video_col:
1293
+ start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
1294
+
1295
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1296
+ def select_template(evt: gr.SelectData):
1297
+ text = {
1298
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1299
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1300
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1301
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1302
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1303
+ }[template_gallery_path[evt.index]]
1304
+ return template_gallery_path[evt.index], text
1305
+
1306
+ template_gallery = gr.Gallery(
1307
+ template_gallery_path,
1308
+ columns=5, rows=1,
1309
+ height=140,
1310
+ allow_preview=False,
1311
+ container=False,
1312
+ label="Template Examples",
1313
+ )
1314
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
1315
+
1316
+ with gr.Accordion("The image at the ending of the video (Optional)", open=False):
1317
+ end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
1318
+
1319
+ with gr.Column(visible = False) as video_to_video_col:
1320
+ validation_video = gr.Video(
1321
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
1322
+ elem_id="v2v", sources="upload",
1323
+ )
1324
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=0.95, step=0.01)
1325
+
1326
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
1327
+
1328
+ with gr.Row():
1329
+ seed_textbox = gr.Textbox(label="Seed", value=43)
1330
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
1331
+ seed_button.click(
1332
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
1333
+ inputs=[],
1334
+ outputs=[seed_textbox]
1335
+ )
1336
+
1337
+ generate_button = gr.Button(value="Generate", variant='primary')
1338
+
1339
+ with gr.Column():
1340
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
1341
+ result_video = gr.Video(label="Generated Animation", interactive=False)
1342
+ infer_progress = gr.Textbox(
1343
+ label="Generation Info",
1344
+ value="No task currently",
1345
+ interactive=False
1346
+ )
1347
+
1348
+ def upload_generation_method(generation_method):
1349
+ if generation_method == "Video Generation":
1350
+ return gr.update(visible=True, minimum=5, maximum=49, value=49, interactive=True)
1351
+ elif generation_method == "Image Generation":
1352
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1353
+ generation_method.change(
1354
+ upload_generation_method, generation_method, [length_slider]
1355
+ )
1356
+
1357
+ def upload_source_method(source_method):
1358
+ if source_method == "Text to Video (文本到视频)":
1359
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
1360
+ elif source_method == "Image to Video (图片到视频)":
1361
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None)]
1362
+ else:
1363
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update()]
1364
+ source_method.change(
1365
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video]
1366
+ )
1367
+
1368
+ def upload_resize_method(resize_method):
1369
+ if resize_method == "Generate by":
1370
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1371
+ else:
1372
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1373
+ resize_method.change(
1374
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1375
+ )
1376
+
1377
+ generate_button.click(
1378
+ fn=controller.generate,
1379
+ inputs=[
1380
+ diffusion_transformer_dropdown,
1381
+ base_model_dropdown,
1382
+ lora_model_dropdown,
1383
+ lora_alpha_slider,
1384
+ prompt_textbox,
1385
+ negative_prompt_textbox,
1386
+ sampler_dropdown,
1387
+ sample_step_slider,
1388
+ resize_method,
1389
+ width_slider,
1390
+ height_slider,
1391
+ base_resolution,
1392
+ generation_method,
1393
+ length_slider,
1394
+ cfg_scale_slider,
1395
+ start_image,
1396
+ end_image,
1397
+ validation_video,
1398
+ denoise_strength,
1399
+ seed_textbox,
1400
+ ],
1401
+ outputs=[result_image, result_video, infer_progress]
1402
+ )
1403
+ return demo, controller
cogvideox/utils/__init__.py ADDED
File without changes
cogvideox/utils/lora_utils.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ add_lora_in_attn_temporal: bool = False,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if not add_lora_in_attn_temporal:
206
+ if "attn_temporal" in child_name:
207
+ continue
208
+
209
+ if is_linear or is_conv2d:
210
+ lora_name = prefix + "." + name + "." + child_name
211
+ lora_name = lora_name.replace(".", "_")
212
+
213
+ dim = None
214
+ alpha = None
215
+
216
+ if is_linear or is_conv2d_1x1:
217
+ dim = self.lora_dim
218
+ alpha = self.alpha
219
+
220
+ if dim is None or dim == 0:
221
+ if is_linear or is_conv2d_1x1:
222
+ skipped.append(lora_name)
223
+ continue
224
+
225
+ lora = module_class(
226
+ lora_name,
227
+ child_module,
228
+ self.multiplier,
229
+ dim,
230
+ alpha,
231
+ dropout=dropout,
232
+ )
233
+ loras.append(lora)
234
+ return loras, skipped
235
+
236
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
237
+
238
+ self.text_encoder_loras = []
239
+ skipped_te = []
240
+ for i, text_encoder in enumerate(text_encoders):
241
+ if text_encoder is not None:
242
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
243
+ self.text_encoder_loras.extend(text_encoder_loras)
244
+ skipped_te += skipped
245
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
246
+
247
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
248
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
249
+
250
+ # assertion
251
+ names = set()
252
+ for lora in self.text_encoder_loras + self.unet_loras:
253
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
254
+ names.add(lora.lora_name)
255
+
256
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
257
+ if apply_text_encoder:
258
+ print("enable LoRA for text encoder")
259
+ else:
260
+ self.text_encoder_loras = []
261
+
262
+ if apply_unet:
263
+ print("enable LoRA for U-Net")
264
+ else:
265
+ self.unet_loras = []
266
+
267
+ for lora in self.text_encoder_loras + self.unet_loras:
268
+ lora.apply_to()
269
+ self.add_module(lora.lora_name, lora)
270
+
271
+ def set_multiplier(self, multiplier):
272
+ self.multiplier = multiplier
273
+ for lora in self.text_encoder_loras + self.unet_loras:
274
+ lora.multiplier = self.multiplier
275
+
276
+ def load_weights(self, file):
277
+ if os.path.splitext(file)[1] == ".safetensors":
278
+ from safetensors.torch import load_file
279
+
280
+ weights_sd = load_file(file)
281
+ else:
282
+ weights_sd = torch.load(file, map_location="cpu")
283
+ info = self.load_state_dict(weights_sd, False)
284
+ return info
285
+
286
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
287
+ self.requires_grad_(True)
288
+ all_params = []
289
+
290
+ def enumerate_params(loras):
291
+ params = []
292
+ for lora in loras:
293
+ params.extend(lora.parameters())
294
+ return params
295
+
296
+ if self.text_encoder_loras:
297
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
298
+ if text_encoder_lr is not None:
299
+ param_data["lr"] = text_encoder_lr
300
+ all_params.append(param_data)
301
+
302
+ if self.unet_loras:
303
+ param_data = {"params": enumerate_params(self.unet_loras)}
304
+ if unet_lr is not None:
305
+ param_data["lr"] = unet_lr
306
+ all_params.append(param_data)
307
+
308
+ return all_params
309
+
310
+ def enable_gradient_checkpointing(self):
311
+ pass
312
+
313
+ def get_trainable_params(self):
314
+ return self.parameters()
315
+
316
+ def save_weights(self, file, dtype, metadata):
317
+ if metadata is not None and len(metadata) == 0:
318
+ metadata = None
319
+
320
+ state_dict = self.state_dict()
321
+
322
+ if dtype is not None:
323
+ for key in list(state_dict.keys()):
324
+ v = state_dict[key]
325
+ v = v.detach().clone().to("cpu").to(dtype)
326
+ state_dict[key] = v
327
+
328
+ if os.path.splitext(file)[1] == ".safetensors":
329
+ from safetensors.torch import save_file
330
+
331
+ # Precalculate model hashes to save time on indexing
332
+ if metadata is None:
333
+ metadata = {}
334
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
335
+ metadata["sshs_model_hash"] = model_hash
336
+ metadata["sshs_legacy_hash"] = legacy_hash
337
+
338
+ save_file(state_dict, file, metadata)
339
+ else:
340
+ torch.save(state_dict, file)
341
+
342
+ def create_network(
343
+ multiplier: float,
344
+ network_dim: Optional[int],
345
+ network_alpha: Optional[float],
346
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
347
+ transformer,
348
+ neuron_dropout: Optional[float] = None,
349
+ add_lora_in_attn_temporal: bool = False,
350
+ **kwargs,
351
+ ):
352
+ if network_dim is None:
353
+ network_dim = 4 # default
354
+ if network_alpha is None:
355
+ network_alpha = 1.0
356
+
357
+ network = LoRANetwork(
358
+ text_encoder,
359
+ transformer,
360
+ multiplier=multiplier,
361
+ lora_dim=network_dim,
362
+ alpha=network_alpha,
363
+ dropout=neuron_dropout,
364
+ add_lora_in_attn_temporal=add_lora_in_attn_temporal,
365
+ varbose=True,
366
+ )
367
+ return network
368
+
369
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
371
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
372
+ if state_dict is None:
373
+ state_dict = load_file(lora_path, device=device)
374
+ else:
375
+ state_dict = state_dict
376
+ updates = defaultdict(dict)
377
+ for key, value in state_dict.items():
378
+ layer, elem = key.split('.', 1)
379
+ updates[layer][elem] = value
380
+
381
+ for layer, elems in updates.items():
382
+
383
+ if "lora_te" in layer:
384
+ if transformer_only:
385
+ continue
386
+ else:
387
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
388
+ curr_layer = pipeline.text_encoder
389
+ else:
390
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
391
+ curr_layer = pipeline.transformer
392
+
393
+ temp_name = layer_infos.pop(0)
394
+ while len(layer_infos) > -1:
395
+ try:
396
+ curr_layer = curr_layer.__getattr__(temp_name)
397
+ if len(layer_infos) > 0:
398
+ temp_name = layer_infos.pop(0)
399
+ elif len(layer_infos) == 0:
400
+ break
401
+ except Exception:
402
+ if len(layer_infos) == 0:
403
+ print('Error loading layer')
404
+ if len(temp_name) > 0:
405
+ temp_name += "_" + layer_infos.pop(0)
406
+ else:
407
+ temp_name = layer_infos.pop(0)
408
+
409
+ weight_up = elems['lora_up.weight'].to(dtype)
410
+ weight_down = elems['lora_down.weight'].to(dtype)
411
+ if 'alpha' in elems.keys():
412
+ alpha = elems['alpha'].item() / weight_up.shape[1]
413
+ else:
414
+ alpha = 1.0
415
+
416
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
417
+ if len(weight_up.shape) == 4:
418
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
419
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(
420
+ 2).unsqueeze(3)
421
+ else:
422
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
423
+
424
+ return pipeline
425
+
426
+ # TODO: Refactor with merge_lora.
427
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
428
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
429
+ LORA_PREFIX_UNET = "lora_unet"
430
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
431
+ state_dict = load_file(lora_path, device=device)
432
+
433
+ updates = defaultdict(dict)
434
+ for key, value in state_dict.items():
435
+ layer, elem = key.split('.', 1)
436
+ updates[layer][elem] = value
437
+
438
+ for layer, elems in updates.items():
439
+
440
+ if "lora_te" in layer:
441
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
442
+ curr_layer = pipeline.text_encoder
443
+ else:
444
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
445
+ curr_layer = pipeline.transformer
446
+
447
+ temp_name = layer_infos.pop(0)
448
+ while len(layer_infos) > -1:
449
+ try:
450
+ curr_layer = curr_layer.__getattr__(temp_name)
451
+ if len(layer_infos) > 0:
452
+ temp_name = layer_infos.pop(0)
453
+ elif len(layer_infos) == 0:
454
+ break
455
+ except Exception:
456
+ if len(layer_infos) == 0:
457
+ print('Error loading layer')
458
+ if len(temp_name) > 0:
459
+ temp_name += "_" + layer_infos.pop(0)
460
+ else:
461
+ temp_name = layer_infos.pop(0)
462
+
463
+ weight_up = elems['lora_up.weight'].to(dtype)
464
+ weight_down = elems['lora_down.weight'].to(dtype)
465
+ if 'alpha' in elems.keys():
466
+ alpha = elems['alpha'].item() / weight_up.shape[1]
467
+ else:
468
+ alpha = 1.0
469
+
470
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
471
+ if len(weight_up.shape) == 4:
472
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
473
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
474
+ else:
475
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
476
+
477
+ return pipeline
cogvideox/utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import imageio
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import cv2
8
+ from einops import rearrange
9
+ from PIL import Image
10
+
11
+ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
12
+ target_pixels = int(base_resolution) * int(base_resolution)
13
+ original_width, original_height = Image.open(image).size
14
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
15
+ width_slider = round(original_width * ratio)
16
+ height_slider = round(original_height * ratio)
17
+ return height_slider, width_slider
18
+
19
+ def color_transfer(sc, dc):
20
+ """
21
+ Transfer color distribution from of sc, referred to dc.
22
+
23
+ Args:
24
+ sc (numpy.ndarray): input image to be transfered.
25
+ dc (numpy.ndarray): reference image
26
+
27
+ Returns:
28
+ numpy.ndarray: Transferred color distribution on the sc.
29
+ """
30
+
31
+ def get_mean_and_std(img):
32
+ x_mean, x_std = cv2.meanStdDev(img)
33
+ x_mean = np.hstack(np.around(x_mean, 2))
34
+ x_std = np.hstack(np.around(x_std, 2))
35
+ return x_mean, x_std
36
+
37
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
38
+ s_mean, s_std = get_mean_and_std(sc)
39
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
40
+ t_mean, t_std = get_mean_and_std(dc)
41
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
42
+ np.putmask(img_n, img_n > 255, 255)
43
+ np.putmask(img_n, img_n < 0, 0)
44
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
45
+ return dst
46
+
47
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
48
+ videos = rearrange(videos, "b c t h w -> t b c h w")
49
+ outputs = []
50
+ for x in videos:
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ outputs.append(Image.fromarray(x))
57
+
58
+ if color_transfer_post_process:
59
+ for i in range(1, len(outputs)):
60
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
61
+
62
+ os.makedirs(os.path.dirname(path), exist_ok=True)
63
+ if imageio_backend:
64
+ if path.endswith("mp4"):
65
+ imageio.mimsave(path, outputs, fps=fps)
66
+ else:
67
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
68
+ else:
69
+ if path.endswith("mp4"):
70
+ path = path.replace('.mp4', '.gif')
71
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
72
+
73
+ def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
74
+ if validation_image_start is not None and validation_image_end is not None:
75
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
76
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
77
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
78
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
79
+ else:
80
+ image_start = clip_image = validation_image_start
81
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
82
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
83
+
84
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
85
+ image_end = Image.open(validation_image_end).convert("RGB")
86
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
87
+ else:
88
+ image_end = validation_image_end
89
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
90
+
91
+ if type(image_start) is list:
92
+ clip_image = clip_image[0]
93
+ start_video = torch.cat(
94
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
95
+ dim=2
96
+ )
97
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
98
+ input_video[:, :, :len(image_start)] = start_video
99
+
100
+ input_video_mask = torch.zeros_like(input_video[:, :1])
101
+ input_video_mask[:, :, len(image_start):] = 255
102
+ else:
103
+ input_video = torch.tile(
104
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
105
+ [1, 1, video_length, 1, 1]
106
+ )
107
+ input_video_mask = torch.zeros_like(input_video[:, :1])
108
+ input_video_mask[:, :, 1:] = 255
109
+
110
+ if type(image_end) is list:
111
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
112
+ end_video = torch.cat(
113
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
114
+ dim=2
115
+ )
116
+ input_video[:, :, -len(end_video):] = end_video
117
+
118
+ input_video_mask[:, :, -len(image_end):] = 0
119
+ else:
120
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
121
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
122
+ input_video_mask[:, :, -1:] = 0
123
+
124
+ input_video = input_video / 255
125
+
126
+ elif validation_image_start is not None:
127
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
128
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
129
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
130
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
131
+ else:
132
+ image_start = clip_image = validation_image_start
133
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
134
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
135
+ image_end = None
136
+
137
+ if type(image_start) is list:
138
+ clip_image = clip_image[0]
139
+ start_video = torch.cat(
140
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
141
+ dim=2
142
+ )
143
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
144
+ input_video[:, :, :len(image_start)] = start_video
145
+ input_video = input_video / 255
146
+
147
+ input_video_mask = torch.zeros_like(input_video[:, :1])
148
+ input_video_mask[:, :, len(image_start):] = 255
149
+ else:
150
+ input_video = torch.tile(
151
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
152
+ [1, 1, video_length, 1, 1]
153
+ ) / 255
154
+ input_video_mask = torch.zeros_like(input_video[:, :1])
155
+ input_video_mask[:, :, 1:, ] = 255
156
+ else:
157
+ image_start = None
158
+ image_end = None
159
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
160
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
161
+ clip_image = None
162
+
163
+ del image_start
164
+ del image_end
165
+ gc.collect()
166
+
167
+ return input_video, input_video_mask, clip_image
168
+
169
+ def get_video_to_video_latent(input_video_path, video_length, sample_size):
170
+ if type(input_video_path) is str:
171
+ cap = cv2.VideoCapture(input_video_path)
172
+ input_video = []
173
+ while True:
174
+ ret, frame = cap.read()
175
+ if not ret:
176
+ break
177
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
178
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
179
+ cap.release()
180
+ else:
181
+ input_video = input_video_path
182
+
183
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
184
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
185
+
186
+ input_video_mask = torch.zeros_like(input_video[:, :1])
187
+ input_video_mask[:, :, :] = 255
188
+
189
+ return input_video, input_video_mask, None
reports/report_v1.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX FUN v1 Report
2
+ In CogVideoX-FUN, we trained on approximately 1.2 million data points based on CogVideoX, supporting image and video predictions. It accommodates pixel values for video generation across different resolutions of 512x512x49, 768x768x49, and 1024x1024x49, as well as videos with different aspect ratios. Moreover, we support the generation of videos from images and the reconstruction of videos from other videos.
3
+
4
+ Compared to CogVideoX, CogVideoX FUN also highlights the following features:
5
+ - Introduction of the InPaint model, enabling the generation of videos from images with specified starting and ending images.
6
+ - Training the model based on token lengths. This allows for the implementation of various sizes and resolutions within the same model.
7
+
8
+ ## InPaint Model
9
+ We used [CogVideoX](https://github.com/THUDM/CogVideo/) as the foundational structure, referencing [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) for the model training to generate videos from images.
10
+
11
+ During video generation, the **reference video** is encoded using VAE, with the **black area in the above image representing the part to be reconstructed, and the white area representing the start image**. This is stacked with noise latents and input into the Transformer for video generation. We perform 3D resizing on the **masked area**, directly resizing it to fit the canvas size of the video that needs reconstruction.
12
+
13
+ Then, we concatenate the latent, the encoded reference video, and the masked area, inputting them into DiT for noise prediction to obtain the final video.
14
+ The pipeline structure of CogVideoX FUN is as follows:
15
+ <img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/pipeline.jpg" alt="ui" style="zoom:50%;" />
16
+
17
+ ## Token Length-Based Model Training
18
+ We collected approximately 1.2 million high-quality data for the training of CogVideoX-Fun. During the training, we resized the videos based on different token lengths. The entire training process is divided into three phases, with each phase corresponding to 13312 (for 512x512x49 videos), 29952 (for 768x768x49 videos), and 53248 (for 1024x1024x49 videos).
19
+
20
+ Taking CogVideoX-Fun-2B as an example:
21
+ - In the 13312 phase, the batch size is 128 with 7k training steps.
22
+ - In the 29952 phase, the batch size is 256 with 6.5k training steps.
23
+ - In the 53248 phase, the batch size is 128 with 5k training steps.
24
+
25
+ During training, we combined high and low resolutions, enabling the model to support video generation from any resolution between 512 and 1280. For example, with a token length of 13312:
26
+ - At a resolution of 512x512, the number of video frames is 49.
27
+ - At a resolution of 768x768, the number of video frames is 21.
28
+ - At a resolution of 1024x1024, the number of video frames is 9.
29
+
30
+ These resolutions and corresponding lengths were mixed for training, allowing the model to generate videos at different resolutions.
31
+
32
+ ## Resize 3D Embedding
33
+ In adapting CogVideoX-2B to the CogVideoX-Fun framework, it was found that the source code obtains 3D embeddings in a truncated manner. This approach only accommodates a single resolution; when the resolution changes, the embedding should also change.
34
+ <img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/PE_Interpolation.jpg" alt="ui" style="zoom:50%;" />
35
+
36
+ Referencing Pixart-Sigma, the above image is from the Pixart-Sigma paper. We used Positional Embeddings Interpolation (PE Interpolation) to resize 3D embeddings. PE Interpolation is more conducive to convergence than directly generating cosine and sine embeddings for different resolutions.
reports/report_v1_zh-CN.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX FUN v1 Report
2
+
3
+ 在CogVideoX-FUN中,我们基于CogVideoX在大约1.2m的数据上进行了训练,支持图片与视频预测,支持像素值从512x512x49、768x768x49、1024x1024x49与不同纵横比的视频生成。另外,我们支持图像到视频的生成与视频到视频的重建。
4
+
5
+ 对比与CogVideoX,CogVideoX FUN还突出了以下功能:
6
+
7
+ - 引入InPaint模型,实现图生视频功能,可以通过首尾图指定视频生成。
8
+ - 基于Token长度的模型训练。达成不同大小多分辨率在同一模型中的实现。
9
+
10
+ ## InPaint模型
11
+ 我们以[CogVideoX](https://github.com/THUDM/CogVideo/)作为基础结构,参考[EasyAnimate](https://github.com/aigc-apps/EasyAnimate)进行图生视频的模型训练。
12
+
13
+ 在进行视频生成的时候,将**参考视频**使用VAE进行encode,**上图黑色的部分代表需要重建的部分,白色的部分代表首图**,与噪声Latents一起堆叠后输入到Transformer中进行视频生成。
14
+
15
+ 我们对**被Mask的区域**进行3D Resize,直接Resize到需要重建的视频的画布大小。
16
+
17
+ 然后将Latent、Encode后的参考视频、被Mask的区域,concat后输入到DiT中进行噪声预测。获得最终的视频。
18
+
19
+ CogVideoX FUN的Pipeline结构如下:
20
+ <img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/pipeline.jpg" alt="ui" style="zoom:50%;" />
21
+
22
+ ## 基于Token长度的模型训练
23
+ 我们收集了大约高质量的1.2m数据进行CogVideoX-Fun的训练。
24
+
25
+ 在进行训练时,我们根据不同Token长度,对视频进行缩放后进行训练。整个训练过程分为三个阶段,每个阶段的13312(对应512x512x49的视频),29952(对应768x768x49的视频),53248(对应1024x1024x49的视频)。
26
+
27
+ 以CogVideoX-Fun-2B为例子,其中:
28
+ - 13312阶段,Batch size为128,训练步数为7k
29
+ - 29952阶段,Batch size为256,训练步数为6.5k。
30
+ - 53248阶段,Batch size为128,训练步数为5k。
31
+
32
+ 训练时我们采用高低分辨率结合训练,因此模型支持从512到1280任意分辨率的视频生成,以13312 token长度为例:
33
+ - 在512x512分辨率下,视频帧数为49;
34
+ - 在768x768分辨率下,视频帧数为21;
35
+ - 在1024x1024分辨率下,视频帧数为9;
36
+ 这些分辨率与对应长度混合训练,模型可以完成不同大小分辨率的视频生成。
37
+
38
+ ## Resize 3D Embedding
39
+ 在适配CogVideoX-2B到CogVideoX-Fun框架的途中,发现源码是以截断的方式去得到3D Embedding的,这样的方式只能适配单一分辨率,当分辨率发生变化时,Embedding也应当发生变化。
40
+
41
+ <img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/PE_Interpolation.jpg" alt="ui" style="zoom:50%;" />
42
+
43
+ 参考Pixart-Sigma,上图来自于Pixart-Sigma论文,我们采用Positional Embeddings Interpolation(PE Interpolation)对3D embedding进行Resize,PE Interpolation相比于直接生成不同分辨率的Cos Sin Embedding更易收敛。
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ einops
3
+ safetensors
4
+ timm
5
+ tomesd
6
+ torch>=2.1.2
7
+ torchdiffeq
8
+ torchsde
9
+ xformers
10
+ decord
11
+ datasets
12
+ numpy
13
+ scikit-image
14
+ opencv-python
15
+ omegaconf
16
+ SentencePiece
17
+ albumentations
18
+ imageio[ffmpeg]
19
+ imageio[pyav]
20
+ tensorboard
21
+ beautifulsoup4
22
+ ftfy
23
+ func_timeout
24
+ deepspeed
25
+ accelerate>=0.25.0
26
+ gradio>=3.41.2
27
+ diffusers>=0.28.2
28
+ transformers>=4.37.2