diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..894895268b94a848acb556edef494b86ee4eb948
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,37 @@
+## Coding Standards
+
+### Unit Tests
+We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
+
+To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
+```bash
+pip install -r requirements/requirements-test.txt
+```
+If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
+
+If you only want to run CPU tests, you can run
+
+```bash
+pytest -m cpu tests/
+```
+
+If you have 8 GPUs on your machine, you can run the full test
+
+```bash
+pytest tests/
+```
+
+If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
+
+
+### Code Style
+
+We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
+
+```shell
+# these commands are executed under the Colossal-AI directory
+pip install pre-commit
+pre-commit install
+```
+
+Code format checking will be automatically executed when you commit your changes.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f6aeff79d7c620029e3490a60fe27a0dd44e4f9c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,2140 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright VideoSys
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ ## Some of videosys's code is derived from others projects, which is subject to the following copyright notice:
+
+ ---------------- LICENSE FOR ColossalAI ----------------
+
+ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021- HPC-AI Technology Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ ---------------- LICENSE FOR Flash Attention ----------------
+
+ BSD 3-Clause License
+
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+ ---------------- LICENSE FOR Meta DiT ----------------
+
+ Attribution-NonCommercial 4.0 International
+
+ =======================================================================
+
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
+ does not provide legal services or legal advice. Distribution of
+ Creative Commons public licenses does not create a lawyer-client or
+ other relationship. Creative Commons makes its licenses and related
+ information available on an "as-is" basis. Creative Commons gives no
+ warranties regarding its licenses, any material licensed under their
+ terms and conditions, or any related information. Creative Commons
+ disclaims all liability for damages resulting from their use to the
+ fullest extent possible.
+
+ Using Creative Commons Public Licenses
+
+ Creative Commons public licenses provide a standard set of terms and
+ conditions that creators and other rights holders may use to share
+ original works of authorship and other material subject to copyright
+ and certain other rights specified in the public license below. The
+ following considerations are for informational purposes only, are not
+ exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+ =======================================================================
+
+ Creative Commons Attribution-NonCommercial 4.0 International Public
+ License
+
+ By exercising the Licensed Rights (defined below), You accept and agree
+ to be bound by the terms and conditions of this Creative Commons
+ Attribution-NonCommercial 4.0 International Public License ("Public
+ License"). To the extent this Public License may be interpreted as a
+ contract, You are granted the Licensed Rights in consideration of Your
+ acceptance of these terms and conditions, and the Licensor grants You
+ such rights in consideration of benefits the Licensor receives from
+ making the Licensed Material available under these terms and
+ conditions.
+
+ Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+ Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+ Section 3 -- License Conditions.
+
+ Your exercise of the Licensed Rights is expressly made subject to the
+ following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+ Section 4 -- Sui Generis Database Rights.
+
+ Where the Licensed Rights include Sui Generis Database Rights that
+ apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+ For the avoidance of doubt, this Section 4 supplements and does not
+ replace Your obligations under this Public License where the Licensed
+ Rights include other Copyright and Similar Rights.
+
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+ Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+ Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+ Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+ =======================================================================
+
+ Creative Commons is not a party to its public
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
+ its public licenses to material it publishes and in those instances
+ will be considered the “Licensor.” The text of the Creative Commons
+ public licenses is dedicated to the public domain under the CC0 Public
+ Domain Dedication. Except for the limited purpose of indicating that
+ material is shared under a Creative Commons public license or as
+ otherwise permitted by the Creative Commons policies published at
+ creativecommons.org/policies, Creative Commons does not authorize the
+ use of the trademark "Creative Commons" or any other trademark or logo
+ of Creative Commons without its prior written consent including,
+ without limitation, in connection with any unauthorized modifications
+ to any of its public licenses or any other arrangements,
+ understandings, or agreements concerning use of licensed material. For
+ the avoidance of doubt, this paragraph does not form part of the
+ public licenses.
+
+ Creative Commons may be contacted at creativecommons.org.
+
+ ---------------- LICENSE FOR OpenSoraPlan ----------------
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+ ---------------- LICENSE FOR OpenSora ----------------
+
+Copyright 2024 HPC-AI Technology Inc. All rights reserved.
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 HPC-AI Technology Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ =========================================================================
+ This project is inspired by the listed projects and is subject to the following licenses:
+
+ 1. Latte (https://github.com/Vchitect/Latte/blob/main/LICENSE)
+
+ Copyright 2024 Latte
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ 2. PixArt-alpha (https://github.com/PixArt-alpha/PixArt-alpha/blob/master/LICENSE)
+
+ Copyright (C) 2024 PixArt-alpha/PixArt-alpha
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+ 3. dpm-solver (https://github.com/LuChengTHU/dpm-solver/blob/main/LICENSE)
+
+ MIT License
+
+ Copyright (c) 2022 Cheng Lu
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+
+ 4. DiT (https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt)
+
+ Attribution-NonCommercial 4.0 International
+
+ =======================================================================
+
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
+ does not provide legal services or legal advice. Distribution of
+ Creative Commons public licenses does not create a lawyer-client or
+ other relationship. Creative Commons makes its licenses and related
+ information available on an "as-is" basis. Creative Commons gives no
+ warranties regarding its licenses, any material licensed under their
+ terms and conditions, or any related information. Creative Commons
+ disclaims all liability for damages resulting from their use to the
+ fullest extent possible.
+
+ Using Creative Commons Public Licenses
+
+ Creative Commons public licenses provide a standard set of terms and
+ conditions that creators and other rights holders may use to share
+ original works of authorship and other material subject to copyright
+ and certain other rights specified in the public license below. The
+ following considerations are for informational purposes only, are not
+ exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+ =======================================================================
+
+ Creative Commons Attribution-NonCommercial 4.0 International Public
+ License
+
+ By exercising the Licensed Rights (defined below), You accept and agree
+ to be bound by the terms and conditions of this Creative Commons
+ Attribution-NonCommercial 4.0 International Public License ("Public
+ License"). To the extent this Public License may be interpreted as a
+ contract, You are granted the Licensed Rights in consideration of Your
+ acceptance of these terms and conditions, and the Licensor grants You
+ such rights in consideration of benefits the Licensor receives from
+ making the Licensed Material available under these terms and
+ conditions.
+
+ Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+ Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+ Section 3 -- License Conditions.
+
+ Your exercise of the Licensed Rights is expressly made subject to the
+ following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+ Section 4 -- Sui Generis Database Rights.
+
+ Where the Licensed Rights include Sui Generis Database Rights that
+ apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+ For the avoidance of doubt, this Section 4 supplements and does not
+ replace Your obligations under this Public License where the Licensed
+ Rights include other Copyright and Similar Rights.
+
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+ Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+ Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+ Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+ =======================================================================
+
+ Creative Commons is not a party to its public
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
+ its public licenses to material it publishes and in those instances
+ will be considered the “Licensor.” The text of the Creative Commons
+ public licenses is dedicated to the public domain under the CC0 Public
+ Domain Dedication. Except for the limited purpose of indicating that
+ material is shared under a Creative Commons public license or as
+ otherwise permitted by the Creative Commons policies published at
+ creativecommons.org/policies, Creative Commons does not authorize the
+ use of the trademark "Creative Commons" or any other trademark or logo
+ of Creative Commons without its prior written consent including,
+ without limitation, in connection with any unauthorized modifications
+ to any of its public licenses or any other arrangements,
+ understandings, or agreements concerning use of licensed material. For
+ the avoidance of doubt, this paragraph does not form part of the
+ public licenses.
+
+ Creative Commons may be contacted at creativecommons.org.
+
+ 5. OpenDiT (https://github.com/NUS-HPC-AI-Lab/OpenDiT/blob/master/LICENSE)
+
+ Copyright OpenDiT
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+ ---------------- LICENSE FOR Latte ----------------
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+ ---------------- LICENSE FOR CogVideo ----------------
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 CogVideo Model Team @ Zhipu AI
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index a21b7b3558f8e070648963032ae94630be4709e3..9c29ffb6076076a8bbf9b186467951e790f0f605 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,20 @@
---
-title: Demo
-emoji: 🔥
-colorFrom: indigo
-colorTo: yellow
+title: VideoSys-CogVideoX
+emoji: 🎥
+colorFrom: yellow
+colorTo: green
sdk: gradio
sdk_version: 4.42.0
+suggested_hardware: a10g-large
+suggested_storage: large
+app_port: 7860
app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+models:
+ - THUDM/CogVideoX-2b
+tags:
+ - cogvideox
+ - video-generation
+ - thudm
+short_description: Text-to-Video
+disable_embedding: false
+---
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad90a9cac3f85769241e2744da14557bd5b4ef78
--- /dev/null
+++ b/app.py
@@ -0,0 +1,508 @@
+# # import gradio as gr
+# # from videosys import CogVideoConfig, VideoSysEngine
+# # import tempfile
+# # import os
+# # import logging
+# # import uuid
+
+# # logging.basicConfig(level=logging.INFO)
+# # logger = logging.getLogger(__name__)
+
+# # config = CogVideoConfig(world_size=1)
+# # engine = VideoSysEngine(config)
+
+# # def generate_video(prompt):
+# # try:
+# # video = engine.generate(prompt).video[0]
+
+# # # 使用临时文件和唯一标识符
+# # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
+# # temp_filename = temp_file.name
+# # unique_filename = f"{uuid.uuid4().hex}.mp4"
+# # output_path = os.path.join(tempfile.gettempdir(), unique_filename)
+
+# # engine.save_video(video, output_path)
+
+# # return output_path
+# # except Exception as e:
+# # logger.error(f"An error occurred: {str(e)}")
+# # return None # 返回 None 而不是错误消息
+
+# # iface = gr.Interface(
+# # fn=generate_video,
+# # inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
+# # outputs=gr.Video(label="Generated Video"),
+# # title="CogVideoX-2b: Text-to-Video Generation",
+# # description="Enter a text prompt to generate a video using CogVideoX-2b."
+# # )
+
+# # iface.launch()
+
+
+# from videosys import CogVideoConfig, VideoSysEngine
+# from videosys.models.cogvideo.pipeline import CogVideoPABConfig
+# import os
+
+# import gradio as gr
+# import numpy as np
+# import torch
+# from openai import OpenAI
+# from time import time
+# import tempfile
+# import uuid
+# import logging
+
+# logging.basicConfig(level=logging.INFO)
+# logger = logging.getLogger(__name__)
+
+# dtype = torch.bfloat16
+# sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
+
+# For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
+# There are a few rules to follow:
+
+# You will only ever output a single video description per user request.
+
+# When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
+# Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
+
+# Video descriptions must have the same num of words as examples below. Extra words will be ignored.
+# """
+
+# def convert_prompt(prompt: str, retry_times: int = 3) -> str:
+# if not os.environ.get("OPENAI_API_KEY"):
+# return prompt
+# client = OpenAI()
+# text = prompt.strip()
+
+# for i in range(retry_times):
+# response = client.chat.completions.create(
+# messages=[
+# {"role": "system", "content": sys_prompt},
+# {
+# "role": "user",
+# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
+# },
+# {
+# "role": "assistant",
+# "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
+# },
+# {
+# "role": "user",
+# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
+# },
+# {
+# "role": "assistant",
+# "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
+# },
+# {
+# "role": "user",
+# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
+# },
+# {
+# "role": "assistant",
+# "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
+# },
+# {
+# "role": "user",
+# "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
+# },
+# ],
+# model="glm-4-0520",
+# temperature=0.01,
+# top_p=0.7,
+# stream=False,
+# max_tokens=250,
+# )
+# if response.choices:
+# return response.choices[0].message.content
+# return prompt
+
+# def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
+# pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
+# config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
+# engine = VideoSysEngine(config)
+# return engine
+
+
+
+# def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
+# try:
+# video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
+
+# with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
+# temp_file.name
+# unique_filename = f"{uuid.uuid4().hex}.mp4"
+# output_path = os.path.join(tempfile.gettempdir(), unique_filename)
+
+# engine.save_video(video, output_path)
+# return output_path
+# except Exception as e:
+# logger.error(f"An error occurred: {str(e)}")
+# return None
+
+
+
+# with gr.Blocks() as demo:
+# gr.Markdown("""
+#
+# VideoSys Huggingface Space🤗
+#
+#
+
+#
+# ⚠️ This demo is for academic research and experiential use only.
+# Users should strictly adhere to local laws and ethics.
+#
+#
+# 💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
+#
+# """)
+# with gr.Row():
+# with gr.Column():
+# prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
+# with gr.Row():
+# gr.Markdown(
+# "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
+# )
+# enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
+
+# with gr.Column():
+# gr.Markdown(
+# "**Optional Parameters** (default values are recommended)
"
+# "Turn Inference Steps larger if you want more detailed video, but it will be slower.
"
+# "50 steps are recommended for most cases. will cause 120 seconds for inference.
"
+# )
+# with gr.Row():
+# num_inference_steps = gr.Number(label="Inference Steps", value=50)
+# guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
+# pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
+# pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
+# with gr.Row():
+# generate_button = gr.Button("🎬 Generate Video")
+# generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
+
+# with gr.Column():
+# with gr.Row():
+# video_output = gr.Video(label="CogVideoX", width=720, height=480)
+# with gr.Row():
+# download_video_button = gr.File(label="📥 Download Video", visible=False)
+# elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
+# with gr.Row():
+# video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
+# with gr.Row():
+# download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
+# elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
+
+# def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
+# # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
+# engine = load_model()
+# t = time()
+# video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+# elapsed_time = time() - t
+# video_update = gr.update(visible=True, value=video_path)
+# elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
+
+# return video_path, video_update, elapsed_time
+
+# def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
+# # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
+# threshold = [int(i) for i in threshold.split(",")]
+# gap = int(gap)
+# engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
+# t = time()
+# video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+# elapsed_time = time() - t
+# video_update = gr.update(visible=True, value=video_path)
+# elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
+
+# return video_path, video_update, elapsed_time
+
+
+# def enhance_prompt_func(prompt):
+# return convert_prompt(prompt, retry_times=1)
+
+# generate_button.click(
+# generate_vanilla,
+# inputs=[prompt, num_inference_steps, guidance_scale],
+# outputs=[video_output, download_video_button, elapsed_time],
+# )
+
+# generate_button_vs.click(
+# generate_vs,
+# inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
+# outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
+# )
+
+# enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
+
+# if __name__ == "__main__":
+# demo.launch()
+
+
+
+import gradio as gr
+from videosys import CogVideoConfig, VideoSysEngine
+from videosys.models.cogvideo.pipeline import CogVideoPABConfig
+import os
+import numpy as np
+import torch
+from openai import OpenAI
+from time import time
+import tempfile
+import uuid
+import logging
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+dtype = torch.bfloat16
+sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
+
+For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
+There are a few rules to follow:
+
+You will only ever output a single video description per user request.
+
+When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
+Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
+
+Video descriptions must have the same num of words as examples below. Extra words will be ignored.
+"""
+
+def convert_prompt(prompt: str, retry_times: int = 3) -> str:
+ if not os.environ.get("OPENAI_API_KEY"):
+ return prompt
+ client = OpenAI()
+ text = prompt.strip()
+
+ for i in range(retry_times):
+ response = client.chat.completions.create(
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
+ },
+ {
+ "role": "assistant",
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
+ },
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
+ },
+ {
+ "role": "assistant",
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
+ },
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
+ },
+ {
+ "role": "assistant",
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
+ },
+ {
+ "role": "user",
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
+ },
+ ],
+ model="glm-4-0520",
+ temperature=0.01,
+ top_p=0.7,
+ stream=False,
+ max_tokens=250,
+ )
+ if response.choices:
+ return response.choices[0].message.content
+ return prompt
+
+def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
+ pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
+ config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
+ engine = VideoSysEngine(config)
+ return engine
+
+def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
+ try:
+ video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
+
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
+ temp_file.name
+ unique_filename = f"{uuid.uuid4().hex}.mp4"
+ output_path = os.path.join(tempfile.gettempdir(), unique_filename)
+
+ engine.save_video(video, output_path)
+ return output_path
+ except Exception as e:
+ logger.error(f"An error occurred: {str(e)}")
+ return None
+
+css = """
+body {
+ font-family: Arial, sans-serif;
+ line-height: 1.6;
+ color: #333;
+ max-width: 1200px;
+ margin: 0 auto;
+ padding: 20px;
+}
+
+.container {
+ display: flex;
+ flex-direction: column;
+ gap: 20px;
+}
+
+.row {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 20px;
+}
+
+.column {
+ flex: 1;
+ min-width: 0;
+}
+
+.textbox, .number-input, button {
+ width: 100%;
+ padding: 10px;
+ margin-bottom: 10px;
+ border: 1px solid #ddd;
+ border-radius: 4px;
+}
+
+button {
+ background-color: #4CAF50;
+ color: white;
+ border: none;
+ cursor: pointer;
+ transition: background-color 0.3s;
+}
+
+button:hover {
+ background-color: #45a049;
+}
+
+.video-output {
+ width: 100%;
+ max-width: 720px;
+ height: auto;
+ margin: 0 auto;
+}
+
+@media (max-width: 768px) {
+ .row {
+ flex-direction: column;
+ }
+
+ .column {
+ width: 100%;
+ }
+
+ .video-output {
+ width: 100%;
+ height: auto;
+ }
+}
+"""
+
+with gr.Blocks(css=css) as demo:
+ gr.HTML("""
+
+ VideoSys Huggingface Space🤗
+
+
+
+ ⚠️ This demo is for academic research and experiential use only.
+ Users should strictly adhere to local laws and ethics.
+
+
+ 💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
+
+ """)
+
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
+ with gr.Row():
+ gr.Markdown(
+ "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
+ )
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
+
+ with gr.Column():
+ gr.Markdown(
+ "**Optional Parameters** (default values are recommended)
"
+ "Turn Inference Steps larger if you want more detailed video, but it will be slower.
"
+ "50 steps are recommended for most cases. will cause 120 seconds for inference.
"
+ )
+ with gr.Row():
+ num_inference_steps = gr.Number(label="Inference Steps", value=50)
+ guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
+ pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
+ pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
+ with gr.Row():
+ generate_button = gr.Button("🎬 Generate Video")
+ generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
+
+ with gr.Column():
+ with gr.Row():
+ video_output = gr.Video(label="CogVideoX", width=720, height=480)
+ with gr.Row():
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
+ elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
+ with gr.Row():
+ video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
+ with gr.Row():
+ download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
+ elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
+
+ def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
+ engine = load_model()
+ t = time()
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+ elapsed_time = time() - t
+ video_update = gr.update(visible=True, value=video_path)
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
+
+ return video_path, video_update, elapsed_time
+
+ def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
+ threshold = [int(i) for i in threshold.split(",")]
+ gap = int(gap)
+ engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
+ t = time()
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+ elapsed_time = time() - t
+ video_update = gr.update(visible=True, value=video_path)
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
+
+ return video_path, video_update, elapsed_time
+
+ def enhance_prompt_func(prompt):
+ return convert_prompt(prompt, retry_times=1)
+
+ generate_button.click(
+ generate_vanilla,
+ inputs=[prompt, num_inference_steps, guidance_scale],
+ outputs=[video_output, download_video_button, elapsed_time],
+ )
+
+ generate_button_vs.click(
+ generate_vs,
+ inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
+ outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
+ )
+
+ enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
+
+if __name__ == "__main__":
+ demo.launch()
\ No newline at end of file
diff --git a/docs/dsp.md b/docs/dsp.md
new file mode 100644
index 0000000000000000000000000000000000000000..2a08cbc44db772909cb7763d449f0a6df51f10bb
--- /dev/null
+++ b/docs/dsp.md
@@ -0,0 +1,25 @@
+# DSP
+
+paper: https://arxiv.org/abs/2403.10266
+
+![dsp_overview](../assets/figures/dsp_overview.png)
+
+
+DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
+
+The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
+
+It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
+
+| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
+| ------ | ------ | ------ | ------ |
+| Latency(s) | 106 | 45 | 22 |
+
+The following is DSP's end-to-end throughput for training of OpenSora:
+
+![dsp_overview](../assets/figures/dsp_exp.png)
+
+
+### Usage
+
+DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
diff --git a/docs/pab.md b/docs/pab.md
new file mode 100644
index 0000000000000000000000000000000000000000..0de7b98139e52b17edc5fd43e5aad1a5f9a01525
--- /dev/null
+++ b/docs/pab.md
@@ -0,0 +1,121 @@
+# Pyramid Attention Broadcast(PAB)
+
+[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
+
+Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
+- [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
+ - [Insights](#insights)
+ - [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
+ - [Experimental Results](#experimental-results)
+ - [Usage](#usage)
+ - [Supported Models](#supported-models)
+ - [Configuration for PAB](#configuration-for-pab)
+ - [Parameters](#parameters)
+ - [Example Configuration](#example-configuration)
+
+
+We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
+
+## Insights
+
+![method](../assets/figures/pab_motivation.png)
+
+Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
+- First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
+- Second, within the stable middle segment, the variability differs among attention types:
+ - **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
+ - **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
+ - **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
+
+## Pyramid Attention Broadcast (PAB) Mechanism
+
+![method](../assets/figures/pab_method.png)
+
+Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
+
+In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
+
+For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
+**The smaller the variation in attention, the broader the potential broadcast range.**
+
+
+## Experimental Results
+Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
+
+![pab_vis](../assets/figures/pab_vis.png)
+
+
+## Usage
+
+### Supported Models
+
+PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
+
+### Configuration for PAB
+
+To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
+
+#### Parameters
+
+- **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
+ - Type: `True` or `False`
+
+- **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
+ - Format: `[min_value, max_value]`
+
+- **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
+ - Type: Integer
+
+- **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
+ - Type: `True` or `False`
+
+- **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
+ - Format: `[min_value, max_value]`
+
+- **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
+ - Type: Integer
+
+- **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
+ - Type: `True` or `False`
+
+- **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
+ - Format: `[min_value, max_value]`
+
+- **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
+ - Type: Integer
+
+#### Example Configuration
+
+```yaml
+spatial_broadcast: True
+spatial_threshold: [100, 800]
+spatial_gap: 2
+
+temporal_broadcast: True
+temporal_threshold: [100, 800]
+temporal_gap: 3
+
+cross_broadcast: True
+cross_threshold: [100, 900]
+cross_gap: 5
+```
+
+Explanation:
+
+- **Spatial Attention**:
+ - Broadcasting enabled (`spatial_broadcast: True`)
+ - Applied within the threshold range of 100 to 800
+ - Skips every 2 steps (`spatial_gap: 2`)
+ - Active within the first 28 steps (`spatial_block: [0, 28]`)
+
+- **Temporal Attention**:
+ - Broadcasting enabled (`temporal_broadcast: True`)
+ - Applied within the threshold range of 100 to 800
+ - Skips every 3 steps (`temporal_gap: 3`)
+
+- **Cross-Modal Attention**:
+ - Broadcasting enabled (`cross_broadcast: True`)
+ - Applied within the threshold range of 100 to 900
+ - Skips every 5 steps (`cross_gap: 5`)
+
+Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
diff --git a/eval/pab/commom_metrics/README.md b/eval/pab/commom_metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1e595d9229094f1f165146b680843f93138a577d
--- /dev/null
+++ b/eval/pab/commom_metrics/README.md
@@ -0,0 +1,6 @@
+Common metrics
+
+Include LPIPS, PSNR and SSIM.
+
+The code is adapted from [common_metrics_on_video_quality
+](https://github.com/JunyaoHu/common_metrics_on_video_quality).
diff --git a/eval/pab/commom_metrics/__init__.py b/eval/pab/commom_metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/eval/pab/commom_metrics/calculate_lpips.py b/eval/pab/commom_metrics/calculate_lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d9efcf24235f6d91701541ab8acfa7279bbecf4
--- /dev/null
+++ b/eval/pab/commom_metrics/calculate_lpips.py
@@ -0,0 +1,97 @@
+import lpips
+import numpy as np
+import torch
+
+spatial = True # Return a spatial map of perceptual distance.
+
+# Linearly calibrated models (LPIPS)
+loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
+# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
+
+
+def trans(x):
+ # if greyscale images add channel
+ if x.shape[-3] == 1:
+ x = x.repeat(1, 1, 3, 1, 1)
+
+ # value range [0, 1] -> [-1, 1]
+ x = x * 2 - 1
+
+ return x
+
+
+def calculate_lpips(videos1, videos2, device):
+ # image should be RGB, IMPORTANT: normalized to [-1,1]
+
+ assert videos1.shape == videos2.shape
+
+ # videos [batch_size, timestamps, channel, h, w]
+
+ # support grayscale input, if grayscale -> channel*3
+ # value range [0, 1] -> [-1, 1]
+ videos1 = trans(videos1)
+ videos2 = trans(videos2)
+
+ lpips_results = []
+
+ for video_num in range(videos1.shape[0]):
+ # get a video
+ # video [timestamps, channel, h, w]
+ video1 = videos1[video_num]
+ video2 = videos2[video_num]
+
+ lpips_results_of_a_video = []
+ for clip_timestamp in range(len(video1)):
+ # get a img
+ # img [timestamps[x], channel, h, w]
+ # img [channel, h, w] tensor
+
+ img1 = video1[clip_timestamp].unsqueeze(0).to(device)
+ img2 = video2[clip_timestamp].unsqueeze(0).to(device)
+
+ loss_fn.to(device)
+
+ # calculate lpips of a video
+ lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
+ lpips_results.append(lpips_results_of_a_video)
+
+ lpips_results = np.array(lpips_results)
+
+ lpips = {}
+ lpips_std = {}
+
+ for clip_timestamp in range(len(video1)):
+ lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
+ lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
+
+ result = {
+ "value": lpips,
+ "value_std": lpips_std,
+ "video_setting": video1.shape,
+ "video_setting_name": "time, channel, heigth, width",
+ }
+
+ return result
+
+
+# test code / using example
+
+
+def main():
+ NUMBER_OF_VIDEOS = 8
+ VIDEO_LENGTH = 50
+ CHANNEL = 3
+ SIZE = 64
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+ videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+ device = torch.device("cuda")
+ # device = torch.device("cpu")
+
+ import json
+
+ result = calculate_lpips(videos1, videos2, device)
+ print(json.dumps(result, indent=4))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/eval/pab/commom_metrics/calculate_psnr.py b/eval/pab/commom_metrics/calculate_psnr.py
new file mode 100644
index 0000000000000000000000000000000000000000..416bc48a94e5fa5c242cd92a89c9b165e522b86b
--- /dev/null
+++ b/eval/pab/commom_metrics/calculate_psnr.py
@@ -0,0 +1,90 @@
+import math
+
+import numpy as np
+import torch
+
+
+def img_psnr(img1, img2):
+ # [0,1]
+ # compute mse
+ # mse = np.mean((img1-img2)**2)
+ mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
+ # compute psnr
+ if mse < 1e-10:
+ return 100
+ psnr = 20 * math.log10(1 / math.sqrt(mse))
+ return psnr
+
+
+def trans(x):
+ return x
+
+
+def calculate_psnr(videos1, videos2):
+ # videos [batch_size, timestamps, channel, h, w]
+
+ assert videos1.shape == videos2.shape
+
+ videos1 = trans(videos1)
+ videos2 = trans(videos2)
+
+ psnr_results = []
+
+ for video_num in range(videos1.shape[0]):
+ # get a video
+ # video [timestamps, channel, h, w]
+ video1 = videos1[video_num]
+ video2 = videos2[video_num]
+
+ psnr_results_of_a_video = []
+ for clip_timestamp in range(len(video1)):
+ # get a img
+ # img [timestamps[x], channel, h, w]
+ # img [channel, h, w] numpy
+
+ img1 = video1[clip_timestamp].numpy()
+ img2 = video2[clip_timestamp].numpy()
+
+ # calculate psnr of a video
+ psnr_results_of_a_video.append(img_psnr(img1, img2))
+
+ psnr_results.append(psnr_results_of_a_video)
+
+ psnr_results = np.array(psnr_results)
+
+ psnr = {}
+ psnr_std = {}
+
+ for clip_timestamp in range(len(video1)):
+ psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
+ psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
+
+ result = {
+ "value": psnr,
+ "value_std": psnr_std,
+ "video_setting": video1.shape,
+ "video_setting_name": "time, channel, heigth, width",
+ }
+
+ return result
+
+
+# test code / using example
+
+
+def main():
+ NUMBER_OF_VIDEOS = 8
+ VIDEO_LENGTH = 50
+ CHANNEL = 3
+ SIZE = 64
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+
+ import json
+
+ result = calculate_psnr(videos1, videos2)
+ print(json.dumps(result, indent=4))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/eval/pab/commom_metrics/calculate_ssim.py b/eval/pab/commom_metrics/calculate_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa78bd5ca2cd826f04b4f42e7dd5c53d61ed7231
--- /dev/null
+++ b/eval/pab/commom_metrics/calculate_ssim.py
@@ -0,0 +1,116 @@
+import cv2
+import numpy as np
+import torch
+
+
+def ssim(img1, img2):
+ C1 = 0.01**2
+ C2 = 0.03**2
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+def calculate_ssim_function(img1, img2):
+ # [0,1]
+ # ssim is the only metric extremely sensitive to gray being compared to b/w
+ if not img1.shape == img2.shape:
+ raise ValueError("Input images must have the same dimensions.")
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[0] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[i], img2[i]))
+ return np.array(ssims).mean()
+ elif img1.shape[0] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError("Wrong input image dimensions.")
+
+
+def trans(x):
+ return x
+
+
+def calculate_ssim(videos1, videos2):
+ # videos [batch_size, timestamps, channel, h, w]
+
+ assert videos1.shape == videos2.shape
+
+ videos1 = trans(videos1)
+ videos2 = trans(videos2)
+
+ ssim_results = []
+
+ for video_num in range(videos1.shape[0]):
+ # get a video
+ # video [timestamps, channel, h, w]
+ video1 = videos1[video_num]
+ video2 = videos2[video_num]
+
+ ssim_results_of_a_video = []
+ for clip_timestamp in range(len(video1)):
+ # get a img
+ # img [timestamps[x], channel, h, w]
+ # img [channel, h, w] numpy
+
+ img1 = video1[clip_timestamp].numpy()
+ img2 = video2[clip_timestamp].numpy()
+
+ # calculate ssim of a video
+ ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
+
+ ssim_results.append(ssim_results_of_a_video)
+
+ ssim_results = np.array(ssim_results)
+
+ ssim = {}
+ ssim_std = {}
+
+ for clip_timestamp in range(len(video1)):
+ ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
+ ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
+
+ result = {
+ "value": ssim,
+ "value_std": ssim_std,
+ "video_setting": video1.shape,
+ "video_setting_name": "time, channel, heigth, width",
+ }
+
+ return result
+
+
+# test code / using example
+
+
+def main():
+ NUMBER_OF_VIDEOS = 8
+ VIDEO_LENGTH = 50
+ CHANNEL = 3
+ SIZE = 64
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
+ torch.device("cuda")
+
+ import json
+
+ result = calculate_ssim(videos1, videos2)
+ print(json.dumps(result, indent=4))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/eval/pab/commom_metrics/eval.py b/eval/pab/commom_metrics/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f300510d6ecedd5875486a8d260a6b4bbd22327
--- /dev/null
+++ b/eval/pab/commom_metrics/eval.py
@@ -0,0 +1,160 @@
+import argparse
+import os
+
+import imageio
+import torch
+import torchvision.transforms.functional as F
+import tqdm
+from calculate_lpips import calculate_lpips
+from calculate_psnr import calculate_psnr
+from calculate_ssim import calculate_ssim
+
+
+def load_videos(directory, video_ids, file_extension):
+ videos = []
+ for video_id in video_ids:
+ video_path = os.path.join(directory, f"{video_id}.{file_extension}")
+ if os.path.exists(video_path):
+ video = load_video(video_path) # Define load_video based on how videos are stored
+ videos.append(video)
+ else:
+ raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
+ return videos
+
+
+def load_video(video_path):
+ """
+ Load a video from the given path and convert it to a PyTorch tensor.
+ """
+ # Read the video using imageio
+ reader = imageio.get_reader(video_path, "ffmpeg")
+
+ # Extract frames and convert to a list of tensors
+ frames = []
+ for frame in reader:
+ # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
+ frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
+ frames.append(frame_tensor)
+
+ # Stack the list of tensors into a single tensor with shape (T, C, H, W)
+ video_tensor = torch.stack(frames)
+
+ return video_tensor
+
+
+def resize_video(video, target_height, target_width):
+ resized_frames = []
+ for frame in video:
+ resized_frame = F.resize(frame, [target_height, target_width])
+ resized_frames.append(resized_frame)
+ return torch.stack(resized_frames)
+
+
+def preprocess_eval_video(eval_video, generated_video_shape):
+ T_gen, _, H_gen, W_gen = generated_video_shape
+ T_eval, _, H_eval, W_eval = eval_video.shape
+
+ if T_eval < T_gen:
+ raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
+
+ if H_eval < H_gen or W_eval < W_gen:
+ # Resize the video maintaining the aspect ratio
+ resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
+ resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
+ eval_video = resize_video(eval_video, resize_height, resize_width)
+ # Recalculate the dimensions
+ T_eval, _, H_eval, W_eval = eval_video.shape
+
+ # Center crop
+ start_h = (H_eval - H_gen) // 2
+ start_w = (W_eval - W_gen) // 2
+ cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
+
+ return cropped_video
+
+
+def main(args):
+ device = "cuda"
+ gt_video_dir = args.gt_video_dir
+ generated_video_dir = args.generated_video_dir
+
+ video_ids = []
+ file_extension = "mp4"
+ for f in os.listdir(generated_video_dir):
+ if f.endswith(f".{file_extension}"):
+ video_ids.append(f.replace(f".{file_extension}", ""))
+ if not video_ids:
+ raise ValueError("No videos found in the generated video dataset. Exiting.")
+
+ print(f"Find {len(video_ids)} videos")
+ prompt_interval = 1
+ batch_size = 16
+ calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
+
+ lpips_results = []
+ psnr_results = []
+ ssim_results = []
+
+ total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
+
+ for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
+ gt_videos_tensor = []
+ generated_videos_tensor = []
+ for i in range(batch_size):
+ video_idx = idx * batch_size + i
+ if video_idx >= len(video_ids):
+ break
+ video_id = video_ids[video_idx]
+ generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
+ generated_videos_tensor.append(generated_video)
+ eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
+ gt_videos_tensor.append(eval_video)
+ gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
+ generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
+
+ if calculate_lpips_flag:
+ result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
+ result = result["value"].values()
+ result = sum(result) / len(result)
+ lpips_results.append(result)
+
+ if calculate_psnr_flag:
+ result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
+ result = result["value"].values()
+ result = sum(result) / len(result)
+ psnr_results.append(result)
+
+ if calculate_ssim_flag:
+ result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
+ result = result["value"].values()
+ result = sum(result) / len(result)
+ ssim_results.append(result)
+
+ if (idx + 1) % prompt_interval == 0:
+ out_str = ""
+ for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
+ result = sum(results) / len(results)
+ out_str += f"{name}: {result:.4f}, "
+ print(f"Processed {idx + 1} videos. {out_str[:-2]}")
+
+ out_str = ""
+ for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
+ result = sum(results) / len(results)
+ out_str += f"{name}: {result:.4f}, "
+ out_str = out_str[:-2]
+
+ # save
+ with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
+ f.write(out_str)
+
+ print(f"Processed all videos. {out_str}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--gt_video_dir", type=str)
+ parser.add_argument("--generated_video_dir", type=str)
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/eval/pab/experiments/__init__.py b/eval/pab/experiments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/eval/pab/experiments/attention_ablation.py b/eval/pab/experiments/attention_ablation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78964d09a1e1f74bcdf382d468c1d2ca03e5ce9
--- /dev/null
+++ b/eval/pab/experiments/attention_ablation.py
@@ -0,0 +1,60 @@
+from utils import generate_func, read_prompt_list
+
+import videosys
+from videosys import OpenSoraConfig, OpenSoraPipeline
+from videosys.models.open_sora import OpenSoraPABConfig
+
+
+def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
+ pab_config = OpenSoraPABConfig(**pab_kwargs)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, output_dir)
+
+
+def main(prompt_list):
+ # spatial
+ gap_list = [2, 3, 4, 5]
+ for gap in gap_list:
+ pab_kwargs = {
+ "spatial_broadcast": True,
+ "spatial_gap": gap,
+ "temporal_broadcast": False,
+ "cross_broadcast": False,
+ "mlp_skip": False,
+ }
+ output_dir = f"./samples/attention_ablation/spatial_g{gap}"
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
+
+ # temporal
+ gap_list = [3, 4, 5, 6]
+ for gap in gap_list:
+ pab_kwargs = {
+ "spatial_broadcast": False,
+ "temporal_broadcast": True,
+ "temporal_gap": gap,
+ "cross_broadcast": False,
+ "mlp_skip": False,
+ }
+ output_dir = f"./samples/attention_ablation/temporal_g{gap}"
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
+
+ # cross
+ gap_list = [5, 6, 7, 8]
+ for gap in gap_list:
+ pab_kwargs = {
+ "spatial_broadcast": False,
+ "temporal_broadcast": False,
+ "cross_broadcast": True,
+ "cross_gap": gap,
+ "mlp_skip": False,
+ }
+ output_dir = f"./samples/attention_ablation/cross_g{gap}"
+ attention_ablation_func(pab_kwargs, prompt_list, output_dir)
+
+
+if __name__ == "__main__":
+ videosys.initialize(42)
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
+ main(prompt_list)
diff --git a/eval/pab/experiments/components_ablation.py b/eval/pab/experiments/components_ablation.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d88f3a61f031d1f51035876d6660d950fa4575
--- /dev/null
+++ b/eval/pab/experiments/components_ablation.py
@@ -0,0 +1,46 @@
+from utils import generate_func, read_prompt_list
+
+import videosys
+from videosys import OpenSoraConfig, OpenSoraPipeline
+from videosys.models.open_sora import OpenSoraPABConfig
+
+
+def wo_spatial(prompt_list):
+ pab_config = OpenSoraPABConfig(spatial_broadcast=False)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
+
+
+def wo_temporal(prompt_list):
+ pab_config = OpenSoraPABConfig(temporal_broadcast=False)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
+
+
+def wo_cross(prompt_list):
+ pab_config = OpenSoraPABConfig(cross_broadcast=False)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
+
+
+def wo_mlp(prompt_list):
+ pab_config = OpenSoraPABConfig(mlp_skip=False)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
+
+
+if __name__ == "__main__":
+ videosys.initialize(42)
+ prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
+ wo_spatial(prompt_list)
+ wo_temporal(prompt_list)
+ wo_cross(prompt_list)
+ wo_mlp(prompt_list)
diff --git a/eval/pab/experiments/latte.py b/eval/pab/experiments/latte.py
new file mode 100644
index 0000000000000000000000000000000000000000..5748dbaf78b6b8a9af784aea7188f4719d1aaf8c
--- /dev/null
+++ b/eval/pab/experiments/latte.py
@@ -0,0 +1,57 @@
+from utils import generate_func, read_prompt_list
+
+import videosys
+from videosys import LatteConfig, LattePipeline
+from videosys.models.latte import LattePABConfig
+
+
+def eval_base(prompt_list):
+ config = LatteConfig()
+ pipeline = LattePipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
+
+
+def eval_pab1(prompt_list):
+ pab_config = LattePABConfig(
+ spatial_gap=2,
+ temporal_gap=3,
+ cross_gap=6,
+ )
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = LattePipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
+
+
+def eval_pab2(prompt_list):
+ pab_config = LattePABConfig(
+ spatial_gap=3,
+ temporal_gap=4,
+ cross_gap=7,
+ )
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = LattePipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
+
+
+def eval_pab3(prompt_list):
+ pab_config = LattePABConfig(
+ spatial_gap=4,
+ temporal_gap=6,
+ cross_gap=9,
+ )
+ config = LatteConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = LattePipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
+
+
+if __name__ == "__main__":
+ videosys.initialize(42)
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
+ eval_base(prompt_list)
+ eval_pab1(prompt_list)
+ eval_pab2(prompt_list)
+ eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/opensora.py b/eval/pab/experiments/opensora.py
new file mode 100644
index 0000000000000000000000000000000000000000..7799c67308704bb2c825996ca8d800f95ba5d2c6
--- /dev/null
+++ b/eval/pab/experiments/opensora.py
@@ -0,0 +1,44 @@
+from utils import generate_func, read_prompt_list
+
+import videosys
+from videosys import OpenSoraConfig, OpenSoraPipeline
+from videosys.models.open_sora import OpenSoraPABConfig
+
+
+def eval_base(prompt_list):
+ config = OpenSoraConfig()
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
+
+
+def eval_pab1(prompt_list):
+ config = OpenSoraConfig(enable_pab=True)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
+
+
+def eval_pab2(prompt_list):
+ pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
+
+
+def eval_pab3(prompt_list):
+ pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
+ config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
+
+
+if __name__ == "__main__":
+ videosys.initialize(42)
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
+ eval_base(prompt_list)
+ eval_pab1(prompt_list)
+ eval_pab2(prompt_list)
+ eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/opensora_plan.py b/eval/pab/experiments/opensora_plan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4e8efc955e6ec469aa4a40d20abab31a8481a42
--- /dev/null
+++ b/eval/pab/experiments/opensora_plan.py
@@ -0,0 +1,57 @@
+from utils import generate_func, read_prompt_list
+
+import videosys
+from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
+from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
+
+
+def eval_base(prompt_list):
+ config = OpenSoraPlanConfig()
+ pipeline = OpenSoraPlanPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
+
+
+def eval_pab1(prompt_list):
+ pab_config = OpenSoraPlanPABConfig(
+ spatial_gap=2,
+ temporal_gap=4,
+ cross_gap=6,
+ )
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPlanPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
+
+
+def eval_pab2(prompt_list):
+ pab_config = OpenSoraPlanPABConfig(
+ spatial_gap=3,
+ temporal_gap=5,
+ cross_gap=7,
+ )
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPlanPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
+
+
+def eval_pab3(prompt_list):
+ pab_config = OpenSoraPlanPABConfig(
+ spatial_gap=5,
+ temporal_gap=7,
+ cross_gap=9,
+ )
+ config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
+ pipeline = OpenSoraPlanPipeline(config)
+
+ generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
+
+
+if __name__ == "__main__":
+ videosys.initialize(42)
+ prompt_list = read_prompt_list("vbench/VBench_full_info.json")
+ eval_base(prompt_list)
+ eval_pab1(prompt_list)
+ eval_pab2(prompt_list)
+ eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/utils.py b/eval/pab/experiments/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb52309fda21056b8ba352696e9ca4cf1fe1788e
--- /dev/null
+++ b/eval/pab/experiments/utils.py
@@ -0,0 +1,22 @@
+import json
+import os
+
+import tqdm
+
+from videosys.utils.utils import set_seed
+
+
+def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
+ kwargs["verbose"] = False
+ for prompt in tqdm.tqdm(prompt_list):
+ for l in range(loop):
+ set_seed(l)
+ video = pipeline.generate(prompt, **kwargs).video[0]
+ pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
+
+
+def read_prompt_list(prompt_list_path):
+ with open(prompt_list_path, "r") as f:
+ prompt_list = json.load(f)
+ prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
+ return prompt_list
diff --git a/eval/pab/vbench/VBench_full_info.json b/eval/pab/vbench/VBench_full_info.json
new file mode 100644
index 0000000000000000000000000000000000000000..e60c40eb0050a5304791490972be3b32de309e4a
--- /dev/null
+++ b/eval/pab/vbench/VBench_full_info.json
@@ -0,0 +1,9132 @@
+[
+ {
+ "prompt_en": "In a still frame, a stop sign",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "a toilet, frozen in time",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "a laptop, frozen in time",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of alley",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of bar",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of barn",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of bathroom",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of bedroom",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of cliff",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, courtyard",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, gas station",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of house",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "indoor gymnasium, frozen in time",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of indoor library",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of kitchen",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of palace",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, parking lot",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, phone booth",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of restaurant",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of tower",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a bowl",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of an apple",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a bench",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a bed",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a chair",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a cup",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a dining table",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, a pear",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a bunch of grapes",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a bowl on the kitchen counter",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a beautiful, handcrafted ceramic bowl",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of an antique bowl",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of an exquisite mahogany dining table",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a wooden bench in the park",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a beautiful wrought-iron bench surrounded by blooming flowers",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, a park bench with a view of the lake",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a vintage rocking chair was placed on the porch",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the jail cell was small and dimly lit, with cold, steel bars",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the phone booth was tucked away in a quiet alley",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "a dilapidated phone booth stood as a relic of a bygone era on the sidewalk, frozen in time",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the old red barn stood weathered and iconic against the backdrop of the countryside",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a picturesque barn was painted a warm shade of red and nestled in a picturesque meadow",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, within the desolate desert, an oasis unfolded, characterized by the stoic presence of palm trees and a motionless, glassy pool of water",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, the Parthenon's majestic Doric columns stand in serene solitude atop the Acropolis, framed by the tranquil Athenian landscape",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, the Temple of Hephaestus, with its timeless Doric grace, stands stoically against the backdrop of a quiet Athens",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, the ornate Victorian streetlamp stands solemnly, adorned with intricate ironwork and stained glass panels",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the Stonehenge presented itself as an enigmatic puzzle, each colossal stone meticulously placed against the backdrop of tranquility",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, in the vast desert, an oasis nestled among dunes, featuring tall palm trees and an air of serenity",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "static view on a desert scene with an oasis, palm trees, and a clear, calm pool of water",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of an ornate Victorian streetlamp standing on a cobblestone street corner, illuminating the empty night",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a tranquil lakeside cabin nestled among tall pines, its reflection mirrored perfectly in the calm water",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, a vintage gas lantern, adorned with intricate details, gracing a historic cobblestone square",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, a tranquil Japanese tea ceremony room, with tatami mats, a delicate tea set, and a bonsai tree in the corner",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the Parthenon stands resolute in its classical elegance, a timeless symbol of Athens' cultural legacy",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the heart of Plaka, the neoclassical architecture of the old city harmonizes with the ancient ruins",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the desolate beauty of the American Southwest, Chaco Canyon's ancient ruins whispered tales of an enigmatic civilization that once thrived amidst the arid landscapes",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of at the edge of the Arabian Desert, the ancient city of Petra beckoned with its enigmatic rock-carved fa\u00e7ades",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, amidst the cobblestone streets, an Art Nouveau lamppost stood tall",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the quaint village square, a traditional wrought-iron streetlamp featured delicate filigree patterns and amber-hued glass panels",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of the lampposts were adorned with Art Deco motifs, their geometric shapes and frosted glass creating a sense of vintage glamour",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, in the picturesque square, a Gothic-style lamppost adorned with intricate stone carvings added a touch of medieval charm to the setting",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, in the heart of the old city, a row of ornate lantern-style streetlamps bathed the narrow alleyway in a warm, welcoming light",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the heart of the Utah desert, a massive sandstone arch spanned the horizon",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the Arizona desert, a massive stone bridge arched across a rugged canyon",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the corner of the minimalist tea room, a bonsai tree added a touch of nature's beauty to the otherwise simple and elegant space",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, amidst the hushed ambiance of the traditional tea room, a meticulously arranged tea set awaited, with porcelain cups, a bamboo whisk",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, nestled in the Zen garden, a rustic teahouse featured tatami seating and a traditional charcoal brazier",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a country estate's library featured elegant wooden shelves",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of beneath the shade of a solitary oak tree, an old wooden park bench sat patiently",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of beside a tranquil pond, a weeping willow tree draped its branches gracefully over the water's surface, creating a serene tableau of reflection and calm",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the Zen garden, a perfectly raked gravel path led to a serene rock garden",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, a tranquil pond was fringed by weeping cherry trees, their blossoms drifting lazily onto the glassy surface",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "In a still frame, within the historic library's reading room, rows of antique leather chairs and mahogany tables offered a serene haven for literary contemplation",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of a peaceful orchid garden showcased a variety of delicate blooms",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "A tranquil tableau of in the serene courtyard, a centuries-old stone well stood as a symbol of a bygone era, its mossy stones bearing witness to the passage of time",
+ "dimension": [
+ "temporal_flickering"
+ ]
+ },
+ {
+ "prompt_en": "a bird and a cat",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bird and cat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cat and a dog",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "cat and dog"
+ }
+ }
+ },
+ {
+ "prompt_en": "a dog and a horse",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "dog and horse"
+ }
+ }
+ },
+ {
+ "prompt_en": "a horse and a sheep",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "horse and sheep"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sheep and a cow",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "sheep and cow"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cow and an elephant",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "cow and elephant"
+ }
+ }
+ },
+ {
+ "prompt_en": "an elephant and a bear",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "elephant and bear"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bear and a zebra",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bear and zebra"
+ }
+ }
+ },
+ {
+ "prompt_en": "a zebra and a giraffe",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "zebra and giraffe"
+ }
+ }
+ },
+ {
+ "prompt_en": "a giraffe and a bird",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "giraffe and bird"
+ }
+ }
+ },
+ {
+ "prompt_en": "a chair and a couch",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "chair and couch"
+ }
+ }
+ },
+ {
+ "prompt_en": "a couch and a potted plant",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "couch and potted plant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a potted plant and a tv",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "potted plant and tv"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tv and a laptop",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "tv and laptop"
+ }
+ }
+ },
+ {
+ "prompt_en": "a laptop and a remote",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "laptop and remote"
+ }
+ }
+ },
+ {
+ "prompt_en": "a remote and a keyboard",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "remote and keyboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a keyboard and a cell phone",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "keyboard and cell phone"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cell phone and a book",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "cell phone and book"
+ }
+ }
+ },
+ {
+ "prompt_en": "a book and a clock",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "book and clock"
+ }
+ }
+ },
+ {
+ "prompt_en": "a clock and a backpack",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "clock and backpack"
+ }
+ }
+ },
+ {
+ "prompt_en": "a backpack and an umbrella",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "backpack and umbrella"
+ }
+ }
+ },
+ {
+ "prompt_en": "an umbrella and a handbag",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "umbrella and handbag"
+ }
+ }
+ },
+ {
+ "prompt_en": "a handbag and a tie",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "handbag and tie"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tie and a suitcase",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "tie and suitcase"
+ }
+ }
+ },
+ {
+ "prompt_en": "a suitcase and a vase",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "suitcase and vase"
+ }
+ }
+ },
+ {
+ "prompt_en": "a vase and scissors",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "vase and scissors"
+ }
+ }
+ },
+ {
+ "prompt_en": "scissors and a teddy bear",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "scissors and teddy bear"
+ }
+ }
+ },
+ {
+ "prompt_en": "a teddy bear and a frisbee",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "teddy bear and frisbee"
+ }
+ }
+ },
+ {
+ "prompt_en": "a frisbee and skis",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "frisbee and skis"
+ }
+ }
+ },
+ {
+ "prompt_en": "skis and a snowboard",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "skis and snowboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a snowboard and a sports ball",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "snowboard and sports ball"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sports ball and a kite",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "sports ball and kite"
+ }
+ }
+ },
+ {
+ "prompt_en": "a kite and a baseball bat",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "kite and baseball bat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball bat and a baseball glove",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "baseball bat and baseball glove"
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball glove and a skateboard",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "baseball glove and skateboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a skateboard and a surfboard",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "skateboard and surfboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a surfboard and a tennis racket",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "surfboard and tennis racket"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tennis racket and a bottle",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "tennis racket and bottle"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bottle and a chair",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bottle and chair"
+ }
+ }
+ },
+ {
+ "prompt_en": "an airplane and a train",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "airplane and train"
+ }
+ }
+ },
+ {
+ "prompt_en": "a train and a boat",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "train and boat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a boat and an airplane",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "boat and airplane"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bicycle and a car",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bicycle and car"
+ }
+ }
+ },
+ {
+ "prompt_en": "a car and a motorcycle",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "car and motorcycle"
+ }
+ }
+ },
+ {
+ "prompt_en": "a motorcycle and a bus",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "motorcycle and bus"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bus and a traffic light",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bus and traffic light"
+ }
+ }
+ },
+ {
+ "prompt_en": "a traffic light and a fire hydrant",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "traffic light and fire hydrant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a fire hydrant and a stop sign",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "fire hydrant and stop sign"
+ }
+ }
+ },
+ {
+ "prompt_en": "a stop sign and a parking meter",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "stop sign and parking meter"
+ }
+ }
+ },
+ {
+ "prompt_en": "a parking meter and a truck",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "parking meter and truck"
+ }
+ }
+ },
+ {
+ "prompt_en": "a truck and a bicycle",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "truck and bicycle"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toilet and a hair drier",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "toilet and hair drier"
+ }
+ }
+ },
+ {
+ "prompt_en": "a hair drier and a toothbrush",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "hair drier and toothbrush"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toothbrush and a sink",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "toothbrush and sink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sink and a toilet",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "sink and toilet"
+ }
+ }
+ },
+ {
+ "prompt_en": "a wine glass and a chair",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "wine glass and chair"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cup and a couch",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "cup and couch"
+ }
+ }
+ },
+ {
+ "prompt_en": "a fork and a potted plant",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "fork and potted plant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a knife and a tv",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "knife and tv"
+ }
+ }
+ },
+ {
+ "prompt_en": "a spoon and a laptop",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "spoon and laptop"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bowl and a remote",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bowl and remote"
+ }
+ }
+ },
+ {
+ "prompt_en": "a banana and a keyboard",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "banana and keyboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "an apple and a cell phone",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "apple and cell phone"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sandwich and a book",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "sandwich and book"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange and a clock",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "orange and clock"
+ }
+ }
+ },
+ {
+ "prompt_en": "broccoli and a backpack",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "broccoli and backpack"
+ }
+ }
+ },
+ {
+ "prompt_en": "a carrot and an umbrella",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "carrot and umbrella"
+ }
+ }
+ },
+ {
+ "prompt_en": "a hot dog and a handbag",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "hot dog and handbag"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pizza and a tie",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "pizza and tie"
+ }
+ }
+ },
+ {
+ "prompt_en": "a donut and a suitcase",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "donut and suitcase"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cake and a vase",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "cake and vase"
+ }
+ }
+ },
+ {
+ "prompt_en": "an oven and scissors",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "oven and scissors"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toaster and a teddy bear",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "toaster and teddy bear"
+ }
+ }
+ },
+ {
+ "prompt_en": "a microwave and a frisbee",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "microwave and frisbee"
+ }
+ }
+ },
+ {
+ "prompt_en": "a refrigerator and skis",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "refrigerator and skis"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bicycle and an airplane",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "bicycle and airplane"
+ }
+ }
+ },
+ {
+ "prompt_en": "a car and a train",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "car and train"
+ }
+ }
+ },
+ {
+ "prompt_en": "a motorcycle and a boat",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "motorcycle and boat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a person and a toilet",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "person and toilet"
+ }
+ }
+ },
+ {
+ "prompt_en": "a person and a hair drier",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "person and hair drier"
+ }
+ }
+ },
+ {
+ "prompt_en": "a person and a toothbrush",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "person and toothbrush"
+ }
+ }
+ },
+ {
+ "prompt_en": "a person and a sink",
+ "dimension": [
+ "multiple_objects"
+ ],
+ "auxiliary_info": {
+ "multiple_objects": {
+ "object": "person and sink"
+ }
+ }
+ },
+ {
+ "prompt_en": "A person is riding a bike",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is marching",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is roller skating",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is tasting beer",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is clapping",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is drawing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is petting animal (not cat)",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is eating watermelon",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing harp",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is wrestling",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is riding scooter",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is sweeping floor",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is skateboarding",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is dunking basketball",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing flute",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is stretching leg",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is tying tie",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is skydiving",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shooting goal (soccer)",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing piano",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is finger snapping",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is canoeing or kayaking",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is laughing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is digging",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is clay pottery making",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shooting basketball",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is bending back",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shaking hands",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is bandaging",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is push up",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is catching or throwing frisbee",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing trumpet",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is flying kite",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is filling eyebrows",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shuffling cards",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is folding clothes",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is smoking",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is tai chi",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is squat",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing controller",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is throwing axe",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is giving or receiving award",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is air drumming",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is taking a shower",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is planting trees",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is sharpening knives",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is robot dancing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is rock climbing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is hula hooping",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is writing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is bungee jumping",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is pushing cart",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is cleaning windows",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is cutting watermelon",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is cheerleading",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is washing hands",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is ironing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is cutting nails",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is hugging",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is trimming or shaving beard",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is jogging",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is making bed",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is washing dishes",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is grooming dog",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is doing laundry",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is knitting",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is reading book",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is baby waking up",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is massaging legs",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is brushing teeth",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is crawling baby",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is motorcycling",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is driving car",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is sticking tongue out",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shaking head",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is sword fighting",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is doing aerobics",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is strumming guitar",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is riding or walking with horse",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is archery",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is catching or throwing baseball",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is playing chess",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is rock scissors paper",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is using computer",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is arranging flowers",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is bending metal",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is ice skating",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is climbing a rope",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is crying",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is dancing ballet",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is getting a haircut",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is running on treadmill",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is kissing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is counting money",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is barbequing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is peeling apples",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is milking cow",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is shining shoes",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is making snowman",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "A person is sailing",
+ "dimension": [
+ "human_action"
+ ]
+ },
+ {
+ "prompt_en": "a person swimming in ocean",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person giving a presentation to a room full of colleagues",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person washing the dishes",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person eating a burger",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person walking in the snowstorm",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person drinking coffee in a cafe",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person playing guitar",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bicycle leaning against a tree",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bicycle gliding through a snowy field",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bicycle slowing down to stop",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bicycle accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a car stuck in traffic during rush hour",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a car turning a corner",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a car slowing down to stop",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a car accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a motorcycle cruising along a coastal highway",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a motorcycle turning a corner",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a motorcycle slowing down to stop",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a motorcycle gliding through a snowy field",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a motorcycle accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an airplane soaring through a clear blue sky",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an airplane taking off",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an airplane landing smoothly on a runway",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an airplane accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bus turning a corner",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bus stuck in traffic during rush hour",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bus accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a train speeding down the tracks",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a train crossing over a tall bridge",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a train accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a truck turning a corner",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a truck anchored in a tranquil bay",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a truck stuck in traffic during rush hour",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a truck slowing down to stop",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a truck accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a boat sailing smoothly on a calm lake",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a boat slowing down to stop",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a boat accelerating to gain speed",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bird soaring gracefully in the sky",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bird building a nest from twigs and leaves",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bird flying over a snowy forest",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cat grooming itself meticulously with its tongue",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cat playing in park",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cat drinking water",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cat running happily",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a dog enjoying a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a dog playing in park",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a dog drinking water",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a dog running happily",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a horse bending down to drink water from a river",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a horse galloping across an open field",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a horse taking a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a horse running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a sheep bending down to drink water from a river",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a sheep taking a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a sheep running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cow bending down to drink water from a river",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cow chewing cud while resting in a tranquil barn",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a cow running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an elephant spraying itself with water using its trunk to cool down",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an elephant taking a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "an elephant running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bear catching a salmon in its powerful jaws",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bear sniffing the air for scents of food",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bear climbing a tree",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a bear hunting for prey",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a zebra bending down to drink water from a river",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a zebra running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a zebra taking a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a giraffe bending down to drink water from a river",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a giraffe taking a peaceful walk",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a giraffe running to join a herd of its kind",
+ "dimension": [
+ "subject_consistency",
+ "dynamic_degree",
+ "motion_smoothness"
+ ]
+ },
+ {
+ "prompt_en": "a person",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "person"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bicycle",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bicycle"
+ }
+ }
+ },
+ {
+ "prompt_en": "a car",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "car"
+ }
+ }
+ },
+ {
+ "prompt_en": "a motorcycle",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "motorcycle"
+ }
+ }
+ },
+ {
+ "prompt_en": "an airplane",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "airplane"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bus",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bus"
+ }
+ }
+ },
+ {
+ "prompt_en": "a train",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "train"
+ }
+ }
+ },
+ {
+ "prompt_en": "a truck",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "truck"
+ }
+ }
+ },
+ {
+ "prompt_en": "a boat",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "boat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a traffic light",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "traffic light"
+ }
+ }
+ },
+ {
+ "prompt_en": "a fire hydrant",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "fire hydrant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a stop sign",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "stop sign"
+ }
+ }
+ },
+ {
+ "prompt_en": "a parking meter",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "parking meter"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bench",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bench"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bird",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bird"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cat",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "cat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a dog",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "dog"
+ }
+ }
+ },
+ {
+ "prompt_en": "a horse",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "horse"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sheep",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "sheep"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cow",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "cow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an elephant",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "elephant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bear",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bear"
+ }
+ }
+ },
+ {
+ "prompt_en": "a zebra",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "zebra"
+ }
+ }
+ },
+ {
+ "prompt_en": "a giraffe",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "giraffe"
+ }
+ }
+ },
+ {
+ "prompt_en": "a backpack",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "backpack"
+ }
+ }
+ },
+ {
+ "prompt_en": "an umbrella",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "umbrella"
+ }
+ }
+ },
+ {
+ "prompt_en": "a handbag",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "handbag"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tie",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "tie"
+ }
+ }
+ },
+ {
+ "prompt_en": "a suitcase",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "suitcase"
+ }
+ }
+ },
+ {
+ "prompt_en": "a frisbee",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "frisbee"
+ }
+ }
+ },
+ {
+ "prompt_en": "skis",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "skis"
+ }
+ }
+ },
+ {
+ "prompt_en": "a snowboard",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "snowboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sports ball",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "sports ball"
+ }
+ }
+ },
+ {
+ "prompt_en": "a kite",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "kite"
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball bat",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "baseball bat"
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball glove",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "baseball glove"
+ }
+ }
+ },
+ {
+ "prompt_en": "a skateboard",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "skateboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a surfboard",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "surfboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tennis racket",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "tennis racket"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bottle",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bottle"
+ }
+ }
+ },
+ {
+ "prompt_en": "a wine glass",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "wine glass"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cup",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "cup"
+ }
+ }
+ },
+ {
+ "prompt_en": "a fork",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "fork"
+ }
+ }
+ },
+ {
+ "prompt_en": "a knife",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "knife"
+ }
+ }
+ },
+ {
+ "prompt_en": "a spoon",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "spoon"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bowl",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bowl"
+ }
+ }
+ },
+ {
+ "prompt_en": "a banana",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "banana"
+ }
+ }
+ },
+ {
+ "prompt_en": "an apple",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "apple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sandwich",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "sandwich"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "broccoli",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "broccoli"
+ }
+ }
+ },
+ {
+ "prompt_en": "a carrot",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "carrot"
+ }
+ }
+ },
+ {
+ "prompt_en": "a hot dog",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "hot dog"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pizza",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "pizza"
+ }
+ }
+ },
+ {
+ "prompt_en": "a donut",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "donut"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cake",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "cake"
+ }
+ }
+ },
+ {
+ "prompt_en": "a chair",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "chair"
+ }
+ }
+ },
+ {
+ "prompt_en": "a couch",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "couch"
+ }
+ }
+ },
+ {
+ "prompt_en": "a potted plant",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "potted plant"
+ }
+ }
+ },
+ {
+ "prompt_en": "a bed",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "bed"
+ }
+ }
+ },
+ {
+ "prompt_en": "a dining table",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "dining table"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toilet",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "toilet"
+ }
+ }
+ },
+ {
+ "prompt_en": "a tv",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "tv"
+ }
+ }
+ },
+ {
+ "prompt_en": "a laptop",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "laptop"
+ }
+ }
+ },
+ {
+ "prompt_en": "a remote",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "remote"
+ }
+ }
+ },
+ {
+ "prompt_en": "a keyboard",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "keyboard"
+ }
+ }
+ },
+ {
+ "prompt_en": "a cell phone",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "cell phone"
+ }
+ }
+ },
+ {
+ "prompt_en": "a microwave",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "microwave"
+ }
+ }
+ },
+ {
+ "prompt_en": "an oven",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "oven"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toaster",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "toaster"
+ }
+ }
+ },
+ {
+ "prompt_en": "a sink",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "sink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a refrigerator",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "refrigerator"
+ }
+ }
+ },
+ {
+ "prompt_en": "a book",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "book"
+ }
+ }
+ },
+ {
+ "prompt_en": "a clock",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "clock"
+ }
+ }
+ },
+ {
+ "prompt_en": "a vase",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "vase"
+ }
+ }
+ },
+ {
+ "prompt_en": "scissors",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "scissors"
+ }
+ }
+ },
+ {
+ "prompt_en": "a teddy bear",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "teddy bear"
+ }
+ }
+ },
+ {
+ "prompt_en": "a hair drier",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "hair drier"
+ }
+ }
+ },
+ {
+ "prompt_en": "a toothbrush",
+ "dimension": [
+ "object_class"
+ ],
+ "auxiliary_info": {
+ "object_class": {
+ "object": "toothbrush"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white bicycle",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white car",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white bird",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black cat",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white cat",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange cat",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow cat",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white umbrella",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white suitcase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white bowl",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white chair",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white clock",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a red vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "red"
+ }
+ }
+ },
+ {
+ "prompt_en": "a green vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "green"
+ }
+ }
+ },
+ {
+ "prompt_en": "a blue vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "blue"
+ }
+ }
+ },
+ {
+ "prompt_en": "a yellow vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "yellow"
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "orange"
+ }
+ }
+ },
+ {
+ "prompt_en": "a purple vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "purple"
+ }
+ }
+ },
+ {
+ "prompt_en": "a pink vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "pink"
+ }
+ }
+ },
+ {
+ "prompt_en": "a black vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "black"
+ }
+ }
+ },
+ {
+ "prompt_en": "a white vase",
+ "dimension": [
+ "color"
+ ],
+ "auxiliary_info": {
+ "color": {
+ "color": "white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "The bund Shanghai, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "An astronaut flying in space, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, Van Gogh style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "Van Gogh style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, oil painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "oil painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "by Hokusai, in the style of Ukiyo"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, black and white",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "black and white"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pixel art",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "pixel art"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in cyberpunk style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "in cyberpunk style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, animated style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "animated style"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, watercolor painting",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "watercolor painting"
+ }
+ }
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, surrealism style",
+ "dimension": [
+ "appearance_style"
+ ],
+ "auxiliary_info": {
+ "appearance_style": {
+ "appearance_style": "surrealism style"
+ }
+ }
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in super slow motion",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom in",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom out",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan left",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan right",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt up",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt down",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, with an intense shaking effect",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, featuring a steady and smooth perspective",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, racking focus",
+ "dimension": [
+ "temporal_style"
+ ]
+ },
+ {
+ "prompt_en": "Close up of grapes on a rotating table.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Turtle swimming in ocean.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A storm trooper vacuuming the beach.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A panda standing on a surfboard in the ocean in sunset.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut feeding ducks on a sunny afternoon, reflection from the water.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Two pandas discussing an academic paper.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Sunset time lapse at the beach with moving clouds and colors in the sky.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A fat rabbit wearing a purple robe walking through a fantasy landscape.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A koala bear playing piano in the forest.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut flying in space.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Fireworks.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An animated painting of fluffy white clouds moving in sky.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Flying through fantasy landscapes.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A bigfoot walking in the snowstorm.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A squirrel eating a burger.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cat wearing sunglasses and working as a lifeguard at a pool.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Splash of turquoise water in extreme slow motion, alpha channel included.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "an ice cream is melting on the table.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "a drone flying over a snowy forest.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "a shark is swimming in the ocean.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Aerial panoramic video from a drone of a fantasy land.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "a teddy bear is swimming in the ocean.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "time lapse of sunrise on mars.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "golden fish swimming in the ocean.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An artist brush painting on a canvas close up.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A drone view of celebration with Christmas tree and fireworks, starry sky - background.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "happy dog wearing a yellow turtleneck, studio, portrait, facing camera, dark background",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Origami dancers in white paper, 3D render, on white background, studio shot, dancing modern dance.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Campfire at night in a snowy forest with starry sky in the background.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "a fantasy landscape",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A 3D model of a 1800s victorian house.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "this is how I do makeup in the morning.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A raccoon that looks like a turtle, digital art.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Robot dancing in Times Square.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Busy freeway at night.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Balloon full of water exploding in extreme slow motion.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An astronaut is riding a horse in the space in a photorealistic style.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Macro slo-mo. Slow motion cropped closeup of roasted coffee beans falling into an empty bowl.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Sewing machine, old sewing machine working.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Motion colour drop in water, ink swirling in water, colourful ink in water, abstraction fancy dream cloud of ink.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Few big purple plums rotating on the turntable. water drops appear on the skin during rotation. isolated on the white background. close-up. macro.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Vampire makeup face of beautiful girl, red contact lenses.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Ashtray full of butts on table, smoke flowing on black background, close-up",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Pacific coast, carmel by the sea ocean and waves.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A teddy bear is playing drum kit in NYC Times Square.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A corgi is playing drum kit.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An Iron man is playing the electronic guitar, high electronic guitar.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A raccoon is playing the electronic guitar.",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Vincent van Gogh",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A corgi's head depicted as an explosion of a nebula",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A fantasy landscape",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A future where humans have achieved teleportation technology",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A jellyfish floating through the ocean, with bioluminescent tentacles",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A Mars rover moving on Mars",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A panda drinking coffee in a cafe in Paris",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A space shuttle launching into orbit, with flames and smoke billowing out from the engines",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A steam train moving on a mountainside",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A super cool giant robot in Cyberpunk Beijing",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A tropical beach at sunrise, with palm trees and crystal-clear water in the foreground",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Cinematic shot of Van Gogh's selfie, Van Gogh style",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Gwen Stacy reading a book",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Iron Man flying in the sky",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, oil painting",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Yoda playing guitar on the stage",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Vincent van Gogh",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A car moving slowly on an empty street, rainy evening",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cat eating food out of a bowl",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cat wearing sunglasses at a pool",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A confused panda in calculus class",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cute fluffy panda eating Chinese food in a restaurant",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cute happy Corgi playing in park, sunset",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A cute raccoon playing guitar in a boat on the ocean",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A happy fuzzy panda playing guitar nearby a campfire, snow mountain in the background",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A lightning striking atop of eiffel tower, dark clouds in the sky",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A modern art museum, with colorful paintings",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A panda cooking in the kitchen",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A panda playing on a swing set",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A polar bear is playing guitar",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A raccoon dressed in suit playing the trumpet, stage background",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A shark swimming in clear Caribbean ocean",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A super robot protecting city",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "A teddy bear washing the dishes",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An epic tornado attacking above a glowing city at night, the tornado is made of smoke",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with umbrellas",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Clown fish swimming through the coral reef",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Hyper-realistic spaceship landing on Mars",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "The bund Shanghai, vibrant color",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Vincent van Gogh is painting in the room",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "Yellow flowers swing in the wind",
+ "dimension": [
+ "overall_consistency",
+ "aesthetic_quality",
+ "imaging_quality"
+ ]
+ },
+ {
+ "prompt_en": "alley",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "alley"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "amusement park",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "amusement park"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "aquarium",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "aquarium"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "arch",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "arch"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "art gallery",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "art gallery"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "bathroom",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "bathroom"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "bakery shop",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "bakery shop"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "ballroom",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "ballroom"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "bar",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "bar"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "barn",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "barn"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "basement",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "basement"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "beach",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "beach"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "bedroom",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "bedroom"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "bridge",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "bridge"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "botanical garden",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "botanical garden"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "cafeteria",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "cafeteria"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "campsite",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "campsite"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "campus",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "campus"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "carrousel",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "carrousel"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "castle",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "castle"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "cemetery",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "cemetery"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "classroom",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "classroom"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "cliff",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "cliff"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "crosswalk",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "crosswalk"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "construction site",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "construction site"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "corridor",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "corridor"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "courtyard",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "courtyard"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "desert",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "desert"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "downtown",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "downtown"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "driveway",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "driveway"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "farm",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "farm"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "food court",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "food court"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "football field",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "football field"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "forest road",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "forest road"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "fountain",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "fountain"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "gas station",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "gas station"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "glacier",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "glacier"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "golf course",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "golf course"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "indoor gymnasium",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "indoor gymnasium"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "harbor",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "harbor"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "highway",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "highway"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "hospital",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "hospital"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "house",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "house"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "iceberg",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "iceberg"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "industrial area",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "industrial area"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "jail cell",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "jail cell"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "junkyard",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "junkyard"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "kitchen",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "kitchen"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "indoor library",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "indoor library"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "lighthouse",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "lighthouse"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "laboratory",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "laboratory"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "mansion",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "mansion"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "marsh",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "marsh"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "mountain",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "mountain"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "indoor movie theater",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "indoor movie theater"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "indoor museum",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "indoor museum"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "music studio",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "music studio"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "nursery",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "nursery"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "ocean",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "ocean"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "office",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "office"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "palace",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "palace"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "parking lot",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "parking lot"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "pharmacy",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "pharmacy"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "phone booth",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "phone booth"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "raceway",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "raceway"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "restaurant",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "restaurant"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "river",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "river"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "science museum",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "science museum"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "shower",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "shower"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "ski slope",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "ski slope"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "sky",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "sky"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "skyscraper",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "skyscraper"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "baseball stadium",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "baseball stadium"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "staircase",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "staircase"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "street",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "street"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "supermarket",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "supermarket"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "indoor swimming pool",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "indoor swimming pool"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "tower",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "tower"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "outdoor track",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "outdoor track"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "train railway",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "train railway"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "train station platform",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "train station platform"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "underwater coral reef",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "underwater coral reef"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "valley",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "valley"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "volcano",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "volcano"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "waterfall",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "waterfall"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "windmill",
+ "dimension": [
+ "scene",
+ "background_consistency"
+ ],
+ "auxiliary_info": {
+ "scene": {
+ "scene": {
+ "scene": "windmill"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bicycle on the left of a car, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bicycle",
+ "object_b": "car",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a car on the right of a motorcycle, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "car",
+ "object_b": "motorcycle",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a motorcycle on the left of a bus, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "motorcycle",
+ "object_b": "bus",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bus on the right of a traffic light, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bus",
+ "object_b": "traffic light",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a traffic light on the left of a fire hydrant, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "traffic light",
+ "object_b": "fire hydrant",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a fire hydrant on the right of a stop sign, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "fire hydrant",
+ "object_b": "stop sign",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a stop sign on the left of a parking meter, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "stop sign",
+ "object_b": "parking meter",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a parking meter on the right of a bench, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "parking meter",
+ "object_b": "bench",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bench on the left of a truck, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bench",
+ "object_b": "truck",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a truck on the right of a bicycle, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "truck",
+ "object_b": "bicycle",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bird on the left of a cat, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bird",
+ "object_b": "cat",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a cat on the right of a dog, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "cat",
+ "object_b": "dog",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a dog on the left of a horse, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "dog",
+ "object_b": "horse",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a horse on the right of a sheep, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "horse",
+ "object_b": "sheep",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a sheep on the left of a cow, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "sheep",
+ "object_b": "cow",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a cow on the right of an elephant, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "cow",
+ "object_b": "elephant",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an elephant on the left of a bear, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "elephant",
+ "object_b": "bear",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bear on the right of a zebra, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bear",
+ "object_b": "zebra",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a zebra on the left of a giraffe, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "zebra",
+ "object_b": "giraffe",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a giraffe on the right of a bird, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "giraffe",
+ "object_b": "bird",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bottle on the left of a wine glass, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bottle",
+ "object_b": "wine glass",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a wine glass on the right of a cup, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "wine glass",
+ "object_b": "cup",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a cup on the left of a fork, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "cup",
+ "object_b": "fork",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a fork on the right of a knife, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "fork",
+ "object_b": "knife",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a knife on the left of a spoon, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "knife",
+ "object_b": "spoon",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a spoon on the right of a bowl, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "spoon",
+ "object_b": "bowl",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bowl on the left of a bottle, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bowl",
+ "object_b": "bottle",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a potted plant on the left of a remote, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "potted plant",
+ "object_b": "remote",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a remote on the right of a clock, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "remote",
+ "object_b": "clock",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a clock on the left of a vase, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "clock",
+ "object_b": "vase",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a vase on the right of scissors, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "vase",
+ "object_b": "scissors",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "scissors on the left of a teddy bear, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "scissors",
+ "object_b": "teddy bear",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a teddy bear on the right of a potted plant, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "teddy bear",
+ "object_b": "potted plant",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a frisbee on the left of a sports ball, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "frisbee",
+ "object_b": "sports ball",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a sports ball on the right of a baseball bat, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "sports ball",
+ "object_b": "baseball bat",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball bat on the left of a baseball glove, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "baseball bat",
+ "object_b": "baseball glove",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a baseball glove on the right of a tennis racket, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "baseball glove",
+ "object_b": "tennis racket",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a tennis racket on the left of a frisbee, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "tennis racket",
+ "object_b": "frisbee",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a toilet on the left of a hair drier, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "toilet",
+ "object_b": "hair drier",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a hair drier on the right of a toothbrush, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "hair drier",
+ "object_b": "toothbrush",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a toothbrush on the left of a sink, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "toothbrush",
+ "object_b": "sink",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a sink on the right of a toilet, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "sink",
+ "object_b": "toilet",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a chair on the left of a couch, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "chair",
+ "object_b": "couch",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a couch on the right of a bed, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "couch",
+ "object_b": "bed",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a bed on the left of a tv, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "bed",
+ "object_b": "tv",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a tv on the right of a dining table, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "tv",
+ "object_b": "dining table",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a dining table on the left of a chair, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "dining table",
+ "object_b": "chair",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an airplane on the left of a train, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "airplane",
+ "object_b": "train",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a train on the right of a boat, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "train",
+ "object_b": "boat",
+ "relationship": "on the right of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a boat on the left of an airplane, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "boat",
+ "object_b": "airplane",
+ "relationship": "on the left of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an oven on the top of a toaster, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "oven",
+ "object_b": "toaster",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an oven on the bottom of a toaster, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "oven",
+ "object_b": "toaster",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a toaster on the top of a microwave, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "toaster",
+ "object_b": "microwave",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a toaster on the bottom of a microwave, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "toaster",
+ "object_b": "microwave",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a microwave on the top of an oven, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "microwave",
+ "object_b": "oven",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a microwave on the bottom of an oven, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "microwave",
+ "object_b": "oven",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a banana on the top of an apple, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "banana",
+ "object_b": "apple",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a banana on the bottom of an apple, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "banana",
+ "object_b": "apple",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an apple on the top of a sandwich, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "apple",
+ "object_b": "sandwich",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an apple on the bottom of a sandwich, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "apple",
+ "object_b": "sandwich",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a sandwich on the top of an orange, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "sandwich",
+ "object_b": "orange",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a sandwich on the bottom of an orange, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "sandwich",
+ "object_b": "orange",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange on the top of a carrot, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "orange",
+ "object_b": "carrot",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "an orange on the bottom of a carrot, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "orange",
+ "object_b": "carrot",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a carrot on the top of a hot dog, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "carrot",
+ "object_b": "hot dog",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a carrot on the bottom of a hot dog, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "carrot",
+ "object_b": "hot dog",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a hot dog on the top of a pizza, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "hot dog",
+ "object_b": "pizza",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a hot dog on the bottom of a pizza, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "hot dog",
+ "object_b": "pizza",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a pizza on the top of a donut, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "pizza",
+ "object_b": "donut",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a pizza on the bottom of a donut, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "pizza",
+ "object_b": "donut",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a donut on the top of broccoli, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "donut",
+ "object_b": "broccoli",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a donut on the bottom of broccoli, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "donut",
+ "object_b": "broccoli",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "broccoli on the top of a banana, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "broccoli",
+ "object_b": "banana",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "broccoli on the bottom of a banana, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "broccoli",
+ "object_b": "banana",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "skis on the top of a snowboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "skis",
+ "object_b": "snowboard",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "skis on the bottom of a snowboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "skis",
+ "object_b": "snowboard",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a snowboard on the top of a kite, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "snowboard",
+ "object_b": "kite",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a snowboard on the bottom of a kite, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "snowboard",
+ "object_b": "kite",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a kite on the top of a skateboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "kite",
+ "object_b": "skateboard",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a kite on the bottom of a skateboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "kite",
+ "object_b": "skateboard",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a skateboard on the top of a surfboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "skateboard",
+ "object_b": "surfboard",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a skateboard on the bottom of a surfboard, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "skateboard",
+ "object_b": "surfboard",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a surfboard on the top of skis, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "surfboard",
+ "object_b": "skis",
+ "relationship": "on the top of"
+ }
+ }
+ }
+ },
+ {
+ "prompt_en": "a surfboard on the bottom of skis, front view",
+ "dimension": [
+ "spatial_relationship"
+ ],
+ "auxiliary_info": {
+ "spatial_relationship": {
+ "spatial_relationship": {
+ "object_a": "surfboard",
+ "object_b": "skis",
+ "relationship": "on the bottom of"
+ }
+ }
+ }
+ }
+]
diff --git a/eval/pab/vbench/cal_vbench.py b/eval/pab/vbench/cal_vbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec1cbbab64e9977983ae8c3349df1d8e0f03bdb0
--- /dev/null
+++ b/eval/pab/vbench/cal_vbench.py
@@ -0,0 +1,154 @@
+import argparse
+import json
+import os
+
+SEMANTIC_WEIGHT = 1
+QUALITY_WEIGHT = 4
+
+QUALITY_LIST = [
+ "subject consistency",
+ "background consistency",
+ "temporal flickering",
+ "motion smoothness",
+ "aesthetic quality",
+ "imaging quality",
+ "dynamic degree",
+]
+
+SEMANTIC_LIST = [
+ "object class",
+ "multiple objects",
+ "human action",
+ "color",
+ "spatial relationship",
+ "scene",
+ "appearance style",
+ "temporal style",
+ "overall consistency",
+]
+
+NORMALIZE_DIC = {
+ "subject consistency": {"Min": 0.1462, "Max": 1.0},
+ "background consistency": {"Min": 0.2615, "Max": 1.0},
+ "temporal flickering": {"Min": 0.6293, "Max": 1.0},
+ "motion smoothness": {"Min": 0.706, "Max": 0.9975},
+ "dynamic degree": {"Min": 0.0, "Max": 1.0},
+ "aesthetic quality": {"Min": 0.0, "Max": 1.0},
+ "imaging quality": {"Min": 0.0, "Max": 1.0},
+ "object class": {"Min": 0.0, "Max": 1.0},
+ "multiple objects": {"Min": 0.0, "Max": 1.0},
+ "human action": {"Min": 0.0, "Max": 1.0},
+ "color": {"Min": 0.0, "Max": 1.0},
+ "spatial relationship": {"Min": 0.0, "Max": 1.0},
+ "scene": {"Min": 0.0, "Max": 0.8222},
+ "appearance style": {"Min": 0.0009, "Max": 0.2855},
+ "temporal style": {"Min": 0.0, "Max": 0.364},
+ "overall consistency": {"Min": 0.0, "Max": 0.364},
+}
+
+DIM_WEIGHT = {
+ "subject consistency": 1,
+ "background consistency": 1,
+ "temporal flickering": 1,
+ "motion smoothness": 1,
+ "aesthetic quality": 1,
+ "imaging quality": 1,
+ "dynamic degree": 0.5,
+ "object class": 1,
+ "multiple objects": 1,
+ "human action": 1,
+ "color": 1,
+ "spatial relationship": 1,
+ "scene": 1,
+ "appearance style": 1,
+ "temporal style": 1,
+ "overall consistency": 1,
+}
+
+ordered_scaled_res = [
+ "total score",
+ "quality score",
+ "semantic score",
+ "subject consistency",
+ "background consistency",
+ "temporal flickering",
+ "motion smoothness",
+ "dynamic degree",
+ "aesthetic quality",
+ "imaging quality",
+ "object class",
+ "multiple objects",
+ "human action",
+ "color",
+ "spatial relationship",
+ "scene",
+ "appearance style",
+ "temporal style",
+ "overall consistency",
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--score_dir", required=True, type=str)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ res_postfix = "_eval_results.json"
+ info_postfix = "_full_info.json"
+ files = os.listdir(args.score_dir)
+ res_files = [x for x in files if res_postfix in x]
+ info_files = [x for x in files if info_postfix in x]
+ assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
+
+ full_results = {}
+ for res_file in res_files:
+ # first check if results is normal
+ info_file = res_file.split(res_postfix)[0] + info_postfix
+ with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
+ info = json.load(f)
+ assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
+ # read results
+ with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
+ data = json.load(f)
+ for key, val in data.items():
+ full_results[key] = format(val[0], ".4f")
+
+ scaled_results = {}
+ dims = set()
+ for key, val in full_results.items():
+ dim = key.replace("_", " ") if "_" in key else key
+ scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
+ NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
+ )
+ scaled_score *= DIM_WEIGHT[dim]
+ scaled_results[dim] = scaled_score
+ dims.add(dim)
+
+ assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
+
+ quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
+ semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
+ scaled_results["quality score"] = quality_score
+ scaled_results["semantic score"] = semantic_score
+ scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
+ QUALITY_WEIGHT + SEMANTIC_WEIGHT
+ )
+
+ formated_scaled_results = {"items": []}
+ for key in ordered_scaled_res:
+ formated_score = format(scaled_results[key] * 100, ".2f") + "%"
+ formated_scaled_results["items"].append({key: formated_score})
+
+ output_file_path = os.path.join(args.score_dir, "all_results.json")
+ with open(output_file_path, "w") as outfile:
+ json.dump(full_results, outfile, indent=4, sort_keys=True)
+ print(f"results saved to: {output_file_path}")
+
+ scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
+ with open(scaled_file_path, "w") as outfile:
+ json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
+ print(f"results saved to: {scaled_file_path}")
diff --git a/eval/pab/vbench/run_vbench.py b/eval/pab/vbench/run_vbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..32df0825502614fc3b2a1f7f56e3b2082ccb207c
--- /dev/null
+++ b/eval/pab/vbench/run_vbench.py
@@ -0,0 +1,52 @@
+import argparse
+
+import torch
+from vbench import VBench
+
+full_info_path = "./vbench/VBench_full_info.json"
+
+dimensions = [
+ "subject_consistency",
+ "imaging_quality",
+ "background_consistency",
+ "motion_smoothness",
+ "overall_consistency",
+ "human_action",
+ "multiple_objects",
+ "spatial_relationship",
+ "object_class",
+ "color",
+ "aesthetic_quality",
+ "appearance_style",
+ "temporal_flickering",
+ "scene",
+ "temporal_style",
+ "dynamic_degree",
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--video_path", required=True, type=str)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ save_path = args.video_path.replace("/samples/", "/vbench_out/")
+
+ kwargs = {}
+ kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
+
+ for dimension in dimensions:
+ my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
+ my_VBench.evaluate(
+ videos_path=args.video_path,
+ name=dimension,
+ local=False,
+ read_frame=False,
+ dimension_list=[dimension],
+ mode="vbench_standard",
+ **kwargs,
+ )
diff --git a/examples/cogvideo/sample.py b/examples/cogvideo/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9a394c2882eaf9debabdd3184d5a29e651e04cc
--- /dev/null
+++ b/examples/cogvideo/sample.py
@@ -0,0 +1,14 @@
+from videosys import CogVideoConfig, VideoSysEngine
+
+
+def run_base():
+ config = CogVideoConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+if __name__ == "__main__":
+ run_base()
diff --git a/examples/latte/sample.py b/examples/latte/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f421d831402611f457ec73fe8739162fbe113b
--- /dev/null
+++ b/examples/latte/sample.py
@@ -0,0 +1,24 @@
+from videosys import LatteConfig, VideoSysEngine
+
+
+def run_base():
+ config = LatteConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+def run_pab():
+ config = LatteConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+if __name__ == "__main__":
+ run_base()
+ # run_pab()
diff --git a/examples/open_sora/sample.py b/examples/open_sora/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..17a89f921d09ec5b71aaf98e210afc664aaa2385
--- /dev/null
+++ b/examples/open_sora/sample.py
@@ -0,0 +1,24 @@
+from videosys import OpenSoraConfig, VideoSysEngine
+
+
+def run_base():
+ config = OpenSoraConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+def run_pab():
+ config = OpenSoraConfig(world_size=1, enable_pab=True)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+if __name__ == "__main__":
+ run_base()
+ run_pab()
diff --git a/examples/open_sora_plan/sample.py b/examples/open_sora_plan/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3f3e9681a08906b28329db6a46c98c0a9ce2684
--- /dev/null
+++ b/examples/open_sora_plan/sample.py
@@ -0,0 +1,24 @@
+from videosys import OpenSoraPlanConfig, VideoSysEngine
+
+
+def run_base():
+ config = OpenSoraPlanConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+def run_pab():
+ config = OpenSoraPlanConfig(world_size=1)
+ engine = VideoSysEngine(config)
+
+ prompt = "Sunset over the sea."
+ video = engine.generate(prompt).video[0]
+ engine.save_video(video, f"./outputs/{prompt}.mp4")
+
+
+if __name__ == "__main__":
+ run_base()
+ # run_pab()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8d1f44e862c19e004a135972a2c356cfdb36681d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,25 @@
+gradio
+click
+colossalai
+contexttimer
+diffusers==0.30.0
+einops
+fabric
+ftfy
+imageio
+imageio-ffmpeg
+matplotlib
+ninja
+numpy
+omegaconf
+packaging
+psutil
+pydantic
+ray
+rich
+safetensors
+timm
+torch>=1.13
+tqdm
+transformers
+openai
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..50b95376560e467ab5976f0dbedae64e0590b152
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,55 @@
+from typing import List
+
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path) -> List[str]:
+ """
+ This function reads the requirements file.
+
+ Args:
+ path (str): the path to the requirements file.
+
+ Returns:
+ The lines in the requirements file.
+ """
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme() -> str:
+ """
+ This function reads the README.md file in the current directory.
+
+ Returns:
+ The lines in the README file.
+ """
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+setup(
+ name="videosys",
+ version="2.0.0",
+ packages=find_packages(
+ exclude=(
+ "videos",
+ "tests",
+ "figure",
+ "*.egg-info",
+ )
+ ),
+ description="VideoSys",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
+ ],
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/__init__.py b/videosys/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd86b4acbb1c9e577d3d6b9298b2d9824695e3c
--- /dev/null
+++ b/videosys/__init__.py
@@ -0,0 +1,19 @@
+from .core.engine import VideoSysEngine
+from .core.parallel_mgr import initialize
+from .models.cogvideo.pipeline import CogVideoConfig, CogVideoPipeline
+from .models.latte.pipeline import LatteConfig, LattePipeline
+from .models.open_sora.pipeline import OpenSoraConfig, OpenSoraPipeline
+from .models.open_sora_plan.pipeline import OpenSoraPlanConfig, OpenSoraPlanPipeline
+
+__all__ = [
+ "initialize",
+ "VideoSysEngine",
+ "LattePipeline",
+ "LatteConfig",
+ "OpenSoraPlanPipeline",
+ "OpenSoraPlanConfig",
+ "OpenSoraPipeline",
+ "OpenSoraConfig",
+ "CogVideoConfig",
+ "CogVideoPipeline",
+]
diff --git a/videosys/core/__init__.py b/videosys/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/core/comm.py b/videosys/core/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..175fba59882544a7a4a6365dd855970b39fa042d
--- /dev/null
+++ b/videosys/core/comm.py
@@ -0,0 +1,420 @@
+from typing import Any, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from einops import rearrange
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from videosys.core.parallel_mgr import get_sequence_parallel_size
+
+# ======================================================
+# Model
+# ======================================================
+
+
+def model_sharding(model: torch.nn.Module):
+ global_rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ for _, param in model.named_parameters():
+ padding_size = (world_size - param.numel() % world_size) % world_size
+ if padding_size > 0:
+ padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
+ else:
+ padding_param = param.data.view(-1)
+ splited_params = padding_param.split(padding_param.numel() // world_size)
+ splited_params = splited_params[global_rank]
+ param.data = splited_params
+
+
+# ======================================================
+# AllGather & ReduceScatter
+# ======================================================
+
+
+class AsyncAllGatherForTwo(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ sp_rank: int,
+ sp_size: int,
+ group: Optional[ProcessGroup] = None,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ from torch.distributed._functional_collectives import all_gather_tensor
+
+ ctx.group = group
+ ctx.sp_rank = sp_rank
+ ctx.sp_size = sp_size
+
+ # all gather inputs
+ all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
+ # compute local qkv
+ local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
+
+ # remote compute
+ remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
+ # compute remote qkv
+ remote_qkv = F.linear(remote_inputs, weight, bias)
+
+ # concat local and remote qkv
+ if sp_rank == 0:
+ qkv = torch.cat([local_qkv, remote_qkv], dim=0)
+ else:
+ qkv = torch.cat([remote_qkv, local_qkv], dim=0)
+ qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
+
+ ctx.save_for_backward(inputs, weight, remote_inputs)
+ return qkv
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ from torch.distributed._functional_collectives import reduce_scatter_tensor
+
+ group = ctx.group
+ sp_rank = ctx.sp_rank
+ sp_size = ctx.sp_size
+ inputs, weight, remote_inputs = ctx.saved_tensors
+
+ # split qkv_grad
+ qkv_grad = grad_outputs[0]
+ qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
+ qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
+ if sp_rank == 0:
+ local_qkv_grad, remote_qkv_grad = qkv_grad
+ else:
+ remote_qkv_grad, local_qkv_grad = qkv_grad
+
+ # compute remote grad
+ remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
+ weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
+ bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
+
+ # launch async reduce scatter
+ remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
+ if sp_rank == 0:
+ remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
+ else:
+ remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
+ remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
+
+ # compute local grad and wait for reduce scatter
+ local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
+ weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
+ bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
+
+ # sum remote and local grad
+ inputs_grad = remote_inputs_grad + local_input_grad
+ return inputs_grad, weight_grad, bias_grad, None, None, None
+
+
+class AllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ group: Optional[ProcessGroup] = None,
+ overlap: bool = False,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ assert ctx is not None or not overlap
+
+ if ctx is not None:
+ ctx.comm_grp = group
+
+ comm_size = dist.get_world_size(group)
+ if comm_size == 1:
+ return inputs.unsqueeze(0), None
+
+ buffer_shape = (comm_size,) + inputs.shape
+ outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
+ buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
+ if not overlap:
+ dist.all_gather(buffer_list, inputs, group=group)
+ return outputs, None
+ else:
+ handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
+ return outputs, handle
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ return (
+ ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
+ None,
+ None,
+ )
+
+
+class ReduceScatter(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ group: ProcessGroup,
+ overlap: bool = False,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ assert ctx is not None or not overlap
+
+ if ctx is not None:
+ ctx.comm_grp = group
+
+ comm_size = dist.get_world_size(group)
+ if comm_size == 1:
+ return inputs.squeeze(0), None
+
+ if not inputs.is_contiguous():
+ inputs = inputs.contiguous()
+
+ output_shape = inputs.shape[1:]
+ outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
+ buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
+ if not overlap:
+ dist.reduce_scatter(outputs, buffer_list, group=group)
+ return outputs, None
+ else:
+ handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
+ return outputs, handle
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ # TODO: support async backward
+ return (
+ AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
+ None,
+ None,
+ )
+
+
+# ======================================================
+# AlltoAll
+# ======================================================
+
+
+def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
+ dist.all_to_all(output_list, input_list, group=group)
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+class _AllToAll(torch.autograd.Function):
+ """All-to-all communication.
+
+ Args:
+ input_: input matrix
+ process_group: communication group
+ scatter_dim: scatter dimension
+ gather_dim: gather dimension
+ """
+
+ @staticmethod
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
+ ctx.process_group = process_group
+ ctx.scatter_dim = scatter_dim
+ ctx.gather_dim = gather_dim
+ world_size = dist.get_world_size(process_group)
+
+ return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ process_group = ctx.process_group
+ scatter_dim = ctx.gather_dim
+ gather_dim = ctx.scatter_dim
+ return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
+ return (return_grad, None, None, None)
+
+
+def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
+ return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
+
+
+# ======================================================
+# Sequence Gather & Split
+# ======================================================
+
+
+def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
+ # skip if only one rank involved
+ world_size = dist.get_world_size(pg)
+ rank = dist.get_rank(pg)
+ if world_size == 1:
+ return input_
+
+ if pad > 0:
+ pad_size = list(input_.shape)
+ pad_size[dim] = pad
+ input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
+
+ dim_size = input_.size(dim)
+ assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
+
+ tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
+ output = tensor_list[rank].contiguous()
+ return output
+
+
+def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
+ # skip if only one rank involved
+ input_ = input_.contiguous()
+ world_size = dist.get_world_size(pg)
+ dist.get_rank(pg)
+
+ if world_size == 1:
+ return input_
+
+ # all gather
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ assert input_.device.type == "cuda"
+ torch.distributed.all_gather(tensor_list, input_, group=pg)
+
+ # concat
+ output = torch.cat(tensor_list, dim=dim)
+
+ if pad > 0:
+ output = output.narrow(dim, 0, output.size(dim) - pad)
+
+ return output
+
+
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """
+ Gather the input sequence.
+
+ Args:
+ input_: input matrix.
+ process_group: process group.
+ dim: dimension
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _gather_sequence_func(input_)
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
+ ctx.process_group = process_group
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+ ctx.pad = pad
+ return _gather_sequence_func(input_, process_group, dim, pad)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.grad_scale == "up":
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
+ elif ctx.grad_scale == "down":
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
+
+ return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
+
+
+class _SplitForwardGatherBackward(torch.autograd.Function):
+ """
+ Split sequence.
+
+ Args:
+ input_: input matrix.
+ process_group: parallel mode.
+ dim: dimension
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _split_sequence_func(input_)
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
+ ctx.process_group = process_group
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+ ctx.pad = pad
+ return _split_sequence_func(input_, process_group, dim, pad)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.grad_scale == "up":
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
+ elif ctx.grad_scale == "down":
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
+ return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
+
+
+def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
+ return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
+
+
+def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
+ return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
+
+
+# ==============================
+# Pad
+# ==============================
+
+SPTIAL_PAD = 0
+TEMPORAL_PAD = 0
+
+
+def set_spatial_pad(dim_size: int):
+ sp_size = get_sequence_parallel_size()
+ pad = (sp_size - (dim_size % sp_size)) % sp_size
+ global SPTIAL_PAD
+ SPTIAL_PAD = pad
+
+
+def get_spatial_pad() -> int:
+ return SPTIAL_PAD
+
+
+def set_temporal_pad(dim_size: int):
+ sp_size = get_sequence_parallel_size()
+ pad = (sp_size - (dim_size % sp_size)) % sp_size
+ global TEMPORAL_PAD
+ TEMPORAL_PAD = pad
+
+
+def get_temporal_pad() -> int:
+ return TEMPORAL_PAD
+
+
+def all_to_all_with_pad(
+ input_: torch.Tensor,
+ process_group: dist.ProcessGroup,
+ scatter_dim: int = 2,
+ gather_dim: int = 1,
+ scatter_pad: int = 0,
+ gather_pad: int = 0,
+):
+ if scatter_pad > 0:
+ pad_shape = list(input_.shape)
+ pad_shape[scatter_dim] = scatter_pad
+ pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
+ input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
+
+ assert (
+ input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
+ ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
+ input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
+
+ if gather_pad > 0:
+ input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
+
+ return input_
diff --git a/videosys/core/engine.py b/videosys/core/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0976159e51a9b74330e2d7b0879d54efaa6ece
--- /dev/null
+++ b/videosys/core/engine.py
@@ -0,0 +1,132 @@
+import os
+from functools import partial
+from typing import Any, Optional
+
+import imageio
+import torch
+
+import videosys
+
+from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
+
+
+class VideoSysEngine:
+ """
+ this is partly inspired by vllm
+ """
+
+ def __init__(self, config):
+ self.config = config
+ self.parallel_worker_tasks = None
+ self._init_worker(config.pipeline_cls)
+
+ def _init_worker(self, pipeline_cls):
+ world_size = self.config.world_size
+
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
+
+ # Disable torch async compiling which won't work with daemonic processes
+ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
+
+ # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
+ # contention amongst the shards
+ if "OMP_NUM_THREADS" not in os.environ:
+ os.environ["OMP_NUM_THREADS"] = "1"
+
+ # NOTE: The two following lines need adaption for multi-node
+ assert world_size <= torch.cuda.device_count()
+
+ # change addr for multi-node
+ distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
+
+ if world_size == 1:
+ self.workers = []
+ self.worker_monitor = None
+ else:
+ result_handler = ResultHandler()
+ self.workers = [
+ ProcessWorkerWrapper(
+ result_handler,
+ partial(
+ self._create_pipeline,
+ pipeline_cls=pipeline_cls,
+ rank=rank,
+ local_rank=rank,
+ distributed_init_method=distributed_init_method,
+ ),
+ )
+ for rank in range(1, world_size)
+ ]
+
+ self.worker_monitor = WorkerMonitor(self.workers, result_handler)
+ result_handler.start()
+ self.worker_monitor.start()
+
+ self.driver_worker = self._create_pipeline(
+ pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
+ )
+
+ # TODO: add more options here for pipeline, or wrap all options into config
+ def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
+ videosys.initialize(rank=rank, world_size=self.config.world_size, init_method=distributed_init_method, seed=42)
+
+ pipeline = pipeline_cls(self.config)
+ return pipeline
+
+ def _run_workers(
+ self,
+ method: str,
+ *args,
+ async_run_tensor_parallel_workers_only: bool = False,
+ max_concurrent_workers: Optional[int] = None,
+ **kwargs,
+ ) -> Any:
+ """Runs the given method on all workers."""
+
+ # Start the workers first.
+ worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
+
+ if async_run_tensor_parallel_workers_only:
+ # Just return futures
+ return worker_outputs
+
+ driver_worker_method = getattr(self.driver_worker, method)
+ driver_worker_output = driver_worker_method(*args, **kwargs)
+
+ # Get the results of the workers.
+ return [driver_worker_output] + [output.get() for output in worker_outputs]
+
+ def _driver_execute_model(self, *args, **kwargs):
+ return self.driver_worker.generate(*args, **kwargs)
+
+ def generate(self, *args, **kwargs):
+ return self._run_workers("generate", *args, **kwargs)[0]
+
+ def stop_remote_worker_execution_loop(self) -> None:
+ if self.parallel_worker_tasks is None:
+ return
+
+ parallel_worker_tasks = self.parallel_worker_tasks
+ self.parallel_worker_tasks = None
+ # Ensure that workers exit model loop cleanly
+ # (this will raise otherwise)
+ self._wait_for_tasks_completion(parallel_worker_tasks)
+
+ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+ """Wait for futures returned from _run_workers() with
+ async_run_remote_workers_only to complete."""
+ for result in parallel_worker_tasks:
+ result.get()
+
+ def save_video(self, video, output_path):
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ imageio.mimwrite(output_path, video, fps=24)
+
+ def shutdown(self):
+ if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
+ worker_monitor.close()
+ torch.distributed.destroy_process_group()
+
+ def __del__(self):
+ self.shutdown()
\ No newline at end of file
diff --git a/videosys/core/mp_utils.py b/videosys/core/mp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03beb0b0c9b1875ca7dc2b91b12c37c9a886c8f
--- /dev/null
+++ b/videosys/core/mp_utils.py
@@ -0,0 +1,270 @@
+# adapted from vllm
+# https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
+
+import asyncio
+import multiprocessing
+import os
+import socket
+import sys
+import threading
+import traceback
+import uuid
+from dataclasses import dataclass
+from multiprocessing import Queue
+from multiprocessing.connection import wait
+from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
+
+from videosys.utils.logging import create_logger
+
+T = TypeVar("T")
+_TERMINATE = "TERMINATE" # sentinel
+# ANSI color codes
+CYAN = "\033[1;36m"
+RESET = "\033[0;0m"
+JOIN_TIMEOUT_S = 2
+
+mp_method = "spawn" # fork cann't work
+mp = multiprocessing.get_context(mp_method)
+
+logger = create_logger()
+
+
+def get_distributed_init_method(ip: str, port: int) -> str:
+ # Brackets are not permitted in ipv4 addresses,
+ # see https://github.com/python/cpython/issues/103848
+ return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
+
+
+def get_open_port() -> int:
+ # try ipv4
+ try:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+ except OSError:
+ # try ipv6
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+@dataclass
+class Result(Generic[T]):
+ """Result of task dispatched to worker"""
+
+ task_id: uuid.UUID
+ value: Optional[T] = None
+ exception: Optional[BaseException] = None
+
+
+class ResultFuture(threading.Event, Generic[T]):
+ """Synchronous future for non-async case"""
+
+ def __init__(self):
+ super().__init__()
+ self.result: Optional[Result[T]] = None
+
+ def set_result(self, result: Result[T]):
+ self.result = result
+ self.set()
+
+ def get(self) -> T:
+ self.wait()
+ assert self.result is not None
+ if self.result.exception is not None:
+ raise self.result.exception
+ return self.result.value # type: ignore[return-value]
+
+
+def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
+ if isinstance(future, ResultFuture):
+ future.set_result(result)
+ return
+ loop = future.get_loop()
+ if not loop.is_closed():
+ if result.exception is not None:
+ loop.call_soon_threadsafe(future.set_exception, result.exception)
+ else:
+ loop.call_soon_threadsafe(future.set_result, result.value)
+
+
+class ResultHandler(threading.Thread):
+ """Handle results from all workers (in background thread)"""
+
+ def __init__(self) -> None:
+ super().__init__(daemon=True)
+ self.result_queue = mp.Queue()
+ self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
+
+ def run(self):
+ for result in iter(self.result_queue.get, _TERMINATE):
+ future = self.tasks.pop(result.task_id)
+ _set_future_result(future, result)
+ # Ensure that all waiters will receive an exception
+ for task_id, future in self.tasks.items():
+ _set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
+
+ def close(self):
+ self.result_queue.put(_TERMINATE)
+
+
+class WorkerMonitor(threading.Thread):
+ """Monitor worker status (in background thread)"""
+
+ def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
+ super().__init__(daemon=True)
+ self.workers = workers
+ self.result_handler = result_handler
+ self._close = False
+
+ def run(self) -> None:
+ # Blocks until any worker exits
+ dead_sentinels = wait([w.process.sentinel for w in self.workers])
+ if not self._close:
+ self._close = True
+
+ # Kill / cleanup all workers
+ for worker in self.workers:
+ process = worker.process
+ if process.sentinel in dead_sentinels:
+ process.join(JOIN_TIMEOUT_S)
+ if process.exitcode is not None and process.exitcode != 0:
+ logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
+ # Cleanup any remaining workers
+ logger.info("Killing local worker processes")
+ for worker in self.workers:
+ worker.kill_worker()
+ # Must be done after worker task queues are all closed
+ self.result_handler.close()
+
+ for worker in self.workers:
+ worker.process.join(JOIN_TIMEOUT_S)
+
+ def close(self):
+ if self._close:
+ return
+ self._close = True
+ logger.info("Terminating local worker processes")
+ for worker in self.workers:
+ worker.terminate_worker()
+ # Must be done after worker task queues are all closed
+ self.result_handler.close()
+
+
+def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
+ """Prepend each output line with process-specific prefix"""
+
+ prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
+ file_write = file.write
+
+ def write_with_prefix(s: str):
+ if not s:
+ return
+ if file.start_new_line: # type: ignore[attr-defined]
+ file_write(prefix)
+ idx = 0
+ while (next_idx := s.find("\n", idx)) != -1:
+ next_idx += 1
+ file_write(s[idx:next_idx])
+ if next_idx == len(s):
+ file.start_new_line = True # type: ignore[attr-defined]
+ return
+ file_write(prefix)
+ idx = next_idx
+ file_write(s[idx:])
+ file.start_new_line = False # type: ignore[attr-defined]
+
+ file.start_new_line = True # type: ignore[attr-defined]
+ file.write = write_with_prefix # type: ignore[method-assign]
+
+
+def _run_worker_process(
+ worker_factory: Callable[[], Any],
+ task_queue: Queue,
+ result_queue: Queue,
+) -> None:
+ """Worker process event loop"""
+
+ # Add process-specific prefix to stdout and stderr
+ process_name = mp.current_process().name
+ pid = os.getpid()
+ _add_prefix(sys.stdout, process_name, pid)
+ _add_prefix(sys.stderr, process_name, pid)
+
+ # Initialize worker
+ worker = worker_factory()
+ del worker_factory
+
+ # Accept tasks from the engine in task_queue
+ # and return task output in result_queue
+ logger.info("Worker ready; awaiting tasks")
+ try:
+ for items in iter(task_queue.get, _TERMINATE):
+ output = None
+ exception = None
+ task_id, method, args, kwargs = items
+ try:
+ executor = getattr(worker, method)
+ output = executor(*args, **kwargs)
+ except BaseException as e:
+ tb = traceback.format_exc()
+ logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
+ exception = e
+ result_queue.put(Result(task_id=task_id, value=output, exception=exception))
+ except KeyboardInterrupt:
+ pass
+ except Exception:
+ logger.exception("Worker failed")
+
+ logger.info("Worker exiting")
+
+
+class ProcessWorkerWrapper:
+ """Local process wrapper for handling single-node multi-GPU."""
+
+ def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
+ self._task_queue = mp.Queue()
+ self.result_queue = result_handler.result_queue
+ self.tasks = result_handler.tasks
+ self.process = mp.Process( # type: ignore[attr-defined]
+ target=_run_worker_process,
+ name="VideoSysWorkerProcess",
+ kwargs=dict(
+ worker_factory=worker_factory,
+ task_queue=self._task_queue,
+ result_queue=self.result_queue,
+ ),
+ daemon=True,
+ )
+
+ self.process.start()
+
+ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
+ task_id = uuid.uuid4()
+ self.tasks[task_id] = future
+ try:
+ self._task_queue.put((task_id, method, args, kwargs))
+ except BaseException as e:
+ del self.tasks[task_id]
+ raise ChildProcessError("worker died") from e
+
+ def execute_method(self, method: str, *args, **kwargs):
+ future: ResultFuture = ResultFuture()
+ self._enqueue_task(future, method, args, kwargs)
+ return future
+
+ async def execute_method_async(self, method: str, *args, **kwargs):
+ future = asyncio.get_running_loop().create_future()
+ self._enqueue_task(future, method, args, kwargs)
+ return await future
+
+ def terminate_worker(self):
+ try:
+ self._task_queue.put(_TERMINATE)
+ except ValueError:
+ self.process.kill()
+ self._task_queue.close()
+
+ def kill_worker(self):
+ self._task_queue.close()
+ self.process.kill()
diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ce857c85efe6769795d4044bc4bad6c42f61c7
--- /dev/null
+++ b/videosys/core/pab_mgr.py
@@ -0,0 +1,364 @@
+import random
+
+import numpy as np
+import torch
+
+from videosys.utils.logging import logger
+
+PAB_MANAGER = None
+
+
+class PABConfig:
+ def __init__(
+ self,
+ steps: int,
+ cross_broadcast: bool,
+ cross_threshold: list,
+ cross_gap: int,
+ spatial_broadcast: bool,
+ spatial_threshold: list,
+ spatial_gap: int,
+ temporal_broadcast: bool,
+ temporal_threshold: list,
+ temporal_gap: int,
+ diffusion_skip: bool,
+ diffusion_timestep_respacing: list,
+ diffusion_skip_timestep: list,
+ mlp_skip: bool,
+ mlp_spatial_skip_config: dict,
+ mlp_temporal_skip_config: dict,
+ full_broadcast: bool = False,
+ full_threshold: list = None,
+ full_gap: int = 1,
+ ):
+ self.steps = steps
+
+ self.cross_broadcast = cross_broadcast
+ self.cross_threshold = cross_threshold
+ self.cross_gap = cross_gap
+
+ self.spatial_broadcast = spatial_broadcast
+ self.spatial_threshold = spatial_threshold
+ self.spatial_gap = spatial_gap
+
+ self.temporal_broadcast = temporal_broadcast
+ self.temporal_threshold = temporal_threshold
+ self.temporal_gap = temporal_gap
+
+ self.diffusion_skip = diffusion_skip
+ self.diffusion_timestep_respacing = diffusion_timestep_respacing
+ self.diffusion_skip_timestep = diffusion_skip_timestep
+
+ self.mlp_skip = mlp_skip
+ self.mlp_spatial_skip_config = mlp_spatial_skip_config
+ self.mlp_temporal_skip_config = mlp_temporal_skip_config
+
+ self.temporal_mlp_outputs = {}
+ self.spatial_mlp_outputs = {}
+
+ self.full_broadcast = full_broadcast
+ self.full_threshold = full_threshold
+ self.full_gap = full_gap
+
+
+class PABManager:
+ def __init__(self, config: PABConfig):
+ self.config: PABConfig = config
+
+ init_prompt = f"Init PABManager. steps: {config.steps}."
+ init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
+ init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
+ init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
+ init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
+ logger.info(init_prompt)
+
+ def if_broadcast_cross(self, timestep: int, count: int):
+ if (
+ self.config.cross_broadcast
+ and (timestep is not None)
+ and (count % self.config.cross_gap != 0)
+ and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
+ ):
+ flag = True
+ else:
+ flag = False
+ count = (count + 1) % self.config.steps
+ return flag, count
+
+ def if_broadcast_temporal(self, timestep: int, count: int):
+ if (
+ self.config.temporal_broadcast
+ and (timestep is not None)
+ and (count % self.config.temporal_gap != 0)
+ and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
+ ):
+ flag = True
+ else:
+ flag = False
+ count = (count + 1) % self.config.steps
+ return flag, count
+
+ def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
+ if (
+ self.config.spatial_broadcast
+ and (timestep is not None)
+ and (count % self.config.spatial_gap != 0)
+ and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
+ ):
+ flag = True
+ else:
+ flag = False
+ count = (count + 1) % self.config.steps
+ return flag, count
+
+ def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
+ if (
+ self.config.full_broadcast
+ and (timestep is not None)
+ and (count % self.config.full_gap != 0)
+ and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
+ ):
+ flag = True
+ else:
+ flag = False
+ count = (count + 1) % self.config.steps
+ return flag, count
+
+ @staticmethod
+ def _is_t_in_skip_config(all_timesteps, timestep, config):
+ is_t_in_skip_config = False
+ for key in config:
+ if key not in all_timesteps:
+ continue
+ index = all_timesteps.index(key)
+ skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
+ if timestep in skip_range:
+ is_t_in_skip_config = True
+ skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
+ break
+ return is_t_in_skip_config, skip_range
+
+ def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
+ if not self.config.mlp_skip:
+ return False, None, False, None
+
+ if is_temporal:
+ cur_config = self.config.mlp_temporal_skip_config
+ else:
+ cur_config = self.config.mlp_spatial_skip_config
+
+ is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
+ next_flag = False
+ if (
+ self.config.mlp_skip
+ and (timestep is not None)
+ and (timestep in cur_config)
+ and (block_idx in cur_config[timestep]["block"])
+ ):
+ flag = False
+ next_flag = True
+ count = count + 1
+ elif (
+ self.config.mlp_skip
+ and (timestep is not None)
+ and (is_t_in_skip_config)
+ and (block_idx in cur_config[skip_range[0]]["block"])
+ ):
+ flag = True
+ count = 0
+ else:
+ flag = False
+
+ return flag, count, next_flag, skip_range
+
+ def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
+ if is_temporal:
+ self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
+ else:
+ self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
+
+ def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
+ skip_start_t = skip_range[0]
+ if is_temporal:
+ skip_output = (
+ self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
+ if self.config.temporal_mlp_outputs is not None
+ else None
+ )
+ else:
+ skip_output = (
+ self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
+ if self.config.spatial_mlp_outputs is not None
+ else None
+ )
+
+ if skip_output is not None:
+ if timestep == skip_range[-1]:
+ # TODO: save memory
+ if is_temporal:
+ del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
+ else:
+ del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
+ else:
+ raise ValueError(
+ f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
+ )
+
+ return skip_output
+
+ def get_spatial_mlp_outputs(self):
+ return self.config.spatial_mlp_outputs
+
+ def get_temporal_mlp_outputs(self):
+ return self.config.temporal_mlp_outputs
+
+
+def set_pab_manager(config: PABConfig):
+ global PAB_MANAGER
+ PAB_MANAGER = PABManager(config)
+
+
+def enable_pab():
+ if PAB_MANAGER is None:
+ return False
+ return (
+ PAB_MANAGER.config.cross_broadcast
+ or PAB_MANAGER.config.spatial_broadcast
+ or PAB_MANAGER.config.temporal_broadcast
+ )
+
+
+def update_steps(steps: int):
+ if PAB_MANAGER is not None:
+ PAB_MANAGER.config.steps = steps
+
+
+def if_broadcast_cross(timestep: int, count: int):
+ if not enable_pab():
+ return False, count
+ return PAB_MANAGER.if_broadcast_cross(timestep, count)
+
+
+def if_broadcast_temporal(timestep: int, count: int):
+ if not enable_pab():
+ return False, count
+ return PAB_MANAGER.if_broadcast_temporal(timestep, count)
+
+
+def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
+ if not enable_pab():
+ return False, count
+ return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
+
+def if_broadcast_full(timestep: int, count: int, block_idx: int):
+ if not enable_pab():
+ return False, count
+ return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
+
+
+def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
+ if not enable_pab():
+ return False, count
+ return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
+
+
+def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
+ return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
+
+
+def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
+ return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
+
+
+def get_diffusion_skip():
+ return enable_pab() and PAB_MANAGER.config.diffusion_skip
+
+
+def get_diffusion_timestep_respacing():
+ return PAB_MANAGER.config.diffusion_timestep_respacing
+
+
+def get_diffusion_skip_timestep():
+ return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
+
+
+def space_timesteps(time_steps, time_bins):
+ num_bins = len(time_bins)
+ bin_size = time_steps // num_bins
+
+ result = []
+
+ for i, bin_count in enumerate(time_bins):
+ start = i * bin_size
+ end = start + bin_size
+
+ bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
+ result.extend(bin_steps)
+
+ result_tensor = torch.tensor(result, dtype=torch.int32)
+ sorted_tensor = torch.sort(result_tensor, descending=True).values
+
+ return sorted_tensor
+
+
+def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
+ if isinstance(timesteps, list):
+ # If timesteps is a list, we assume each element is a tensor
+ timesteps_np = [t.cpu().numpy() for t in timesteps]
+ device = timesteps[0].device
+ else:
+ # If timesteps is a tensor
+ timesteps_np = timesteps.cpu().numpy()
+ device = timesteps.device
+
+ num_bins = len(diffusion_skip_timestep)
+
+ if isinstance(timesteps_np, list):
+ bin_size = len(timesteps_np) // num_bins
+ new_timesteps = []
+
+ for i in range(num_bins):
+ bin_start = i * bin_size
+ bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
+ bin_timesteps = timesteps_np[bin_start:bin_end]
+
+ if diffusion_skip_timestep[i] == 0:
+ # If the bin is marked with 0, keep all timesteps
+ new_timesteps.extend(bin_timesteps)
+ elif diffusion_skip_timestep[i] == 1:
+ # If the bin is marked with 1, omit the last timestep in the bin
+ new_timesteps.extend(bin_timesteps[1:])
+
+ new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
+ else:
+ bin_size = len(timesteps_np) // num_bins
+ new_timesteps = []
+
+ for i in range(num_bins):
+ bin_start = i * bin_size
+ bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
+ bin_timesteps = timesteps_np[bin_start:bin_end]
+
+ if diffusion_skip_timestep[i] == 0:
+ # If the bin is marked with 0, keep all timesteps
+ new_timesteps.extend(bin_timesteps)
+ elif diffusion_skip_timestep[i] == 1:
+ # If the bin is marked with 1, omit the last timestep in the bin
+ new_timesteps.extend(bin_timesteps[1:])
+ elif diffusion_skip_timestep[i] != 0:
+ # If the bin is marked with a non-zero value, randomly omit n timesteps
+ if len(bin_timesteps) > diffusion_skip_timestep[i]:
+ indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
+ timesteps_to_keep = [
+ timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
+ ]
+ else:
+ timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
+ new_timesteps.extend(timesteps_to_keep)
+
+ new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
+
+ if isinstance(timesteps, list):
+ return new_timesteps_tensor
+ else:
+ return new_timesteps_tensor
diff --git a/videosys/core/parallel_mgr.py b/videosys/core/parallel_mgr.py
new file mode 100644
index 0000000000000000000000000000000000000000..733933eb8a8e8e7f3e455dcc17bd90176b5ef6cf
--- /dev/null
+++ b/videosys/core/parallel_mgr.py
@@ -0,0 +1,119 @@
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from colossalai.cluster.process_group_mesh import ProcessGroupMesh
+from torch.distributed import ProcessGroup
+
+from videosys.utils.logging import init_dist_logger, logger
+from videosys.utils.utils import set_seed
+
+PARALLEL_MANAGER = None
+
+
+class ParallelManager(ProcessGroupMesh):
+ def __init__(self, dp_size, cp_size, sp_size):
+ super().__init__(dp_size, cp_size, sp_size)
+ dp_axis, cp_axis, sp_axis = 0, 1, 2
+
+ self.dp_size = dp_size
+ self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
+ self.dp_rank = dist.get_rank(self.dp_group)
+
+ self.cp_size = cp_size
+ self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
+ self.cp_rank = dist.get_rank(self.cp_group)
+
+ self.sp_size = sp_size
+ self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
+ self.sp_rank = dist.get_rank(self.sp_group)
+ self.enable_sp = sp_size > 1
+
+ logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
+
+
+def set_parallel_manager(dp_size, cp_size, sp_size):
+ global PARALLEL_MANAGER
+ PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
+
+
+def get_data_parallel_group():
+ return PARALLEL_MANAGER.dp_group
+
+
+def get_data_parallel_size():
+ return PARALLEL_MANAGER.dp_size
+
+
+def get_data_parallel_rank():
+ return PARALLEL_MANAGER.dp_rank
+
+
+def get_sequence_parallel_group():
+ return PARALLEL_MANAGER.sp_group
+
+
+def get_sequence_parallel_size():
+ return PARALLEL_MANAGER.sp_size
+
+
+def get_sequence_parallel_rank():
+ return PARALLEL_MANAGER.sp_rank
+
+
+def get_cfg_parallel_group():
+ return PARALLEL_MANAGER.cp_group
+
+
+def get_cfg_parallel_size():
+ return PARALLEL_MANAGER.cp_size
+
+
+def enable_sequence_parallel():
+ if PARALLEL_MANAGER is None:
+ return False
+ return PARALLEL_MANAGER.enable_sp
+
+
+def get_parallel_manager():
+ return PARALLEL_MANAGER
+
+
+def initialize(
+ rank=0,
+ world_size=1,
+ init_method=None,
+ seed: Optional[int] = None,
+ sp_size: Optional[int] = None,
+ enable_cp: bool = True,
+):
+ if not dist.is_initialized():
+ try:
+ dist.destroy_process_group()
+ except Exception:
+ pass
+ dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
+ torch.cuda.set_device(rank)
+ init_dist_logger()
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ # init sequence parallel
+ if sp_size is None:
+ sp_size = dist.get_world_size()
+ dp_size = 1
+ else:
+ assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
+ dp_size = dist.get_world_size() // sp_size
+
+ # update cfg parallel
+ if enable_cp and sp_size % 2 == 0:
+ sp_size = sp_size // 2
+ cp_size = 2
+ else:
+ cp_size = 1
+
+ set_parallel_manager(dp_size, cp_size, sp_size)
+
+ if seed is not None:
+ set_seed(seed + get_data_parallel_rank())
diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c59b734688a68eae97ffb980c8d8747efb0bb2e
--- /dev/null
+++ b/videosys/core/pipeline.py
@@ -0,0 +1,34 @@
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import torch
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput
+
+
+class VideoSysPipeline(DiffusionPipeline):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def set_eval_and_device(device: torch.device, *modules):
+ for module in modules:
+ module.eval()
+ module.to(device)
+
+ @abstractmethod
+ def generate(self, *args, **kwargs):
+ pass
+
+ def __call__(self, *args, **kwargs):
+ """
+ In diffusers, it is a convention to call the pipeline object.
+ But in VideoSys, we will use the generate method for better prompt.
+ This is a wrapper for the generate method to support the diffusers usage.
+ """
+ return self.generate(*args, **kwargs)
+
+
+@dataclass
+class VideoSysPipelineOutput(BaseOutput):
+ video: torch.Tensor
diff --git a/videosys/core/shardformer/__init__.py b/videosys/core/shardformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/core/shardformer/t5/__init__.py b/videosys/core/shardformer/t5/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/core/shardformer/t5/modeling.py b/videosys/core/shardformer/t5/modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cfb80841c92a57628fba81425627053afc76a3b
--- /dev/null
+++ b/videosys/core/shardformer/t5/modeling.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+
+
+class T5LayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+ @staticmethod
+ def from_native_module(module, *args, **kwargs):
+ assert module.__class__.__name__ == "FusedRMSNorm", (
+ "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
+ "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
+ )
+
+ layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
+ layer_norm.weight.data.copy_(module.weight.data)
+ layer_norm = layer_norm.to(module.weight.device)
+ return layer_norm
diff --git a/videosys/core/shardformer/t5/policy.py b/videosys/core/shardformer/t5/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8071b17d5fb741b6d5736131bdeffdf1154c5b8
--- /dev/null
+++ b/videosys/core/shardformer/t5/policy.py
@@ -0,0 +1,68 @@
+from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
+from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
+from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
+
+
+class T5EncoderPolicy(Policy):
+ def config_sanity_check(self):
+ assert not self.shard_config.enable_tensor_parallelism
+ assert not self.shard_config.enable_flash_attention
+
+ def preprocess(self):
+ return self.model
+
+ def module_policy(self):
+ from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
+
+ policy = {}
+
+ # check whether apex is installed
+ try:
+ from apex.normalization import FusedRMSNorm # noqa
+ from videosys.core.shardformer.t5.modeling import T5LayerNorm
+
+ # recover hf from fused rms norm to T5 norm which is faster
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=T5LayerNorm,
+ ),
+ policy=policy,
+ target_key=T5LayerFF,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
+ policy=policy,
+ target_key=T5LayerSelfAttention,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
+ policy=policy,
+ target_key=T5Stack,
+ )
+ except (ImportError, ModuleNotFoundError):
+ pass
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_jit_fused_T5_layer_ff_forward(),
+ "dropout_add": get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=T5LayerFF,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_T5_layer_self_attention_forward(),
+ "dropout_add": get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=T5LayerSelfAttention,
+ )
+
+ return policy
+
+ def postprocess(self):
+ return self.model
diff --git a/videosys/datasets/dataloader.py b/videosys/datasets/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..22a7be3dd188cb425910571510634ea697ab6550
--- /dev/null
+++ b/videosys/datasets/dataloader.py
@@ -0,0 +1,94 @@
+import random
+from typing import Iterator, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from videosys.core.parallel_mgr import ParallelManager
+
+
+class StatefulDistributedSampler(DistributedSampler):
+ def __init__(
+ self,
+ dataset: Dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ ) -> None:
+ super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
+ self.start_index: int = 0
+
+ def __iter__(self) -> Iterator:
+ iterator = super().__iter__()
+ indices = list(iterator)
+ indices = indices[self.start_index :]
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples - self.start_index
+
+ def set_start_index(self, start_index: int) -> None:
+ self.start_index = start_index
+
+
+def prepare_dataloader(
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ pg_manager: Optional[ParallelManager] = None,
+ **kwargs,
+):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
+
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader `_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = StatefulDistributedSampler(
+ dataset,
+ num_replicas=pg_manager.size(pg_manager.dp_axis),
+ rank=pg_manager.coordinate(pg_manager.dp_axis),
+ shuffle=shuffle,
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
diff --git a/videosys/datasets/image_transform.py b/videosys/datasets/image_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efa8bb45c5b51adf072c5bc5b710f7e2e272409
--- /dev/null
+++ b/videosys/datasets/image_transform.py
@@ -0,0 +1,42 @@
+# Adapted from DiT
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DiT: https://github.com/facebookresearch/DiT
+# --------------------------------------------------------
+
+
+import numpy as np
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
+
+
+def get_transforms_image(image_size=256):
+ transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+ return transform
diff --git a/videosys/datasets/video_transform.py b/videosys/datasets/video_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f0fb440026d078835a19f5389a86930e697010
--- /dev/null
+++ b/videosys/datasets/video_transform.py
@@ -0,0 +1,441 @@
+# Adapted from OpenSora and Latte
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# Latte: https://github.com/Vchitect/Latte
+# --------------------------------------------------------
+
+import numbers
+import random
+
+import numpy as np
+import torch
+from PIL import Image
+
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+
+def resize_scale(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size[0] / min(H, W)
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+ """
+ Do spatial cropping and resizing to the video clip
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped region.
+ w (int): Width of the cropped region.
+ size (tuple(int, int)): height and width of resized clip
+ Returns:
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ clip = crop(clip, i, j, h, w)
+ clip = resize(clip, size, interpolation_mode)
+ return clip
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ th, tw = crop_size
+ if h < th or w < tw:
+ raise ValueError("height and width must be no smaller than crop_size")
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def center_crop_using_short_edge(clip):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ if h < w:
+ th, tw = h, h
+ i = 0
+ j = int(round((w - tw) / 2.0))
+ else:
+ th, tw = w, w
+ i = int(round((h - th) / 2.0))
+ j = 0
+ return crop(clip, i, j, th, tw)
+
+
+def random_shift_crop(clip):
+ """
+ Slide along the long edge, with the short edge as crop size
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+
+ if h <= w:
+ short_edge = h
+ else:
+ short_edge = w
+
+ th, tw = short_edge, short_edge
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+ return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
+ return clip.float() / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ mean (tuple): pixel RGB mean. Size is (3)
+ std (tuple): pixel standard deviation. Size is (3)
+ Returns:
+ normalized clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ if not inplace:
+ clip = clip.clone()
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+ # print(mean)
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+ return clip
+
+
+def hflip(clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
+ Returns:
+ flipped clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ return clip.flip(-1)
+
+
+class RandomCropVideo:
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: randomly cropped video clip.
+ size is (T, C, OH, OW)
+ """
+ i, j, h, w = self.get_params(clip)
+ return crop(clip, i, j, h, w)
+
+ def get_params(self, clip):
+ h, w = clip.shape[-2:]
+ th, tw = self.size
+
+ if h < th or w < tw:
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
+
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
+
+ return i, j, th, tw
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size})"
+
+
+class CenterCropResizeVideo:
+ """
+ First use the short side for cropping length,
+ center crop video, then resize to the specified size
+ """
+
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_center_crop = center_crop_using_short_edge(clip)
+ clip_center_crop_resize = resize(
+ clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
+ )
+ return clip_center_crop_resize
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class UCFCenterCropVideo:
+ """
+ First scale to the specified size in equal proportion to the short edge,
+ then center cropping
+ """
+
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ clip_center_crop = center_crop(clip_resize, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class KineticsRandomCropResizeVideo:
+ """
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
+ """
+
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ clip_random_crop = random_shift_crop(clip)
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
+ return clip_resize
+
+
+class CenterCropVideo:
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_center_crop = center_crop(clip, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class NormalizeVideo:
+ """
+ Normalize the video clip by mean subtraction and division by standard deviation
+ Args:
+ mean (3-tuple): pixel RGB mean
+ std (3-tuple): pixel RGB standard deviation
+ inplace (boolean): whether do in-place normalization
+ """
+
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
+ """
+ return normalize(clip, self.mean, self.std, self.inplace)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ return to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class RandomHorizontalFlipVideo:
+ """
+ Flip the video clip along the horizontal direction with a given probability
+ Args:
+ p (float): probability of the clip being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor): Size is (T, C, H, W)
+ """
+ if random.random() < self.p:
+ clip = hflip(clip)
+ return clip
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(p={self.p})"
+
+
+# ------------------------------------------------------------
+# --------------------- Sampling ---------------------------
+# ------------------------------------------------------------
+class TemporalRandomCrop(object):
+ """Temporally crop the given frame indices at a random location.
+
+ Args:
+ size (int): Desired length of frames will be seen in the model.
+ """
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, total_frames):
+ rand_end = max(0, total_frames - self.size - 1)
+ begin_index = random.randint(0, rand_end)
+ end_index = min(begin_index + self.size, total_frames)
+ return begin_index, end_index
diff --git a/videosys/diffusion/__init__.py b/videosys/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d16b842bf5a1bdc145b923693143e25a7e2ce81
--- /dev/null
+++ b/videosys/diffusion/__init__.py
@@ -0,0 +1,41 @@
+# Modified from OpenAI's diffusion repos and Meta DiT
+# DiT: https://github.com/facebookresearch/DiT/tree/main
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ timestep_respacing,
+ noise_schedule="linear",
+ use_kl=False,
+ sigma_small=False,
+ predict_xstart=False,
+ learn_sigma=True,
+ rescale_learned_sigmas=False,
+ diffusion_steps=1000,
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
+ model_var_type=(
+ (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/videosys/diffusion/diffusion_utils.py b/videosys/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..056471c0b0b560d17d18b95f9b8ef3dbc1b8317e
--- /dev/null
+++ b/videosys/diffusion/diffusion_utils.py
@@ -0,0 +1,79 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import numpy as np
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
+
+ return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/videosys/diffusion/gaussian_diffusion.py b/videosys/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf734a641690a1c1f4f5256eea1792afd71b800c
--- /dev/null
+++ b/videosys/diffusion/gaussian_diffusion.py
@@ -0,0 +1,829 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = (
+ np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
+ if len(self.posterior_variance) > 1
+ else np.array([])
+ )
+
+ self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
+ out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/videosys/diffusion/respace.py b/videosys/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5754aa70a9f221a2320ba7a56ab0cb5f4ed9188
--- /dev/null
+++ b/videosys/diffusion/respace.py
@@ -0,0 +1,119 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/videosys/diffusion/timestep_sampler.py b/videosys/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdaa45acfcf239d7b6aaf5a83ee12fd553bc06b8
--- /dev/null
+++ b/videosys/diffusion/timestep_sampler.py
@@ -0,0 +1,143 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/videosys/models/__init__.py b/videosys/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/models/cogvideo/__init__.py b/videosys/models/cogvideo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e80e1830a6f3a5e53b9c315e377dce38fca82ed
--- /dev/null
+++ b/videosys/models/cogvideo/__init__.py
@@ -0,0 +1,6 @@
+from .pipeline import CogVideoConfig, CogVideoPipeline
+
+__all__ = [
+ "CogVideoConfig",
+ "CogVideoPipeline",
+]
diff --git a/videosys/models/cogvideo/autoencoder_kl.py b/videosys/models/cogvideo/autoencoder_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e52a2b80b50346adcd95a4dca9693884a5a3d3
--- /dev/null
+++ b/videosys/models/cogvideo/autoencoder_kl.py
@@ -0,0 +1,1024 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.models.activations import get_activation
+from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import logging
+from diffusers.utils.accelerate_utils import apply_forward_hook
+
+from .modules import CogVideoXDownsample3D, CogVideoXUpsample3D
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogVideoXSafeConv3d(nn.Conv3d):
+ """
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
+ """
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+
+ # Set to 2GB, suitable for CuDNN
+ if memory_count > 2:
+ kernel_size = self.kernel_size[0]
+ part_num = int(memory_count / 2) + 1
+ input_chunks = torch.chunk(input, part_num, dim=2)
+
+ if kernel_size > 1:
+ input_chunks = [input_chunks[0]] + [
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
+ for i in range(1, len(input_chunks))
+ ]
+
+ output_chunks = []
+ for input_chunk in input_chunks:
+ output_chunks.append(super().forward(input_chunk))
+ output = torch.cat(output_chunks, dim=2)
+ return output
+ else:
+ return super().forward(input)
+
+
+class CogVideoXCausalConv3d(nn.Module):
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of output channels.
+ kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
+ stride (int, optional): Stride of the convolution. Default is 1.
+ dilation (int, optional): Dilation rate of the convolution. Default is 1.
+ pad_mode (str, optional): Padding mode. Default is "constant".
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: int = 1,
+ dilation: int = 1,
+ pad_mode: str = "constant",
+ ):
+ super().__init__()
+
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size,) * 3
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+ self.temporal_dim = 2
+ self.time_kernel_size = time_kernel_size
+
+ stride = (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = CogVideoXSafeConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ )
+
+ self.conv_cache = None
+
+ def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ dim = self.temporal_dim
+ kernel_size = self.time_kernel_size
+ if kernel_size == 1:
+ return inputs
+
+ inputs = inputs.transpose(0, dim)
+
+ if self.conv_cache is not None:
+ inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
+ else:
+ inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
+
+ inputs = inputs.transpose(0, dim).contiguous()
+ return inputs
+
+ def _clear_fake_context_parallel_cache(self):
+ del self.conv_cache
+ self.conv_cache = None
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ input_parallel = self.fake_context_parallel_forward(inputs)
+
+ self._clear_fake_context_parallel_cache()
+ self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
+
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
+
+ output_parallel = self.conv(input_parallel)
+ output = output_parallel
+ return output
+
+
+class CogVideoXSpatialNorm3D(nn.Module):
+ r"""
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
+ to 3D-video like data.
+
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ groups: int = 32,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
+ z_first = F.interpolate(z_first, size=f_first_size)
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
+ zq = torch.cat([z_first, z_rest], dim=2)
+ else:
+ zq = F.interpolate(zq, size=f.shape[-3:])
+
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class CogVideoXResnetBlock3D(nn.Module):
+ r"""
+ A 3D ResNet block used in the CogVideoX model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (Optional[int], optional):
+ Number of output channels. If None, defaults to `in_channels`. Default is None.
+ dropout (float, optional): Dropout rate. Default is 0.0.
+ temb_channels (int, optional): Number of time embedding channels. Default is 512.
+ groups (int, optional): Number of groups for group normalization. Default is 32.
+ eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
+ non_linearity (str, optional): Activation function to use. Default is "swish".
+ conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
+ spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
+ pad_mode (str, optional): Padding mode. Default is "first".
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ groups: int = 32,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ conv_shortcut: bool = False,
+ spatial_norm_dim: Optional[int] = None,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(non_linearity)
+ self.use_conv_shortcut = conv_shortcut
+
+ if spatial_norm_dim is None:
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
+ else:
+ self.norm1 = CogVideoXSpatialNorm3D(
+ f_channels=in_channels,
+ zq_channels=spatial_norm_dim,
+ groups=groups,
+ )
+ self.norm2 = CogVideoXSpatialNorm3D(
+ f_channels=out_channels,
+ zq_channels=spatial_norm_dim,
+ groups=groups,
+ )
+
+ self.conv1 = CogVideoXCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ if temb_channels > 0:
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
+
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = CogVideoXCausalConv3d(
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CogVideoXCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+ else:
+ self.conv_shortcut = CogVideoXSafeConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = inputs
+
+ if zq is not None:
+ hidden_states = self.norm1(hidden_states, zq)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if zq is not None:
+ hidden_states = self.norm2(hidden_states, zq)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ inputs = self.conv_shortcut(inputs)
+
+ hidden_states = hidden_states + inputs
+ return hidden_states
+
+
+class CogVideoXDownBlock3D(nn.Module):
+ r"""
+ A downsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ temb_channels (int): Number of time embedding channels.
+ dropout (float, optional): Dropout rate. Default is 0.0.
+ num_layers (int, optional): Number of layers in the block. Default is 1.
+ resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
+ resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
+ resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
+ add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
+ downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
+ compress_time (bool, optional): If True, apply temporal compression. Default is False.
+ pad_mode (str, optional): Padding mode. Default is "first".
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ add_downsample: bool = True,
+ downsample_padding: int = 0,
+ compress_time: bool = False,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channel,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ pad_mode=pad_mode,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.downsamplers = None
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ CogVideoXDownsample3D(
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
+ )
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class CogVideoXMidBlock3D(nn.Module):
+ r"""
+ A middle block used in the CogVideoX model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ temb_channels (int): Number of time embedding channels.
+ dropout (float, optional): Dropout rate. Default is 0.0.
+ num_layers (int, optional): Number of layers in the block. Default is 1.
+ resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
+ resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
+ resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
+ spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
+ pad_mode (str, optional): Padding mode. Default is "first".
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ spatial_norm_dim: Optional[int] = None,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ spatial_norm_dim=spatial_norm_dim,
+ non_linearity=resnet_act_fn,
+ pad_mode=pad_mode,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ return hidden_states
+
+
+class CogVideoXUpBlock3D(nn.Module):
+ r"""
+ An upsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ temb_channels (int): Number of time embedding channels.
+ dropout (float, optional): Dropout rate. Default is 0.0.
+ num_layers (int, optional): Number of layers in the block. Default is 1.
+ resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
+ resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
+ resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
+ spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
+ add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
+ upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
+ compress_time (bool, optional): If True, apply temporal compression. Default is False.
+ pad_mode (str, optional): Padding mode. Default is "first".
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ spatial_norm_dim: int = 16,
+ add_upsample: bool = True,
+ upsample_padding: int = 1,
+ compress_time: bool = False,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channel,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ spatial_norm_dim=spatial_norm_dim,
+ pad_mode=pad_mode,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.upsamplers = None
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [CogVideoXUpsample3D(out_channels, out_channels, padding=upsample_padding, compress_time=compress_time)]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CogVideoXEncoder3D(nn.Module):
+ r"""
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
+ options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ double_z (`bool`, *optional*, defaults to `True`):
+ Whether to double the number of output channels for the last block.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 16,
+ down_block_types: Tuple[str, ...] = (
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ dropout: float = 0.0,
+ pad_mode: str = "first",
+ temporal_compression_ratio: float = 4,
+ ):
+ super().__init__()
+
+ # log2 of temporal_compress_times
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+ self.down_blocks = nn.ModuleList([])
+
+ # down blocks
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ if down_block_type == "CogVideoXDownBlock3D":
+ down_block = CogVideoXDownBlock3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=layers_per_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ add_downsample=not is_final_block,
+ compress_time=compress_time,
+ )
+ else:
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
+
+ self.down_blocks.append(down_block)
+
+ # mid block
+ self.mid_block = CogVideoXMidBlock3D(
+ in_channels=block_out_channels[-1],
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=2,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ pad_mode=pad_mode,
+ )
+
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CogVideoXCausalConv3d(
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
+ hidden_states = self.conv_in(sample)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # 1. Down
+ for down_block in self.down_blocks:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block), hidden_states, temb, None
+ )
+
+ # 2. Mid
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), hidden_states, temb, None
+ )
+ else:
+ # 1. Down
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states, temb, None)
+
+ # 2. Mid
+ hidden_states = self.mid_block(hidden_states, temb, None)
+
+ # 3. Post-process
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class CogVideoXDecoder3D(nn.Module):
+ r"""
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ norm_type (`str`, *optional*, defaults to `"group"`):
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = (
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ dropout: float = 0.0,
+ pad_mode: str = "first",
+ temporal_compression_ratio: float = 4,
+ ):
+ super().__init__()
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ self.conv_in = CogVideoXCausalConv3d(
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
+ )
+
+ # mid block
+ self.mid_block = CogVideoXMidBlock3D(
+ in_channels=reversed_block_out_channels[0],
+ temb_channels=0,
+ num_layers=2,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ spatial_norm_dim=in_channels,
+ pad_mode=pad_mode,
+ )
+
+ # up blocks
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = reversed_block_out_channels[0]
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ if up_block_type == "CogVideoXUpBlock3D":
+ up_block = CogVideoXUpBlock3D(
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=layers_per_block + 1,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ spatial_norm_dim=in_channels,
+ add_upsample=not is_final_block,
+ compress_time=compress_time,
+ pad_mode=pad_mode,
+ )
+ prev_output_channel = output_channel
+ else:
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
+
+ self.up_blocks.append(up_block)
+
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CogVideoXCausalConv3d(
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
+ hidden_states = self.conv_in(sample)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # 1. Mid
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), hidden_states, temb, sample
+ )
+
+ # 2. Up
+ for up_block in self.up_blocks:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block), hidden_states, temb, sample
+ )
+ else:
+ # 1. Mid
+ hidden_states = self.mid_block(hidden_states, temb, sample)
+
+ # 2. Up
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states, temb, sample)
+
+ # 3. Post-process
+ hidden_states = self.norm_out(hidden_states, sample)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = (
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ up_block_types: Tuple[str] = (
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
+ latent_channels: int = 16,
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ temporal_compression_ratio: float = 4,
+ sample_size: int = 256,
+ scaling_factor: float = 1.15258426,
+ shift_factor: Optional[float] = None,
+ latents_mean: Optional[Tuple[float]] = None,
+ latents_std: Optional[Tuple[float]] = None,
+ force_upcast: float = True,
+ use_quant_conv: bool = False,
+ use_post_quant_conv: bool = False,
+ ):
+ super().__init__()
+
+ self.encoder = CogVideoXEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.decoder = CogVideoXDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
+ module.gradient_checkpointing = value
+
+ def clear_fake_context_parallel_cache(self):
+ for name, module in self.named_modules():
+ if isinstance(module, CogVideoXCausalConv3d):
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
+ module._clear_fake_context_parallel_cache()
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ h = self.encoder(x)
+ if self.quant_conv is not None:
+ h = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(h)
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ @apply_forward_hook
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ if not return_dict:
+ return (dec,)
+ return dec
diff --git a/videosys/models/cogvideo/cogvideox_transformer_3d.py b/videosys/models/cogvideo/cogvideox_transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..975e86b568aa0efb864d014f0ec698f597d87901
--- /dev/null
+++ b/videosys/models/cogvideo/cogvideox_transformer_3d.py
@@ -0,0 +1,339 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+from typing import Any, Dict, Optional, Union
+
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import is_torch_version, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from torch import nn
+
+from .modules import AdaLayerNorm, CogVideoXLayerNormZero, CogVideoXPatchEmbed, get_3d_sincos_pos_embed
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ ) -> torch.Tensor:
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ text_length = norm_encoder_hidden_states.size(1)
+
+ # CogVideoX uses concatenated text + video embeddings with self-attention instead of using
+ # them in cross-attention individually
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ attn_output = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input.
+ out_channels (`int`, *optional*):
+ The number of channels in the output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ patch_size (`int`, *optional*):
+ The size of the patches to use in the patch embedding layer.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states. During inference, you can denoise for up to but not more steps than
+ `num_embeds_ada_norm`.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
+ caption_channels (`int`, *optional*):
+ The number of channels in the caption embeddings.
+ video_length (`int`, *optional*):
+ The number of frames in the video-like data.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: Optional[int] = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ post_patch_height = sample_height // patch_size
+ post_patch_width = sample_width // patch_size
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. 3D positional embeddings
+ spatial_pos_embedding = get_3d_sincos_pos_embed(
+ inner_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ spatial_interpolation_scale,
+ temporal_interpolation_scale,
+ )
+ spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
+ pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
+ pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=False)
+
+ # 3. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 4. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 5. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+
+ # 3. Position embedding
+ seq_length = height * width * num_frames // (self.config.patch_size**2)
+
+ pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
+ hidden_states = hidden_states + pos_embeds
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
+ hidden_states = hidden_states[:, self.config.max_text_seq_length :]
+
+ # 5. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ )
+
+ hidden_states = self.norm_final(hidden_states)
+
+ # 6. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 7. Unpatchify
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/videosys/models/cogvideo/modules.py b/videosys/models/cogvideo/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d5dc49515ce5423026c130581fc16f4155333d4
--- /dev/null
+++ b/videosys/models/cogvideo/modules.py
@@ -0,0 +1,317 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
+
+
+class CogVideoXDownsample3D(nn.Module):
+ # Todo: Wait for paper relase.
+ r"""
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `2`):
+ Stride of the convolution.
+ padding (`int`, defaults to `0`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 2,
+ padding: int = 0,
+ compress_time: bool = False,
+ ):
+ super().__init__()
+
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ batch_size, channels, frames, height, width = x.shape
+
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
+
+ if x.shape[-1] % 2 == 1:
+ x_first, x_rest = x[..., 0], x[..., 1:]
+ if x_rest.shape[-1] > 0:
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
+
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+ else:
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+
+ # Pad the tensor
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+ batch_size, channels, frames, height, width = x.shape
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
+ x = self.conv(x)
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
+ return x
+
+
+class CogVideoXUpsample3D(nn.Module):
+ r"""
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `1`):
+ Stride of the convolution.
+ padding (`int`, defaults to `1`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ compress_time: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
+ # split first frame
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
+
+ x_first = F.interpolate(x_first, scale_factor=2.0)
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
+ x_first = x_first[:, :, None, :, :]
+ inputs = torch.cat([x_first, x_rest], dim=2)
+ elif inputs.shape[2] > 1:
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ else:
+ inputs = inputs.squeeze(2)
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs[:, :, None, :, :]
+ else:
+ # only interpolate 2D
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
+
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = self.conv(inputs)
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ return inputs
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim: int,
+ spatial_size: Union[int, Tuple[int, int]],
+ temporal_size: int,
+ spatial_interpolation_scale: float = 1.0,
+ temporal_interpolation_scale: float = 1.0,
+) -> np.ndarray:
+ r"""
+ Args:
+ embed_dim (`int`):
+ spatial_size (`int` or `Tuple[int, int]`):
+ temporal_size (`int`):
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ """
+ if embed_dim % 4 != 0:
+ raise ValueError("`embed_dim` must be divisible by 4")
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+
+ # 1. Spatial
+ grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
+ grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
+
+ # 2. Temporal
+ grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ # 3. Concat
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
+
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
+ return pos_embed
+
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+ return embeds
+
+
+class CogVideoXLayerNormZero(nn.Module):
+ def __init__(
+ self,
+ conditioning_dim: int,
+ embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
+ output_dim (`int`, *optional*):
+ norm_elementwise_affine (`bool`, defaults to `False):
+ norm_eps (`bool`, defaults to `False`):
+ chunk_dim (`int`, defaults to `0`):
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_embeddings: Optional[int] = None,
+ output_dim: Optional[int] = None,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+
+ self.chunk_dim = chunk_dim
+ output_dim = output_dim or embedding_dim * 2
+
+ if num_embeddings is not None:
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if self.emb is not None:
+ temb = self.emb(timestep)
+
+ temb = self.linear(self.silu(temb))
+
+ if self.chunk_dim == 1:
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
+ # other if-branch. This branch is specific to CogVideoX for now.
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ else:
+ scale, shift = temb.chunk(2, dim=0)
+
+ x = self.norm(x) * (1 + scale) + shift
+ return x
diff --git a/videosys/models/cogvideo/pipeline.py b/videosys/models/cogvideo/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5334d9289a8df333dce27608a659b137d2b798eb
--- /dev/null
+++ b/videosys/models/cogvideo/pipeline.py
@@ -0,0 +1,692 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+import inspect
+import math
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from transformers import T5EncoderModel, T5Tokenizer
+
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.utils.utils import save_video
+
+from .autoencoder_kl import AutoencoderKLCogVideoX
+from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from .retrieve_timesteps import retrieve_timesteps
+from .scheduling import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+from videosys.core.pab_mgr import (
+ PABConfig,
+ get_diffusion_skip,
+ get_diffusion_skip_timestep,
+ set_pab_manager,
+ skip_diffusion_timestep,
+ update_steps,
+)
+
+
+
+class CogVideoPABConfig(PABConfig):
+ def __init__(
+ self,
+ steps: int = 150,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [100, 850],
+ spatial_gap: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [100, 850],
+ temporal_gap: int = 4,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [100, 850],
+ cross_gap: int = 6,
+ diffusion_skip: bool = False,
+ diffusion_timestep_respacing: list = None,
+ diffusion_skip_timestep: list = None,
+ mlp_skip: bool = True,
+ mlp_spatial_skip_config: dict = {
+ 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ },
+ mlp_temporal_skip_config: dict = {
+ 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ },
+ full_broadcast: bool = True,
+ full_threshold: list = [100, 850],
+ full_gap: int = 3,
+ ):
+ super().__init__(
+ steps=steps,
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_gap=spatial_gap,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_gap=temporal_gap,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_gap=cross_gap,
+ diffusion_skip=diffusion_skip,
+ diffusion_timestep_respacing=diffusion_timestep_respacing,
+ diffusion_skip_timestep=diffusion_skip_timestep,
+ mlp_skip=mlp_skip,
+ mlp_spatial_skip_config=mlp_spatial_skip_config,
+ mlp_temporal_skip_config=mlp_temporal_skip_config,
+ full_broadcast=full_broadcast,
+ full_threshold=full_threshold,
+ full_gap=full_gap,
+ )
+
+
+
+class CogVideoConfig:
+ def __init__(
+ self,
+ world_size: int = 1,
+ model_path: str = "THUDM/CogVideoX-2b",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ enable_pab: bool = False,
+ pab_config = CogVideoPABConfig()
+ ):
+ # ======= engine ========
+ self.world_size = world_size
+
+ # ======= pipeline ========
+ self.pipeline_cls = CogVideoPipeline
+
+ # ======= model ========
+ self.model_path = model_path
+ self.num_inference_steps = num_inference_steps
+ self.guidance_scale = guidance_scale
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class CogVideoPipeline(VideoSysPipeline):
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ config: CogVideoConfig,
+ tokenizer: Optional[T5Tokenizer] = None,
+ text_encoder: Optional[T5EncoderModel] = None,
+ vae: Optional[AutoencoderKLCogVideoX] = None,
+ transformer: Optional[CogVideoXTransformer3DModel] = None,
+ scheduler: Optional[CogVideoXDDIMScheduler] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.float16,
+ ):
+ super().__init__()
+ self._config = config
+ self._device = device
+ self._dtype = dtype
+
+ if transformer is None:
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ config.model_path, subfolder="transformer", torch_dtype=self._dtype
+ )
+ if vae is None:
+ vae = AutoencoderKLCogVideoX.from_pretrained(config.model_path, subfolder="vae", torch_dtype=self._dtype)
+ if tokenizer is None:
+ tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
+ if text_encoder is None:
+ text_encoder = T5EncoderModel.from_pretrained(
+ config.model_path, subfolder="text_encoder", torch_dtype=self._dtype
+ )
+ if scheduler is None:
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
+ config.model_path,
+ subfolder="scheduler",
+ )
+
+ # set eval and device
+ self.set_eval_and_device(self._device, text_encoder, vae, transformer)
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ torch.cuda.empty_cache()
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor, num_seconds: int):
+ print("hhhhhhhh")
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ frames = []
+ num_frames = latents.size(2)
+ segment_size = num_frames // num_frames # 每段处理的帧数
+
+ for i in range(num_frames): # 显存问题,逐帧处理
+ start_frame = i * segment_size
+ end_frame = start_frame + segment_size if i < num_frames-1 else num_frames
+
+ current_latents = latents[:, :, start_frame:end_frame, :, :]
+ try:
+ current_frames = self.vae.decode(current_latents).sample
+ frames.append(current_frames)
+ except RuntimeError as e:
+ logger.error(f"CUDA out of memory error: {str(e)}")
+ raise e
+
+ # 清理缓存
+ torch.cuda.empty_cache()
+
+ self.vae.clear_fake_context_parallel_cache()
+
+ frames = torch.cat(frames, dim=2)
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 48,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[VideoSysPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ fps = 8
+ assert (
+ num_frames <= 48 and num_frames % fps == 0 and fps == 8
+ ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ num_frames += 1
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents, num_frames // fps)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ if not return_dict:
+ return (video,)
+
+ return VideoSysPipelineOutput(video=video)
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=8)
diff --git a/videosys/models/cogvideo/retrieve_timesteps.py b/videosys/models/cogvideo/retrieve_timesteps.py
new file mode 100644
index 0000000000000000000000000000000000000000..9702ec47a610e7f3f778a98572ffcac6cfb7a6d0
--- /dev/null
+++ b/videosys/models/cogvideo/retrieve_timesteps.py
@@ -0,0 +1,74 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+import inspect
+from typing import List, Optional, Union
+
+import torch
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
diff --git a/videosys/models/cogvideo/scheduling.py b/videosys/models/cogvideo/scheduling.py
new file mode 100644
index 0000000000000000000000000000000000000000..06a4e0f01f250060a5791a494a50c7186908b55b
--- /dev/null
+++ b/videosys/models/cogvideo/scheduling.py
@@ -0,0 +1,813 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+# --------------------------------------------------------
+
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(alphas_cumprod):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+
+ return alphas_bar
+
+
+class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.0120,
+ beta_schedule: str = "scaled_linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ snr_shift_scale: float = 3.0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # Modify: SNR shift following SD3
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ # pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
+ b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t
+
+ prev_sample = a_t * sample + b_t * pred_original_sample
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+
+class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.0120,
+ beta_schedule: str = "scaled_linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ snr_shift_scale: float = 3.0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # Modify: SNR shift following SD3
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
+ lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
+ lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
+ h = lamb_next - lamb
+
+ if alpha_prod_t_back is not None:
+ lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
+ h_last = lamb - lamb_previous
+ r = h_last / h
+ return h, r, lamb, lamb_next
+ else:
+ return h, None, lamb, lamb_next
+
+ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
+ mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
+ mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
+
+ if alpha_prod_t_back is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ old_pred_original_sample: torch.Tensor,
+ timestep: int,
+ timestep_back: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = False,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ # pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
+ mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
+ mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
+
+ noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
+ prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise
+
+ if old_pred_original_sample is None or prev_timestep < 0:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return prev_sample, pred_original_sample
+ else:
+ denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
+ noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
+ x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
+
+ prev_sample = x_advanced
+
+ if not return_dict:
+ return (prev_sample, pred_original_sample)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/videosys/models/latte/__init__.py b/videosys/models/latte/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e277da72ca2bcc598f1b4334bc6303e6207f175
--- /dev/null
+++ b/videosys/models/latte/__init__.py
@@ -0,0 +1,7 @@
+from .pipeline import LatteConfig, LattePABConfig, LattePipeline
+
+__all__ = [
+ "LattePipeline",
+ "LattePABConfig",
+ "LatteConfig",
+]
diff --git a/videosys/models/latte/latte_t2v.py b/videosys/models/latte/latte_t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0075902216c4351c96c8376deb1bab68e53024
--- /dev/null
+++ b/videosys/models/latte/latte_t2v.py
@@ -0,0 +1,1477 @@
+# Adapted from Latte
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Latte: https://github.com/Vchitect/Latte
+# --------------------------------------------------------
+
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import (
+ ImagePositionalEmbeddings,
+ PatchEmbed,
+ PixArtAlphaCombinedTimestepSizeEmbeddings,
+ PixArtAlphaTextProjection,
+ SinusoidalPositionalEmbedding,
+ get_1d_sincos_pos_embed_from_grid,
+)
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange, repeat
+from torch import nn
+
+from videosys.core.comm import (
+ all_to_all_with_pad,
+ gather_sequence,
+ get_spatial_pad,
+ get_temporal_pad,
+ set_spatial_pad,
+ set_temporal_pad,
+ split_sequence,
+)
+from videosys.core.pab_mgr import (
+ enable_pab,
+ get_mlp_output,
+ if_broadcast_cross,
+ if_broadcast_mlp,
+ if_broadcast_spatial,
+ if_broadcast_temporal,
+ save_mlp_output,
+)
+from videosys.core.parallel_mgr import (
+ enable_sequence_parallel,
+ get_cfg_parallel_group,
+ get_cfg_parallel_size,
+ get_sequence_parallel_group,
+)
+from videosys.utils.utils import batch_func
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+ for module in self.net:
+ if isinstance(module, compatible_cls):
+ hidden_states = module(hidden_states, scale)
+ else:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+ ada_norm_bias: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_zero":
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm1 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if norm_type == "ada_norm":
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm2 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if norm_type == "ada_norm_continuous":
+ self.norm3 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "layer_norm",
+ )
+
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ elif norm_type == "layer_norm_i2vgen":
+ self.norm3 = None
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if norm_type == "ada_norm_single":
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.cross_last = None
+ self.cross_count = 0
+ self.spatial_last = None
+ self.spatial_count = 0
+ self.block_idx = block_idx
+ self.mlp_count = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def set_cross_last(self, last_out: torch.Tensor):
+ self.cross_last = last_out
+
+ def set_spatial_last(self, last_out: torch.Tensor):
+ self.spatial_last = last_out
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+ # 1. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ if enable_pab():
+ broadcast_spatial, self.spatial_count = if_broadcast_spatial(
+ int(org_timestep[0]), self.spatial_count, self.block_idx
+ )
+
+ if enable_pab() and broadcast_spatial:
+ attn_output = self.spatial_last
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.norm_type == "ada_norm_zero":
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.set_spatial_last(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1.2 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ broadcast_cross, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count)
+ if broadcast_cross:
+ hidden_states = hidden_states + self.cross_last
+ else:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.norm_type == "ada_norm_single":
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if enable_pab():
+ self.set_cross_last(attn_output)
+
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # i2vgen doesn't have this norm 🤷♂️
+ if enable_pab():
+ broadcast_mlp, self.mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=False,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=False,
+ )
+ else:
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ # spatial
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=False,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock_(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) # go here
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # # 2. Cross-Attn
+ # if cross_attention_dim is not None or double_self_attention:
+ # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # # the second cross attention block.
+ # self.norm2 = (
+ # AdaLayerNorm(dim, num_embeds_ada_norm)
+ # if self.use_ada_layer_norm
+ # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ # )
+ # self.attn2 = Attention(
+ # query_dim=dim,
+ # cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # upcast_attention=upcast_attention,
+ # ) # is self-attn if encoder_hidden_states is none
+ # else:
+ # self.norm2 = None
+ # self.attn2 = None
+
+ # 3. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.last_out = None
+ self.mlp_count = 0
+ self.block_idx = block_idx
+ self.count = 0
+
+ def set_last_out(self, last_out: torch.Tensor):
+ self.last_out = last_out
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ if enable_pab():
+ broadcast_temporal, self.count = if_broadcast_temporal(int(org_timestep[0]), self.count)
+
+ if enable_pab() and broadcast_temporal:
+ attn_output = self.last_out
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single: # go here
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ # norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ if enable_sequence_parallel():
+ norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if enable_sequence_parallel():
+ attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.last_out = attn_output
+
+ hidden_states = attn_output + hidden_states
+
+ if enable_pab():
+ broadcast_mlp, self.mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=True,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=True,
+ )
+ else:
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=True,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+ def dynamic_switch(self, x, to_spatial_shard: bool):
+ if to_spatial_shard:
+ scatter_dim, gather_dim = 0, 1
+ scatter_pad = get_spatial_pad()
+ gather_pad = get_temporal_pad()
+ else:
+ scatter_dim, gather_dim = 1, 0
+ scatter_pad = get_temporal_pad()
+ gather_pad = get_spatial_pad()
+ x = all_to_all_with_pad(
+ x,
+ get_sequence_parallel_group(),
+ scatter_dim=scatter_dim,
+ gather_dim=gather_dim,
+ scatter_pad=scatter_pad,
+ gather_pad=gather_pad,
+ )
+ return x
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ batch_size: int = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
+ )
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class LatteT2V(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ video_length: int = 16,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.video_length = video_length
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = linear_cls(in_channels, inner_dim)
+ else:
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale = max(interpolation_scale, 1)
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # Define temporal transformers blocks
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock_( # one attention
+ inner_dim,
+ num_attention_heads, # num_attention_heads
+ attention_head_dim, # attention_head_dim 72
+ dropout=dropout,
+ cross_attention_dim=None,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=False,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, in_channels)
+ else:
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches and norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128 # False, 128 -> 1024
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ # define temporal positional embedding
+ temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ enable_temporal_attentions: bool = True,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+
+ # 0. Split batch for data parallelism
+ if get_cfg_parallel_size() > 1:
+ (
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ added_cond_kwargs,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ ) = batch_func(
+ partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ added_cond_kwargs,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ )
+
+ input_batch_size, c, frame, h, w = hidden_states.shape
+ frame = frame - use_image_num
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ org_timestep = timestep
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+ encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
+ elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
+ encoder_attention_mask_video = repeat(
+ encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
+ ).contiguous()
+ encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
+ encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
+ encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
+
+ # Retrieve lora scale.
+ cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 1. Input
+ if self.is_input_patches: # here
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ num_patches = height * width
+
+ hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ # batch_size = hidden_states.shape[0]
+ batch_size = input_batch_size
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
+
+ if use_image_num != 0 and self.training:
+ encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
+ encoder_hidden_states_video = repeat(
+ encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
+ ).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
+ else:
+ encoder_hidden_states_spatial = repeat(
+ encoder_hidden_states, "b t d -> (b f) t d", f=frame
+ ).contiguous()
+
+ # prepare timesteps for spatial and temporal block
+ timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
+
+ if enable_sequence_parallel():
+ set_temporal_pad(frame + use_image_num)
+ set_spatial_pad(num_patches)
+ hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
+ encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
+ timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
+ temp_pos_embed = split_sequence(
+ self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ )
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ spatial_block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ if enable_temporal_attentions:
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0: # image-video joitn training
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ if i == 0:
+ hidden_states_video = hidden_states_video + temp_pos_embed
+
+ hidden_states_video = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ if i == 0:
+ hidden_states = hidden_states + temp_pos_embed
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+ else:
+ hidden_states = spatial_block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ None,
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ if enable_temporal_attentions:
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0 and self.training:
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ hidden_states_video = temp_block(
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ org_timestep,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ if i == 0 and frame > 1:
+ hidden_states = hidden_states + temp_pos_embed
+ hidden_states = temp_block(
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ if enable_sequence_parallel():
+ hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
+
+ # 3. Gather batch for data parallelism
+ if get_cfg_parallel_size() > 1:
+ output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+ def get_1d_sincos_temp_embed(self, embed_dim, length):
+ pos = torch.arange(0, length).unsqueeze(1)
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+
+ def split_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+ def gather_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
+ x = x.reshape(-1, *x.shape[2:])
+ return x
diff --git a/videosys/models/latte/pipeline.py b/videosys/models/latte/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad742a89db778539d53c9f20fa73f6c715e0c9e3
--- /dev/null
+++ b/videosys/models/latte/pipeline.py
@@ -0,0 +1,915 @@
+# Adapted from Latte
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Latte: https://github.com/Vchitect/Latte
+# --------------------------------------------------------
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import einops
+import ftfy
+import torch
+import torch.distributed as dist
+import tqdm
+from bs4 import BeautifulSoup
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import T5EncoderModel, T5Tokenizer
+
+from videosys.core.pab_mgr import (
+ PABConfig,
+ get_diffusion_skip,
+ get_diffusion_skip_timestep,
+ set_pab_manager,
+ skip_diffusion_timestep,
+ update_steps,
+)
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.utils.logging import logger
+from videosys.utils.utils import save_video
+
+from .latte_t2v import LatteT2V
+
+
+class LattePABConfig(PABConfig):
+ def __init__(
+ self,
+ steps: int = 50,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [100, 800],
+ spatial_gap: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [100, 800],
+ temporal_gap: int = 3,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [100, 800],
+ cross_gap: int = 6,
+ diffusion_skip: bool = False,
+ diffusion_timestep_respacing: list = None,
+ diffusion_skip_timestep: list = None,
+ mlp_skip: bool = True,
+ mlp_spatial_skip_config: dict = {
+ 720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ mlp_temporal_skip_config: dict = {
+ 720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ ):
+ super().__init__(
+ steps=steps,
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_gap=spatial_gap,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_gap=temporal_gap,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_gap=cross_gap,
+ diffusion_skip=diffusion_skip,
+ diffusion_timestep_respacing=diffusion_timestep_respacing,
+ diffusion_skip_timestep=diffusion_skip_timestep,
+ mlp_skip=mlp_skip,
+ mlp_spatial_skip_config=mlp_spatial_skip_config,
+ mlp_temporal_skip_config=mlp_temporal_skip_config,
+ )
+
+
+class LatteConfig:
+ def __init__(
+ self,
+ world_size: int = 1,
+ model_path: str = "maxin-cn/Latte-1",
+ enable_vae_temporal_decoder: bool = True,
+ # ======= scheduler ========
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ variance_type: str = "learned_range",
+ # ======= pab ========
+ enable_pab: bool = False,
+ pab_config: PABConfig = LattePABConfig(),
+ ):
+ # ======= engine ========
+ self.world_size = world_size
+
+ # ======= pipeline ========
+ self.pipeline_cls = LattePipeline
+
+ # ======= model ========
+ self.model_path = model_path
+ self.enable_vae_temporal_decoder = enable_vae_temporal_decoder
+
+ # ======= scheduler ========
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.beta_schedule = beta_schedule
+ self.variance_type = variance_type
+
+ # ======= pab ========
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class LattePipeline(VideoSysPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ config: LatteConfig,
+ tokenizer: Optional[T5Tokenizer] = None,
+ text_encoder: Optional[T5EncoderModel] = None,
+ vae: Optional[AutoencoderKL] = None,
+ transformer: Optional[LatteT2V] = None,
+ scheduler: Optional[DDIMScheduler] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.float16,
+ ):
+ super().__init__()
+ self._config = config
+
+ # initialize the model if not provided
+ if transformer is None:
+ transformer = LatteT2V.from_pretrained(config.model_path, subfolder="transformer", video_length=16).to(
+ dtype=dtype
+ )
+ if vae is None:
+ if config.enable_vae_temporal_decoder:
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
+ config.model_path, subfolder="vae_temporal_decoder", torch_dtype=dtype
+ )
+ else:
+ vae = AutoencoderKL.from_pretrained(config.model_path, subfolder="vae", torch_dtype=dtype)
+ if tokenizer is None:
+ tokenizer = T5Tokenizer.from_pretrained(config.model_path, subfolder="tokenizer")
+ if text_encoder is None:
+ text_encoder = T5EncoderModel.from_pretrained(
+ config.model_path, subfolder="text_encoder", torch_dtype=dtype
+ )
+ if scheduler is None:
+ scheduler = DDIMScheduler.from_pretrained(
+ config.model_path,
+ subfolder="scheduler",
+ beta_start=config.beta_start,
+ beta_end=config.beta_end,
+ beta_schedule=config.beta_schedule,
+ variance_type=config.variance_type,
+ clip_sample=False,
+ )
+
+ # pab
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+ # set eval and device
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
+ else:
+ masked_feature = emb * mask[:, None, :, None] # 1 120 4096
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ mask_feature: bool = True,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (bool, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ mask_feature: (bool, defaults to `True`):
+ If `True`, the function will mask the text embeddings.
+ """
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = 120
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds_attention_mask = attention_mask
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ # Perform additional masking.
+ if mask_feature and not embeds_initially_provided:
+ prompt_embeds = prompt_embeds.unsqueeze(1)
+ masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
+ masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
+ masked_negative_prompt_embeds = (
+ negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
+ )
+
+ # import torch.nn.functional as F
+
+ # padding = (0, 0, 0, 113) # (左, 右, 下, 上)
+ # masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
+ # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
+
+ # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
+
+ return masked_prompt_embeds, masked_negative_prompt_embeds
+ # return masked_prompt_embeds_, masked_negative_prompt_embeds_
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ video_length,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt: str = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ mask_feature: bool = True,
+ enable_temporal_attentions: bool = True,
+ verbose: bool = True,
+ ) -> Union[VideoSysPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Latte can only generate video of 16 frames 512x512.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+ enable_temporal_attentions (`bool`, defaults to `True`):
+ If `True`, the model will use temporal attentions to generate the video.
+ verbose (`bool`, *optional*, defaults to `True`):
+ Whether to print progress bars and other information during inference.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ video_length = 16
+ height = 512
+ width = 512
+ update_steps(num_inference_steps)
+ self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self.text_encoder.device or self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ mask_feature=mask_feature,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ # timesteps = self.scheduler.timesteps # NOTE change timestep_respacing here
+
+ if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
+ # TODO add assertion for timestep_respacing
+ # timestep_respacing = get_diffusion_skip_timestep()
+ # timesteps = space_timesteps(1000, timestep_respacing)
+
+ diffusion_skip_timestep = get_diffusion_skip_timestep()
+ timesteps = skip_diffusion_timestep(self.scheduler.timesteps, diffusion_skip_timestep)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ orignal_timesteps = self.scheduler.timesteps
+
+ if verbose and dist.get_rank() == 0:
+ print("============================")
+ print("skip diffusion steps!!!")
+ print("============================")
+ print(f"orignal sample timesteps: {orignal_timesteps}")
+ print(f"orignal diffusion steps: {len(orignal_timesteps)}")
+ print("============================")
+ print(f"skip diffusion steps: {get_diffusion_skip_timestep()}")
+ print(f"sample timesteps: {timesteps}")
+ print(f"num_inference_steps: {len(timesteps)}")
+ print("============================")
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ video_length,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
+ for i, t in progress_wrap(list(enumerate(timesteps))):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ all_timesteps=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ enable_temporal_attentions=enable_temporal_attentions,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latents":
+ if latents.shape[2] == 1: # image
+ video = self.decode_latents_image(latents)
+ else: # video
+ if self._config.enable_vae_temporal_decoder:
+ video = self.decode_latents_with_temporal_decoder(latents)
+ else:
+ video = self.decode_latents(latents)
+ else:
+ video = latents
+ return VideoSysPipelineOutput(video=video)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return VideoSysPipelineOutput(video=video)
+
+ def decode_latents_image(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for frame_idx in range(latents.shape[0]):
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
+ video = torch.cat(video)
+ video = einops.rearrange(video, "(b f) c h w -> b f c h w", f=video_length)
+ video = (video / 2.0 + 0.5).clamp(0, 1)
+ return video
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for frame_idx in range(latents.shape[0]):
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
+ video = torch.cat(video)
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ return video
+
+ def decode_latents_with_temporal_decoder(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+
+ decode_chunk_size = 14
+ for frame_idx in range(0, latents.shape[0], decode_chunk_size):
+ num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0]
+
+ decode_kwargs = {}
+ decode_kwargs["num_frames"] = num_frames_in
+
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + decode_chunk_size], **decode_kwargs).sample)
+
+ video = torch.cat(video)
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ return video
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=8)
diff --git a/videosys/models/open_sora/__init__.py b/videosys/models/open_sora/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8d92196e09ed1e3707a96162b26e40c751a6d4a
--- /dev/null
+++ b/videosys/models/open_sora/__init__.py
@@ -0,0 +1,7 @@
+from .pipeline import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
+
+__all__ = [
+ "OpenSoraConfig",
+ "OpenSoraPABConfig",
+ "OpenSoraPipeline",
+]
diff --git a/videosys/models/open_sora/datasets.py b/videosys/models/open_sora/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..a75a711cfb55ff7b3802b7a596101712dc266de5
--- /dev/null
+++ b/videosys/models/open_sora/datasets.py
@@ -0,0 +1,788 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+
+import numbers
+import os
+import re
+
+import numpy as np
+import requests
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from PIL import Image
+from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
+from torchvision.io import write_video
+from torchvision.utils import save_image
+
+IMG_FPS = 120
+VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
+
+regex = re.compile(
+ r"^(?:http|ftp)s?://" # http:// or https://
+ r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
+ r"localhost|" # localhost...
+ r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
+ r"(?::\d+)?" # optional port
+ r"(?:/?|[/?]\S+)$",
+ re.IGNORECASE,
+)
+
+# H:W
+ASPECT_RATIO_MAP = {
+ "3:8": "0.38",
+ "9:21": "0.43",
+ "12:25": "0.48",
+ "1:2": "0.50",
+ "9:17": "0.53",
+ "27:50": "0.54",
+ "9:16": "0.56",
+ "5:8": "0.62",
+ "2:3": "0.67",
+ "3:4": "0.75",
+ "1:1": "1.00",
+ "4:3": "1.33",
+ "3:2": "1.50",
+ "16:9": "1.78",
+ "17:9": "1.89",
+ "2:1": "2.00",
+ "50:27": "2.08",
+}
+
+
+# computed from above code
+# S = 8294400
+ASPECT_RATIO_4K = {
+ "0.38": (1764, 4704),
+ "0.43": (1886, 4400),
+ "0.48": (1996, 4158),
+ "0.50": (2036, 4072),
+ "0.53": (2096, 3960),
+ "0.54": (2118, 3918),
+ "0.62": (2276, 3642),
+ "0.56": (2160, 3840), # base
+ "0.67": (2352, 3528),
+ "0.75": (2494, 3326),
+ "1.00": (2880, 2880),
+ "1.33": (3326, 2494),
+ "1.50": (3528, 2352),
+ "1.78": (3840, 2160),
+ "1.89": (3958, 2096),
+ "2.00": (4072, 2036),
+ "2.08": (4156, 1994),
+}
+
+# S = 3686400
+ASPECT_RATIO_2K = {
+ "0.38": (1176, 3136),
+ "0.43": (1256, 2930),
+ "0.48": (1330, 2770),
+ "0.50": (1358, 2716),
+ "0.53": (1398, 2640),
+ "0.54": (1412, 2612),
+ "0.56": (1440, 2560), # base
+ "0.62": (1518, 2428),
+ "0.67": (1568, 2352),
+ "0.75": (1662, 2216),
+ "1.00": (1920, 1920),
+ "1.33": (2218, 1664),
+ "1.50": (2352, 1568),
+ "1.78": (2560, 1440),
+ "1.89": (2638, 1396),
+ "2.00": (2716, 1358),
+ "2.08": (2772, 1330),
+}
+
+# S = 2073600
+ASPECT_RATIO_1080P = {
+ "0.38": (882, 2352),
+ "0.43": (942, 2198),
+ "0.48": (998, 2080),
+ "0.50": (1018, 2036),
+ "0.53": (1048, 1980),
+ "0.54": (1058, 1958),
+ "0.56": (1080, 1920), # base
+ "0.62": (1138, 1820),
+ "0.67": (1176, 1764),
+ "0.75": (1248, 1664),
+ "1.00": (1440, 1440),
+ "1.33": (1662, 1246),
+ "1.50": (1764, 1176),
+ "1.78": (1920, 1080),
+ "1.89": (1980, 1048),
+ "2.00": (2036, 1018),
+ "2.08": (2078, 998),
+}
+
+# S = 921600
+ASPECT_RATIO_720P = {
+ "0.38": (588, 1568),
+ "0.43": (628, 1466),
+ "0.48": (666, 1388),
+ "0.50": (678, 1356),
+ "0.53": (698, 1318),
+ "0.54": (706, 1306),
+ "0.56": (720, 1280), # base
+ "0.62": (758, 1212),
+ "0.67": (784, 1176),
+ "0.75": (832, 1110),
+ "1.00": (960, 960),
+ "1.33": (1108, 832),
+ "1.50": (1176, 784),
+ "1.78": (1280, 720),
+ "1.89": (1320, 698),
+ "2.00": (1358, 680),
+ "2.08": (1386, 666),
+}
+
+# S = 409920
+ASPECT_RATIO_480P = {
+ "0.38": (392, 1046),
+ "0.43": (420, 980),
+ "0.48": (444, 925),
+ "0.50": (452, 904),
+ "0.53": (466, 880),
+ "0.54": (470, 870),
+ "0.56": (480, 854), # base
+ "0.62": (506, 810),
+ "0.67": (522, 784),
+ "0.75": (554, 738),
+ "1.00": (640, 640),
+ "1.33": (740, 555),
+ "1.50": (784, 522),
+ "1.78": (854, 480),
+ "1.89": (880, 466),
+ "2.00": (906, 454),
+ "2.08": (924, 444),
+}
+
+# S = 230400
+ASPECT_RATIO_360P = {
+ "0.38": (294, 784),
+ "0.43": (314, 732),
+ "0.48": (332, 692),
+ "0.50": (340, 680),
+ "0.53": (350, 662),
+ "0.54": (352, 652),
+ "0.56": (360, 640), # base
+ "0.62": (380, 608),
+ "0.67": (392, 588),
+ "0.75": (416, 554),
+ "1.00": (480, 480),
+ "1.33": (554, 416),
+ "1.50": (588, 392),
+ "1.78": (640, 360),
+ "1.89": (660, 350),
+ "2.00": (678, 340),
+ "2.08": (692, 332),
+}
+
+# S = 102240
+ASPECT_RATIO_240P = {
+ "0.38": (196, 522),
+ "0.43": (210, 490),
+ "0.48": (222, 462),
+ "0.50": (226, 452),
+ "0.53": (232, 438),
+ "0.54": (236, 436),
+ "0.56": (240, 426), # base
+ "0.62": (252, 404),
+ "0.67": (262, 393),
+ "0.75": (276, 368),
+ "1.00": (320, 320),
+ "1.33": (370, 278),
+ "1.50": (392, 262),
+ "1.78": (426, 240),
+ "1.89": (440, 232),
+ "2.00": (452, 226),
+ "2.08": (462, 222),
+}
+
+# S = 36864
+ASPECT_RATIO_144P = {
+ "0.38": (117, 312),
+ "0.43": (125, 291),
+ "0.48": (133, 277),
+ "0.50": (135, 270),
+ "0.53": (139, 262),
+ "0.54": (141, 260),
+ "0.56": (144, 256), # base
+ "0.62": (151, 241),
+ "0.67": (156, 234),
+ "0.75": (166, 221),
+ "1.00": (192, 192),
+ "1.33": (221, 165),
+ "1.50": (235, 156),
+ "1.78": (256, 144),
+ "1.89": (263, 139),
+ "2.00": (271, 135),
+ "2.08": (277, 132),
+}
+
+# from PixArt
+# S = 8294400
+ASPECT_RATIO_2880 = {
+ "0.25": (1408, 5760),
+ "0.26": (1408, 5568),
+ "0.27": (1408, 5376),
+ "0.28": (1408, 5184),
+ "0.32": (1600, 4992),
+ "0.33": (1600, 4800),
+ "0.34": (1600, 4672),
+ "0.4": (1792, 4480),
+ "0.42": (1792, 4288),
+ "0.47": (1920, 4096),
+ "0.49": (1920, 3904),
+ "0.51": (1920, 3776),
+ "0.55": (2112, 3840),
+ "0.59": (2112, 3584),
+ "0.68": (2304, 3392),
+ "0.72": (2304, 3200),
+ "0.78": (2496, 3200),
+ "0.83": (2496, 3008),
+ "0.89": (2688, 3008),
+ "0.93": (2688, 2880),
+ "1.0": (2880, 2880),
+ "1.07": (2880, 2688),
+ "1.12": (3008, 2688),
+ "1.21": (3008, 2496),
+ "1.28": (3200, 2496),
+ "1.39": (3200, 2304),
+ "1.47": (3392, 2304),
+ "1.7": (3584, 2112),
+ "1.82": (3840, 2112),
+ "2.03": (3904, 1920),
+ "2.13": (4096, 1920),
+ "2.39": (4288, 1792),
+ "2.5": (4480, 1792),
+ "2.92": (4672, 1600),
+ "3.0": (4800, 1600),
+ "3.12": (4992, 1600),
+ "3.68": (5184, 1408),
+ "3.82": (5376, 1408),
+ "3.95": (5568, 1408),
+ "4.0": (5760, 1408),
+}
+
+# S = 4194304
+ASPECT_RATIO_2048 = {
+ "0.25": (1024, 4096),
+ "0.26": (1024, 3968),
+ "0.27": (1024, 3840),
+ "0.28": (1024, 3712),
+ "0.32": (1152, 3584),
+ "0.33": (1152, 3456),
+ "0.35": (1152, 3328),
+ "0.4": (1280, 3200),
+ "0.42": (1280, 3072),
+ "0.48": (1408, 2944),
+ "0.5": (1408, 2816),
+ "0.52": (1408, 2688),
+ "0.57": (1536, 2688),
+ "0.6": (1536, 2560),
+ "0.68": (1664, 2432),
+ "0.72": (1664, 2304),
+ "0.78": (1792, 2304),
+ "0.82": (1792, 2176),
+ "0.88": (1920, 2176),
+ "0.94": (1920, 2048),
+ "1.0": (2048, 2048),
+ "1.07": (2048, 1920),
+ "1.13": (2176, 1920),
+ "1.21": (2176, 1792),
+ "1.29": (2304, 1792),
+ "1.38": (2304, 1664),
+ "1.46": (2432, 1664),
+ "1.67": (2560, 1536),
+ "1.75": (2688, 1536),
+ "2.0": (2816, 1408),
+ "2.09": (2944, 1408),
+ "2.4": (3072, 1280),
+ "2.5": (3200, 1280),
+ "2.89": (3328, 1152),
+ "3.0": (3456, 1152),
+ "3.11": (3584, 1152),
+ "3.62": (3712, 1024),
+ "3.75": (3840, 1024),
+ "3.88": (3968, 1024),
+ "4.0": (4096, 1024),
+}
+
+# S = 1048576
+ASPECT_RATIO_1024 = {
+ "0.25": (512, 2048),
+ "0.26": (512, 1984),
+ "0.27": (512, 1920),
+ "0.28": (512, 1856),
+ "0.32": (576, 1792),
+ "0.33": (576, 1728),
+ "0.35": (576, 1664),
+ "0.4": (640, 1600),
+ "0.42": (640, 1536),
+ "0.48": (704, 1472),
+ "0.5": (704, 1408),
+ "0.52": (704, 1344),
+ "0.57": (768, 1344),
+ "0.6": (768, 1280),
+ "0.68": (832, 1216),
+ "0.72": (832, 1152),
+ "0.78": (896, 1152),
+ "0.82": (896, 1088),
+ "0.88": (960, 1088),
+ "0.94": (960, 1024),
+ "1.0": (1024, 1024),
+ "1.07": (1024, 960),
+ "1.13": (1088, 960),
+ "1.21": (1088, 896),
+ "1.29": (1152, 896),
+ "1.38": (1152, 832),
+ "1.46": (1216, 832),
+ "1.67": (1280, 768),
+ "1.75": (1344, 768),
+ "2.0": (1408, 704),
+ "2.09": (1472, 704),
+ "2.4": (1536, 640),
+ "2.5": (1600, 640),
+ "2.89": (1664, 576),
+ "3.0": (1728, 576),
+ "3.11": (1792, 576),
+ "3.62": (1856, 512),
+ "3.75": (1920, 512),
+ "3.88": (1984, 512),
+ "4.0": (2048, 512),
+}
+
+# S = 262144
+ASPECT_RATIO_512 = {
+ "0.25": (256, 1024),
+ "0.26": (256, 992),
+ "0.27": (256, 960),
+ "0.28": (256, 928),
+ "0.32": (288, 896),
+ "0.33": (288, 864),
+ "0.35": (288, 832),
+ "0.4": (320, 800),
+ "0.42": (320, 768),
+ "0.48": (352, 736),
+ "0.5": (352, 704),
+ "0.52": (352, 672),
+ "0.57": (384, 672),
+ "0.6": (384, 640),
+ "0.68": (416, 608),
+ "0.72": (416, 576),
+ "0.78": (448, 576),
+ "0.82": (448, 544),
+ "0.88": (480, 544),
+ "0.94": (480, 512),
+ "1.0": (512, 512),
+ "1.07": (512, 480),
+ "1.13": (544, 480),
+ "1.21": (544, 448),
+ "1.29": (576, 448),
+ "1.38": (576, 416),
+ "1.46": (608, 416),
+ "1.67": (640, 384),
+ "1.75": (672, 384),
+ "2.0": (704, 352),
+ "2.09": (736, 352),
+ "2.4": (768, 320),
+ "2.5": (800, 320),
+ "2.89": (832, 288),
+ "3.0": (864, 288),
+ "3.11": (896, 288),
+ "3.62": (928, 256),
+ "3.75": (960, 256),
+ "3.88": (992, 256),
+ "4.0": (1024, 256),
+}
+
+# S = 65536
+ASPECT_RATIO_256 = {
+ "0.25": (128, 512),
+ "0.26": (128, 496),
+ "0.27": (128, 480),
+ "0.28": (128, 464),
+ "0.32": (144, 448),
+ "0.33": (144, 432),
+ "0.35": (144, 416),
+ "0.4": (160, 400),
+ "0.42": (160, 384),
+ "0.48": (176, 368),
+ "0.5": (176, 352),
+ "0.52": (176, 336),
+ "0.57": (192, 336),
+ "0.6": (192, 320),
+ "0.68": (208, 304),
+ "0.72": (208, 288),
+ "0.78": (224, 288),
+ "0.82": (224, 272),
+ "0.88": (240, 272),
+ "0.94": (240, 256),
+ "1.0": (256, 256),
+ "1.07": (256, 240),
+ "1.13": (272, 240),
+ "1.21": (272, 224),
+ "1.29": (288, 224),
+ "1.38": (288, 208),
+ "1.46": (304, 208),
+ "1.67": (320, 192),
+ "1.75": (336, 192),
+ "2.0": (352, 176),
+ "2.09": (368, 176),
+ "2.4": (384, 160),
+ "2.5": (400, 160),
+ "2.89": (416, 144),
+ "3.0": (432, 144),
+ "3.11": (448, 144),
+ "3.62": (464, 128),
+ "3.75": (480, 128),
+ "3.88": (496, 128),
+ "4.0": (512, 128),
+}
+
+
+def get_closest_ratio(height: float, width: float, ratios: dict):
+ aspect_ratio = height / width
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
+ return closest_ratio
+
+
+ASPECT_RATIOS = {
+ "144p": (36864, ASPECT_RATIO_144P),
+ "256": (65536, ASPECT_RATIO_256),
+ "240p": (102240, ASPECT_RATIO_240P),
+ "360p": (230400, ASPECT_RATIO_360P),
+ "512": (262144, ASPECT_RATIO_512),
+ "480p": (409920, ASPECT_RATIO_480P),
+ "720p": (921600, ASPECT_RATIO_720P),
+ "1024": (1048576, ASPECT_RATIO_1024),
+ "1080p": (2073600, ASPECT_RATIO_1080P),
+ "2k": (3686400, ASPECT_RATIO_2K),
+ "2048": (4194304, ASPECT_RATIO_2048),
+ "2880": (8294400, ASPECT_RATIO_2880),
+ "4k": (8294400, ASPECT_RATIO_4K),
+}
+
+
+def get_image_size(resolution, ar_ratio):
+ ar_key = ASPECT_RATIO_MAP[ar_ratio]
+ rs_dict = ASPECT_RATIOS[resolution][1]
+ assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
+ return rs_dict[ar_key]
+
+
+NUM_FRAMES_MAP = {
+ "1x": 51,
+ "2x": 102,
+ "4x": 204,
+ "8x": 408,
+ "16x": 816,
+ "2s": 51,
+ "4s": 102,
+ "8s": 204,
+ "16s": 408,
+ "32s": 816,
+}
+
+
+def get_num_frames(num_frames):
+ if num_frames in NUM_FRAMES_MAP:
+ return NUM_FRAMES_MAP[num_frames]
+ else:
+ return int(num_frames)
+
+
+def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True):
+ """
+ Args:
+ x (Tensor): shape [C, T, H, W]
+ """
+ assert x.ndim == 4
+
+ if not force_video and x.shape[1] == 1: # T = 1: save as image
+ save_path += ".png"
+ x = x.squeeze(1)
+ save_image([x], save_path, normalize=normalize, value_range=value_range)
+ else:
+ save_path += ".mp4"
+ if normalize:
+ low, high = value_range
+ x.clamp_(min=low, max=high)
+ x.sub_(low).div_(max(high - low, 1e-5))
+
+ x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
+ write_video(save_path, x, fps=fps, video_codec="h264")
+ if verbose:
+ print(f"Saved to {save_path}")
+ return save_path
+
+
+def is_url(url):
+ return re.match(regex, url) is not None
+
+
+def download_url(input_path):
+ output_dir = "cache"
+ os.makedirs(output_dir, exist_ok=True)
+ base_name = os.path.basename(input_path)
+ output_path = os.path.join(output_dir, base_name)
+ img_data = requests.get(input_path).content
+ with open(output_path, "wb") as handler:
+ handler.write(img_data)
+ print(f"URL {input_path} downloaded to {output_path}")
+ return output_path
+
+
+def get_transforms_video(name="center", image_size=(256, 256)):
+ if name is None:
+ return None
+ elif name == "center":
+ assert image_size[0] == image_size[1], "image_size must be square for center crop"
+ transform_video = transforms.Compose(
+ [
+ ToTensorVideo(), # TCHW
+ # video_transforms.RandomHorizontalFlipVideo(),
+ UCFCenterCropVideo(image_size[0]),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+ elif name == "resize_crop":
+ transform_video = transforms.Compose(
+ [
+ ToTensorVideo(), # TCHW
+ ResizeCrop(image_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+ else:
+ raise NotImplementedError(f"Transform {name} not implemented")
+ return transform_video
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ th, tw = crop_size
+ if h < th or w < tw:
+ raise ValueError("height and width must be no smaller than crop_size")
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def resize_scale(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ H, W = clip.size(-2), clip.size(-1)
+ scale_ = target_size[0] / min(H, W)
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
+
+
+class UCFCenterCropVideo:
+ """
+ First scale to the specified size in equal proportion to the short edge,
+ then center cropping
+ """
+
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized / center cropped video clip.
+ size is (T, C, crop_size, crop_size)
+ """
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ clip_center_crop = center_crop(clip_resize, self.size)
+ return clip_center_crop
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
+ return clip.float() / 255.0
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ return to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class ResizeCrop:
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, clip):
+ clip = resize_crop_to_fill(clip, self.size)
+ return clip
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size})"
+
+
+def get_transforms_image(name="center", image_size=(256, 256)):
+ if name is None:
+ return None
+ elif name == "center":
+ assert image_size[0] == image_size[1], "Image size must be square for center crop"
+ transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
+ # transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+ elif name == "resize_crop":
+ transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+ else:
+ raise NotImplementedError(f"Transform {name} not implemented")
+ return transform
+
+
+def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
+
+
+def resize_crop_to_fill(pil_image, image_size):
+ w, h = pil_image.size # PIL is (W, H)
+ th, tw = image_size
+ rh, rw = th / h, tw / w
+ if rh > rw:
+ sh, sw = th, round(w * rh)
+ image = pil_image.resize((sw, sh), Image.BICUBIC)
+ i = 0
+ j = int(round((sw - tw) / 2.0))
+ else:
+ sh, sw = round(h * rw), tw
+ image = pil_image.resize((sw, sh), Image.BICUBIC)
+ i = int(round((sh - th) / 2.0))
+ j = 0
+ arr = np.array(image)
+ assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
+ return Image.fromarray(arr[i : i + th, j : j + tw])
+
+
+def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
+ if transform is None:
+ transform = get_transforms_video(image_size=image_size, name=transform_name)
+ video = transform(vframes) # T C H W
+ video = video.permute(1, 0, 2, 3)
+ return video
+
+
+def read_from_path(path, image_size, transform_name="center"):
+ if is_url(path):
+ path = download_url(path)
+ ext = os.path.splitext(path)[-1].lower()
+ if ext.lower() in VID_EXTENSIONS:
+ return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
+ else:
+ assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
+ return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
+
+
+def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
+ image = pil_loader(path)
+ if transform is None:
+ transform = get_transforms_image(image_size=image_size, name=transform_name)
+ image = transform(image)
+ video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
+ video = video.permute(1, 0, 2, 3)
+ return video
diff --git a/videosys/models/open_sora/embed.py b/videosys/models/open_sora/embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c238bc714f0a8ffb62edb36743e488060d120a
--- /dev/null
+++ b/videosys/models/open_sora/embed.py
@@ -0,0 +1,585 @@
+# Adapted from OpenSora and DiT
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DiT: https://github.com/facebookresearch/DiT
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import html
+import math
+import re
+
+import ftfy
+import numpy
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import transformers
+from timm.models.vision_transformer import Mlp
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from videosys.modules.embed import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
+
+transformers.logging.set_verbosity_error()
+
+
+# ===============================================
+# Text Embed
+# ===============================================
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+
+ def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(path)
+ self.transformer = CLIPTextModel.from_pretrained(path)
+ self.device = device
+ self.max_length = max_length
+ self._freeze()
+
+ def _freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ pooled_z = outputs.pooler_output
+ return z, pooled_z
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+
+ def __init__(self, path, hidden_size, dropout_prob=0.1):
+ super().__init__()
+ self.text_encoder = FrozenCLIPEmbedder(path=path)
+ self.dropout_prob = dropout_prob
+
+ output_dim = self.text_encoder.transformer.config.hidden_size
+ self.output_projection = nn.Linear(output_dim, hidden_size)
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ # TODO
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings, pooled_embeddings = self.text_encoder(text_prompts)
+ # return embeddings, pooled_embeddings
+ text_embeddings = self.output_projection(pooled_embeddings)
+ return text_embeddings
+
+
+class CaptionEmbedder(nn.Module):
+ """
+ copied from https://github.com/hpcaitech/Open-Sora
+
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
+ super().__init__()
+
+ self.y_proj = Mlp(
+ in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
+ )
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
+ self.uncond_prob = uncond_prob
+
+ def token_drop(self, caption, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
+ return caption
+
+ def forward(self, caption, train, force_drop_ids=None):
+ if train:
+ assert caption.shape[2:] == self.y_embedding.shape
+ use_dropout = self.uncond_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ caption = self.token_drop(caption, force_drop_ids)
+ caption = self.y_proj(caption)
+ return caption
+
+
+class T5Embedder:
+ available_models = ["DeepFloyd/t5-v1_1-xxl"]
+
+ def __init__(
+ self,
+ device,
+ from_pretrained=None,
+ *,
+ cache_dir=None,
+ hf_token=None,
+ use_text_preprocessing=True,
+ t5_model_kwargs=None,
+ torch_dtype=None,
+ use_offload_folder=None,
+ model_max_length=120,
+ local_files_only=False,
+ ):
+ self.device = torch.device(device)
+ self.torch_dtype = torch_dtype or torch.bfloat16
+ self.cache_dir = cache_dir
+
+ if t5_model_kwargs is None:
+ t5_model_kwargs = {
+ "low_cpu_mem_usage": True,
+ "torch_dtype": self.torch_dtype,
+ }
+
+ if use_offload_folder is not None:
+ t5_model_kwargs["offload_folder"] = use_offload_folder
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder.embed_tokens": self.device,
+ "encoder.block.0": self.device,
+ "encoder.block.1": self.device,
+ "encoder.block.2": self.device,
+ "encoder.block.3": self.device,
+ "encoder.block.4": self.device,
+ "encoder.block.5": self.device,
+ "encoder.block.6": self.device,
+ "encoder.block.7": self.device,
+ "encoder.block.8": self.device,
+ "encoder.block.9": self.device,
+ "encoder.block.10": self.device,
+ "encoder.block.11": self.device,
+ "encoder.block.12": "disk",
+ "encoder.block.13": "disk",
+ "encoder.block.14": "disk",
+ "encoder.block.15": "disk",
+ "encoder.block.16": "disk",
+ "encoder.block.17": "disk",
+ "encoder.block.18": "disk",
+ "encoder.block.19": "disk",
+ "encoder.block.20": "disk",
+ "encoder.block.21": "disk",
+ "encoder.block.22": "disk",
+ "encoder.block.23": "disk",
+ "encoder.final_layer_norm": "disk",
+ "encoder.dropout": "disk",
+ }
+ else:
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder": self.device,
+ }
+
+ self.use_text_preprocessing = use_text_preprocessing
+ self.hf_token = hf_token
+
+ assert from_pretrained in self.available_models
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ )
+ self.model = T5EncoderModel.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ **t5_model_kwargs,
+ ).eval()
+ self.model_max_length = model_max_length
+
+ def get_text_embeddings(self, texts):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=self.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ input_ids = text_tokens_and_mask["input_ids"].to(self.device)
+ attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
+ with torch.no_grad():
+ text_encoder_embs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["last_hidden_state"].detach()
+ return text_encoder_embs, attention_mask
+
+
+class T5Encoder:
+ def __init__(
+ self,
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
+ model_max_length=120,
+ device="cuda",
+ dtype=torch.float,
+ shardformer=False,
+ ):
+ assert from_pretrained is not None, "Please specify the path to the T5 model"
+
+ self.t5 = T5Embedder(
+ device=device,
+ torch_dtype=dtype,
+ from_pretrained=from_pretrained,
+ model_max_length=model_max_length,
+ )
+ self.t5.model.to(dtype=dtype)
+ self.y_embedder = None
+
+ self.model_max_length = model_max_length
+ self.output_dim = self.t5.model.config.d_model
+
+ if shardformer:
+ self.shardformer_t5()
+
+ def shardformer_t5(self):
+ from colossalai.shardformer import ShardConfig, ShardFormer
+
+ from videosys.core.shardformer.t5.policy import T5EncoderPolicy
+ from videosys.utils.utils import requires_grad
+
+ shard_config = ShardConfig(
+ tensor_parallel_process_group=None,
+ pipeline_stage_manager=None,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=True,
+ enable_sequence_parallelism=False,
+ enable_sequence_overlap=False,
+ )
+ shard_former = ShardFormer(shard_config=shard_config)
+ optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
+ self.t5.model = optim_model.half()
+
+ # ensure the weights are frozen
+ requires_grad(self.t5.model, False)
+
+ def encode(self, text):
+ caption_embs, emb_masks = self.t5.get_text_embeddings(text)
+ caption_embs = caption_embs[:, None]
+ return dict(y=caption_embs, mask=emb_masks)
+
+ def null(self, n):
+ null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
+ return null_y
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+BAD_PUNCT_REGEX = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+) # noqa
+
+
+def clean_caption(caption):
+ import urllib.parse as ul
+
+ from bs4 import BeautifulSoup
+
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = basic_clean(caption)
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+
+def text_preprocessing(text, use_text_preprocessing: bool = True):
+ if use_text_preprocessing:
+ # The exact text cleaning as was in the training stage:
+ text = clean_caption(text)
+ text = clean_caption(text)
+ return text
+ else:
+ return text.lower().strip()
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
+ freqs = freqs.to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, dtype):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ if t_freq.dtype != dtype:
+ t_freq = t_freq.to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+# ===============================================
+# Sine/Cosine Positional Embedding Functions
+# ===============================================
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if not isinstance(grid_size, tuple):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
+ if base_size is not None:
+ grid_h *= base_size / grid_size[0]
+ grid_w *= base_size / grid_size[1]
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
+ pos = np.arange(0, length)[..., None] / scale
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+
+
+# ===============================================
+# Patch Embed
+# ===============================================
+
+
+class PatchEmbed3D(nn.Module):
+ """Video to Patch Embedding.
+
+ Args:
+ patch_size (int): Patch token size. Default: (2,4,4).
+ in_chans (int): Number of input video channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(
+ self,
+ patch_size=(2, 4, 4),
+ in_chans=3,
+ embed_dim=96,
+ norm_layer=None,
+ flatten=True,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.flatten = flatten
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, D, H, W = x.size()
+ if W % self.patch_size[2] != 0:
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
+ if H % self.patch_size[1] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
+ if D % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
+
+ x = self.proj(x) # (B C T H W)
+ if self.norm is not None:
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
+ return x
diff --git a/videosys/models/open_sora/inference_utils.py b/videosys/models/open_sora/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de95fcf717902b4f2c432bc3302829cf719ec980
--- /dev/null
+++ b/videosys/models/open_sora/inference_utils.py
@@ -0,0 +1,348 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import json
+import os
+import re
+
+import torch
+
+from .datasets import IMG_FPS, read_from_path
+
+
+def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
+ if info_type is None:
+ return dict()
+ elif info_type == "PixArtMS":
+ hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
+ ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
+ return dict(ar=ar, hw=hw)
+ elif info_type in ["STDiT2", "OpenSora"]:
+ fps = fps if num_frames > 1 else IMG_FPS
+ fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
+ height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
+ width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
+ num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
+ ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
+ return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
+ else:
+ raise NotImplementedError
+
+
+def load_prompts(prompt_path, start_idx=None, end_idx=None):
+ with open(prompt_path, "r") as f:
+ prompts = [line.strip() for line in f.readlines()]
+ prompts = prompts[start_idx:end_idx]
+ return prompts
+
+
+def get_save_path_name(
+ save_dir,
+ sample_name=None, # prefix
+ sample_idx=None, # sample index
+ prompt=None, # used prompt
+ prompt_as_path=False, # use prompt as path
+ num_sample=1, # number of samples to generate for one prompt
+ k=None, # kth sample
+):
+ if sample_name is None:
+ sample_name = "" if prompt_as_path else "sample"
+ sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
+ save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix[:50]}")
+ if num_sample != 1:
+ save_path = f"{save_path}-{k}"
+ return save_path
+
+
+def get_eval_save_path_name(
+ save_dir,
+ id, # add id parameter
+ sample_name=None, # prefix
+ sample_idx=None, # sample index
+ prompt=None, # used prompt
+ prompt_as_path=False, # use prompt as path
+ num_sample=1, # number of samples to generate for one prompt
+ k=None, # kth sample
+):
+ if sample_name is None:
+ sample_name = "" if prompt_as_path else "sample"
+ save_path = os.path.join(save_dir, f"{id}")
+ if num_sample != 1:
+ save_path = f"{save_path}-{k}"
+ return save_path
+
+
+def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
+ new_prompts = []
+ for prompt in prompts:
+ new_prompt = prompt
+ if aes is not None and "aesthetic score:" not in prompt:
+ new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
+ if flow is not None and "motion score:" not in prompt:
+ new_prompt = f"{new_prompt} motion score: {flow:.1f}."
+ if camera_motion is not None and "camera motion:" not in prompt:
+ new_prompt = f"{new_prompt} camera motion: {camera_motion}."
+ new_prompts.append(new_prompt)
+ return new_prompts
+
+
+def extract_json_from_prompts(prompts, reference, mask_strategy):
+ ret_prompts = []
+ for i, prompt in enumerate(prompts):
+ parts = re.split(r"(?=[{])", prompt)
+ assert len(parts) <= 2, f"Invalid prompt: {prompt}"
+ ret_prompts.append(parts[0])
+ if len(parts) > 1:
+ additional_info = json.loads(parts[1])
+ for key in additional_info:
+ assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
+ if key == "reference_path":
+ reference[i] = additional_info[key]
+ elif key == "mask_strategy":
+ mask_strategy[i] = additional_info[key]
+ return ret_prompts, reference, mask_strategy
+
+
+def collect_references_batch(reference_paths, vae, image_size):
+ refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
+ for reference_path in reference_paths:
+ if reference_path == "":
+ refs_x.append([])
+ continue
+ ref_path = reference_path.split(";")
+ ref = []
+ for r_path in ref_path:
+ r = read_from_path(r_path, image_size, transform_name="resize_crop")
+ r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
+ r_x = r_x.squeeze(0)
+ ref.append(r_x)
+ refs_x.append(ref)
+ return refs_x
+
+
+def extract_prompts_loop(prompts, num_loop):
+ ret_prompts = []
+ for prompt in prompts:
+ if prompt.startswith("|0|"):
+ prompt_list = prompt.split("|")[1:]
+ text_list = []
+ for i in range(0, len(prompt_list), 2):
+ start_loop = int(prompt_list[i])
+ text = prompt_list[i + 1]
+ end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
+ text_list.extend([text] * (end_loop - start_loop))
+ prompt = text_list[num_loop]
+ ret_prompts.append(prompt)
+ return ret_prompts
+
+
+def split_prompt(prompt_text):
+ if prompt_text.startswith("|0|"):
+ # this is for prompts which look like
+ # |0| a beautiful day |1| a sunny day |2| a rainy day
+ # we want to parse it into a list of prompts with the loop index
+ prompt_list = prompt_text.split("|")[1:]
+ text_list = []
+ loop_idx = []
+ for i in range(0, len(prompt_list), 2):
+ start_loop = int(prompt_list[i])
+ text = prompt_list[i + 1].strip()
+ text_list.append(text)
+ loop_idx.append(start_loop)
+ return text_list, loop_idx
+ else:
+ return [prompt_text], None
+
+
+def merge_prompt(text_list, loop_idx_list=None):
+ if loop_idx_list is None:
+ return text_list[0]
+ else:
+ prompt = ""
+ for i, text in enumerate(text_list):
+ prompt += f"|{loop_idx_list[i]}|{text}"
+ return prompt
+
+
+MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
+
+
+def parse_mask_strategy(mask_strategy):
+ mask_batch = []
+ if mask_strategy == "" or mask_strategy is None:
+ return mask_batch
+
+ mask_strategy = mask_strategy.split(";")
+ for mask in mask_strategy:
+ mask_group = mask.split(",")
+ num_group = len(mask_group)
+ assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
+ mask_group.extend(MASK_DEFAULT[num_group:])
+ for i in range(5):
+ mask_group[i] = int(mask_group[i])
+ mask_group[5] = float(mask_group[5])
+ mask_batch.append(mask_group)
+ return mask_batch
+
+
+def find_nearest_point(value, point, max_value):
+ t = value // point
+ if value % point > point / 2 and t < max_value // point - 1:
+ t += 1
+ return t * point
+
+
+def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
+ masks = []
+ no_mask = True
+ for i, mask_strategy in enumerate(mask_strategys):
+ no_mask = False
+ mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
+ mask_strategy = parse_mask_strategy(mask_strategy)
+ for mst in mask_strategy:
+ loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
+ if loop_id != loop_i:
+ continue
+ ref = refs_x[i][m_id]
+
+ if m_ref_start < 0:
+ # ref: [C, T, H, W]
+ m_ref_start = ref.shape[1] + m_ref_start
+ if m_target_start < 0:
+ # z: [B, C, T, H, W]
+ m_target_start = z.shape[2] + m_target_start
+ if align is not None:
+ m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
+ m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
+ m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
+ z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
+ mask[m_target_start : m_target_start + m_length] = edit_ratio
+ masks.append(mask)
+ if no_mask:
+ return None
+ masks = torch.stack(masks)
+ return masks
+
+
+def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit):
+ ref_x = vae.encode(generated_video)
+ for j, refs in enumerate(refs_x):
+ if refs is None:
+ refs_x[j] = [ref_x[j]]
+ else:
+ refs.append(ref_x[j])
+ if mask_strategy[j] is None or mask_strategy[j] == "":
+ mask_strategy[j] = ""
+ else:
+ mask_strategy[j] += ";"
+ mask_strategy[
+ j
+ ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
+ return refs_x, mask_strategy
+
+
+def dframe_to_frame(num):
+ assert num % 5 == 0, f"Invalid num: {num}"
+ return num // 5 * 17
+
+
+OPENAI_CLIENT = None
+REFINE_PROMPTS = None
+REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
+REFINE_PROMPTS_TEMPLATE = """
+You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
+{}
+
+The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English.
+"""
+RANDOM_PROMPTS = None
+RANDOM_PROMPTS_TEMPLATE = """
+You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
+{}
+
+The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English.
+"""
+
+
+def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
+ global OPENAI_CLIENT
+ if OPENAI_CLIENT is None:
+ from openai import OpenAI
+
+ OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
+
+ completion = OPENAI_CLIENT.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": sys_prompt,
+ }, # <-- This is the system message that provides context to the model
+ {
+ "role": "user",
+ "content": usr_prompt,
+ }, # <-- This is the user message for which the model will generate a response
+ ],
+ )
+
+ return completion.choices[0].message.content
+
+
+def get_random_prompt_by_openai():
+ global RANDOM_PROMPTS
+ if RANDOM_PROMPTS is None:
+ examples = load_prompts(REFINE_PROMPTS_PATH)
+ RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
+
+ response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
+ return response
+
+
+def refine_prompt_by_openai(prompt):
+ global REFINE_PROMPTS
+ if REFINE_PROMPTS is None:
+ examples = load_prompts(REFINE_PROMPTS_PATH)
+ REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
+
+ response = get_openai_response(REFINE_PROMPTS, prompt)
+ return response
+
+
+def has_openai_key():
+ return "OPENAI_API_KEY" in os.environ
+
+
+def refine_prompts_by_openai(prompts):
+ new_prompts = []
+ for prompt in prompts:
+ try:
+ if prompt.strip() == "":
+ new_prompt = get_random_prompt_by_openai()
+ print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
+ else:
+ new_prompt = refine_prompt_by_openai(prompt)
+ print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
+ new_prompts.append(new_prompt)
+ except Exception as e:
+ print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
+ new_prompts.append(prompt)
+ return new_prompts
+
+
+def add_watermark(
+ input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
+):
+ # execute this command in terminal with subprocess
+ # return if the process is successful
+ if output_video_path is None:
+ output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
+ cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
+ exit_code = os.system(cmd)
+ is_success = exit_code == 0
+ return is_success
diff --git a/videosys/models/open_sora/modules.py b/videosys/models/open_sora/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c127a6cdd5bca058b11b840248b8195f9e47713
--- /dev/null
+++ b/videosys/models/open_sora/modules.py
@@ -0,0 +1,450 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import functools
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from timm.models.vision_transformer import Mlp
+
+approx_gelu = lambda: nn.GELU(approximate="tanh")
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool):
+ return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
+
+
+def t2i_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+# ===============================================
+# General-purpose Layers
+# ===============================================
+
+
+class PatchEmbed3D(nn.Module):
+ """Video to Patch Embedding.
+
+ Args:
+ patch_size (int): Patch token size. Default: (2,4,4).
+ in_chans (int): Number of input video channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(
+ self,
+ patch_size=(2, 4, 4),
+ in_chans=3,
+ embed_dim=96,
+ norm_layer=None,
+ flatten=True,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.flatten = flatten
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, D, H, W = x.size()
+ if W % self.patch_size[2] != 0:
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
+ if H % self.patch_size[1] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
+ if D % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
+
+ x = self.proj(x) # (B C T H W)
+ if self.norm is not None:
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = LlamaRMSNorm,
+ enable_flash_attn: bool = False,
+ rope=None,
+ qk_norm_legacy: bool = False,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.enable_flash_attn = enable_flash_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.qk_norm_legacy = qk_norm_legacy
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = False
+ if rope is not None:
+ self.rope = True
+ self.rotary_emb = rope
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ # flash attn is not memory efficient for small sequences, this is empirical
+ enable_flash_attn = self.enable_flash_attn and (N > B)
+ qkv = self.qkv(x)
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
+
+ qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ if self.qk_norm_legacy:
+ # WARNING: this may be a bug
+ if self.rope:
+ q = self.rotary_emb(q)
+ k = self.rotary_emb(k)
+ q, k = self.q_norm(q), self.k_norm(k)
+ else:
+ q, k = self.q_norm(q), self.k_norm(k)
+ if self.rope:
+ q = self.rotary_emb(q)
+ k = self.rotary_emb(k)
+
+ if enable_flash_attn:
+ from flash_attn import flash_attn_func
+
+ # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+ x = flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ softmax_scale=self.scale,
+ )
+ else:
+ dtype = q.dtype
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
+ attn = attn.to(torch.float32)
+ attn = attn.softmax(dim=-1)
+ attn = attn.to(dtype) # cast back attn to original dtype
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x_output_shape = (B, N, C)
+ if not enable_flash_attn:
+ x = x.transpose(1, 2)
+ x = x.reshape(x_output_shape)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MultiHeadCrossAttention(nn.Module):
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
+ super(MultiHeadCrossAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
+
+ self.d_model = d_model
+ self.num_heads = num_heads
+ self.head_dim = d_model // num_heads
+
+ self.q_linear = nn.Linear(d_model, d_model)
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(d_model, d_model)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, cond, mask=None):
+ # query/value: img tokens; key: condition; mask: if padding tokens
+ B, N, C = x.shape
+
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
+ k, v = kv.unbind(2)
+
+ attn_bias = None
+ # TODO: support torch computation
+ import xformers.ops
+
+ if mask is not None:
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
+
+ x = x.view(B, -1, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class T2IFinalLayer(nn.Module):
+ """
+ The final layer of PixArt.
+ """
+
+ def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
+ self.out_channels = out_channels
+ self.d_t = d_t
+ self.d_s = d_s
+
+ def t_mask_select(self, x_mask, x, masked_x, T, S):
+ # x: [B, (T, S), C]
+ # mased_x: [B, (T, S), C]
+ # x_mask: [B, T]
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
+ masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
+ x = torch.where(x_mask[:, :, None, None], x, masked_x)
+ x = rearrange(x, "B T S C -> B (T S) C")
+ return x
+
+ def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
+ if T is None:
+ T = self.d_t
+ if S is None:
+ S = self.d_s
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
+ x = t2i_modulate(self.norm_final(x), shift, scale)
+ if x_mask is not None:
+ shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
+ x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
+ x = self.t_mask_select(x_mask, x, x_zero, T, S)
+ x = self.linear(x)
+ return x
+
+
+# ===============================================
+# Embedding Layers for Timesteps and Class Labels
+# ===============================================
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
+ freqs = freqs.to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, dtype):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ if t_freq.dtype != dtype:
+ t_freq = t_freq.to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class SizeEmbedder(TimestepEmbedder):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+ self.outdim = hidden_size
+
+ def forward(self, s, bs):
+ if s.ndim == 1:
+ s = s[:, None]
+ assert s.ndim == 2
+ if s.shape[0] != bs:
+ s = s.repeat(bs // s.shape[0], 1)
+ assert s.shape[0] == bs
+ b, dims = s.shape[0], s.shape[1]
+ s = rearrange(s, "b d -> (b d)")
+ s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
+ s_emb = self.mlp(s_freq)
+ s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
+ return s_emb
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class CaptionEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ uncond_prob,
+ act_layer=nn.GELU(approximate="tanh"),
+ token_num=120,
+ ):
+ super().__init__()
+ self.y_proj = Mlp(
+ in_features=in_channels,
+ hidden_features=hidden_size,
+ out_features=hidden_size,
+ act_layer=act_layer,
+ drop=0,
+ )
+ self.register_buffer(
+ "y_embedding",
+ torch.randn(token_num, in_channels) / in_channels**0.5,
+ )
+ self.uncond_prob = uncond_prob
+
+ def token_drop(self, caption, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
+ return caption
+
+ def forward(self, caption, train, force_drop_ids=None):
+ if train:
+ assert caption.shape[2:] == self.y_embedding.shape
+ use_dropout = self.uncond_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ caption = self.token_drop(caption, force_drop_ids)
+ caption = self.y_proj(caption)
+ return caption
+
+
+class PositionEmbedding2D(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.dim = dim
+ assert dim % 4 == 0, "dim must be divisible by 4"
+ half_dim = dim // 2
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def _get_sin_cos_emb(self, t: torch.Tensor):
+ out = torch.einsum("i,d->id", t, self.inv_freq)
+ emb_cos = torch.cos(out)
+ emb_sin = torch.sin(out)
+ return torch.cat((emb_sin, emb_cos), dim=-1)
+
+ @functools.lru_cache(maxsize=512)
+ def _get_cached_emb(
+ self,
+ device: torch.device,
+ dtype: torch.dtype,
+ h: int,
+ w: int,
+ scale: float = 1.0,
+ base_size: Optional[int] = None,
+ ):
+ grid_h = torch.arange(h, device=device) / scale
+ grid_w = torch.arange(w, device=device) / scale
+ if base_size is not None:
+ grid_h *= base_size / h
+ grid_w *= base_size / w
+ grid_h, grid_w = torch.meshgrid(
+ grid_w,
+ grid_h,
+ indexing="ij",
+ ) # here w goes first
+ grid_h = grid_h.t().reshape(-1)
+ grid_w = grid_w.t().reshape(-1)
+ emb_h = self._get_sin_cos_emb(grid_h)
+ emb_w = self._get_sin_cos_emb(grid_w)
+ return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ h: int,
+ w: int,
+ scale: Optional[float] = 1.0,
+ base_size: Optional[int] = None,
+ ) -> torch.Tensor:
+ return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
diff --git a/videosys/models/open_sora/pipeline.py b/videosys/models/open_sora/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eba025dcaee1a0f61d03e7946871417d7a32088
--- /dev/null
+++ b/videosys/models/open_sora/pipeline.py
@@ -0,0 +1,427 @@
+import re
+from typing import Optional, Tuple, Union
+
+import torch
+from diffusers.models import AutoencoderKL
+
+from videosys.core.pab_mgr import PABConfig, set_pab_manager
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.utils.utils import save_video
+
+from .datasets import get_image_size, get_num_frames
+from .inference_utils import (
+ append_generated,
+ append_score_to_prompts,
+ apply_mask_strategy,
+ collect_references_batch,
+ dframe_to_frame,
+ extract_json_from_prompts,
+ extract_prompts_loop,
+ merge_prompt,
+ prepare_multi_resolution_info,
+ split_prompt,
+)
+from .rflow import RFLOW
+from .stdit3 import STDiT3_XL_2
+from .text_encoder import T5Encoder, text_preprocessing
+from .vae import OpenSoraVAE_V1_2
+
+
+class OpenSoraPABConfig(PABConfig):
+ def __init__(
+ self,
+ steps: int = 50,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [450, 930],
+ spatial_gap: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [450, 930],
+ temporal_gap: int = 4,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [450, 930],
+ cross_gap: int = 6,
+ diffusion_skip: bool = False,
+ diffusion_timestep_respacing: list = None,
+ diffusion_skip_timestep: list = None,
+ mlp_skip: bool = True,
+ mlp_spatial_skip_config: dict = {
+ 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ mlp_temporal_skip_config: dict = {
+ 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ ):
+ super().__init__(
+ steps=steps,
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_gap=spatial_gap,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_gap=temporal_gap,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_gap=cross_gap,
+ diffusion_skip=diffusion_skip,
+ diffusion_timestep_respacing=diffusion_timestep_respacing,
+ diffusion_skip_timestep=diffusion_skip_timestep,
+ mlp_skip=mlp_skip,
+ mlp_spatial_skip_config=mlp_spatial_skip_config,
+ mlp_temporal_skip_config=mlp_temporal_skip_config,
+ )
+
+
+class OpenSoraConfig:
+ def __init__(
+ self,
+ world_size: int = 1,
+ transformer: str = "hpcai-tech/OpenSora-STDiT-v3",
+ vae: str = "hpcai-tech/OpenSora-VAE-v1.2",
+ text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
+ # ======= scheduler =======
+ num_sampling_steps: int = 30,
+ cfg_scale: float = 7.0,
+ # ======= vae ========
+ tiling_size: int = 4,
+ # ======= pab ========
+ enable_pab: bool = False,
+ pab_config: PABConfig = OpenSoraPABConfig(),
+ ):
+ # ======= engine ========
+ self.world_size = world_size
+
+ # ======= pipeline ========
+ self.pipeline_cls = OpenSoraPipeline
+ self.transformer = transformer
+ self.vae = vae
+ self.text_encoder = text_encoder
+
+ # ======= scheduler ========
+ self.num_sampling_steps = num_sampling_steps
+ self.cfg_scale = cfg_scale
+
+ # ======= vae ========
+ self.tiling_size = tiling_size
+
+ # ======= pab ========
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class OpenSoraPipeline(VideoSysPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ config: OpenSoraConfig,
+ text_encoder: Optional[T5Encoder] = None,
+ vae: Optional[AutoencoderKL] = None,
+ transformer: Optional[STDiT3_XL_2] = None,
+ scheduler: Optional[RFLOW] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.bfloat16,
+ ):
+ super().__init__()
+ self._config = config
+ self._device = device
+ self._dtype = dtype
+
+ # initialize the model if not provided
+ if text_encoder is None:
+ text_encoder = T5Encoder(
+ from_pretrained=config.text_encoder, model_max_length=300, device=device, dtype=dtype
+ )
+ if vae is None:
+ vae = OpenSoraVAE_V1_2(
+ from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
+ micro_frame_size=17,
+ micro_batch_size=config.tiling_size,
+ ).to(dtype)
+ if transformer is None:
+ transformer = STDiT3_XL_2(
+ from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
+ qk_norm=True,
+ enable_flash_attn=True,
+ enable_layernorm_kernel=True,
+ in_channels=vae.out_channels,
+ caption_channels=text_encoder.output_dim,
+ model_max_length=text_encoder.model_max_length,
+ ).to(device, dtype)
+ text_encoder.y_embedder = transformer.y_embedder
+ if scheduler is None:
+ scheduler = RFLOW(
+ use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale
+ )
+
+ # pab
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+ # set eval and device
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
+
+ self.register_modules(text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler)
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt: str,
+ resolution="480p",
+ aspect_ratio="9:16",
+ num_frames: int = 51,
+ loop: int = 1,
+ llm_refine: bool = False,
+ negative_prompt: str = "",
+ ms: Optional[str] = "",
+ refs: Optional[str] = "",
+ aes: float = 6.5,
+ flow: Optional[float] = None,
+ camera_motion: Optional[float] = None,
+ condition_frame_length: int = 5,
+ align: int = 5,
+ condition_frame_edit: float = 0.0,
+ return_dict: bool = True,
+ verbose: bool = True,
+ ) -> Union[VideoSysPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ resolution (`str`, *optional*, defaults to `"480p"`):
+ The resolution of the generated video.
+ aspect_ratio (`str`, *optional*, defaults to `"9:16"`):
+ The aspect ratio of the generated video.
+ num_frames (`int`, *optional*, defaults to 51):
+ The number of frames to generate.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # == basic ==
+ fps = 24
+ image_size = get_image_size(resolution, aspect_ratio)
+ num_frames = get_num_frames(num_frames)
+
+ # == prepare batch prompts ==
+ batch_prompts = [prompt]
+ ms = [ms]
+ refs = [refs]
+
+ # == get json from prompts ==
+ batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
+
+ # == get reference for condition ==
+ refs = collect_references_batch(refs, self.vae, image_size)
+
+ # == multi-resolution info ==
+ model_args = prepare_multi_resolution_info(
+ "OpenSora", len(batch_prompts), image_size, num_frames, fps, self._device, self._dtype
+ )
+
+ # == process prompts step by step ==
+ # 0. split prompt
+ # each element in the list is [prompt_segment_list, loop_idx_list]
+ batched_prompt_segment_list = []
+ batched_loop_idx_list = []
+ for prompt in batch_prompts:
+ prompt_segment_list, loop_idx_list = split_prompt(prompt)
+ batched_prompt_segment_list.append(prompt_segment_list)
+ batched_loop_idx_list.append(loop_idx_list)
+
+ # 1. refine prompt by openai
+ # if llm_refine:
+ # only call openai API when
+ # 1. seq parallel is not enabled
+ # 2. seq parallel is enabled and the process is rank 0
+ # if not enable_sequence_parallelism or (enable_sequence_parallelism and coordinator.is_master()):
+ # for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ # batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
+
+ # # sync the prompt if using seq parallel
+ # if enable_sequence_parallelism:
+ # coordinator.block_all()
+ # prompt_segment_length = [
+ # len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list
+ # ]
+
+ # # flatten the prompt segment list
+ # batched_prompt_segment_list = [
+ # prompt_segment
+ # for prompt_segment_list in batched_prompt_segment_list
+ # for prompt_segment in prompt_segment_list
+ # ]
+
+ # # create a list of size equal to world size
+ # broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
+ # dist.broadcast_object_list(broadcast_obj_list, 0)
+
+ # # recover the prompt list
+ # batched_prompt_segment_list = []
+ # segment_start_idx = 0
+ # all_prompts = broadcast_obj_list[0]
+ # for num_segment in prompt_segment_length:
+ # batched_prompt_segment_list.append(
+ # all_prompts[segment_start_idx : segment_start_idx + num_segment]
+ # )
+ # segment_start_idx += num_segment
+
+ # 2. append score
+ for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ batched_prompt_segment_list[idx] = append_score_to_prompts(
+ prompt_segment_list,
+ aes=aes,
+ flow=flow,
+ camera_motion=camera_motion,
+ )
+
+ # 3. clean prompt with T5
+ for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list]
+
+ # 4. merge to obtain the final prompt
+ batch_prompts = []
+ for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
+ batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
+
+ # == Iter over loop generation ==
+ video_clips = []
+ for loop_i in range(loop):
+ # == get prompt for loop i ==
+ batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
+
+ # == add condition frames for loop ==
+ if loop_i > 0:
+ refs, ms = append_generated(
+ self.vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
+ )
+
+ # == sampling ==
+ input_size = (num_frames, *image_size)
+ latent_size = self.vae.get_latent_size(input_size)
+ z = torch.randn(
+ len(batch_prompts), self.vae.out_channels, *latent_size, device=self._device, dtype=self._dtype
+ )
+ masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
+ samples = self.scheduler.sample(
+ self.transformer,
+ self.text_encoder,
+ z=z,
+ prompts=batch_prompts_loop,
+ device=self._device,
+ additional_args=model_args,
+ progress=verbose,
+ mask=masks,
+ )
+ samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
+ video_clips.append(samples)
+
+ for i in range(1, loop):
+ video_clips[i] = video_clips[i][:, dframe_to_frame(condition_frame_length) :]
+ video = torch.cat(video_clips, dim=1)
+
+ low, high = -1, 1
+ video.clamp_(min=low, max=high)
+ video.sub_(low).div_(max(high - low, 1e-5))
+ video = video.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 4, 1).to("cpu", torch.uint8)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return VideoSysPipelineOutput(video=video)
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=24)
diff --git a/videosys/models/open_sora/rflow.py b/videosys/models/open_sora/rflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9b8f5bfac237dcadded659d7c3ba6bcc2515e77
--- /dev/null
+++ b/videosys/models/open_sora/rflow.py
@@ -0,0 +1,270 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import torch
+import torch.distributed as dist
+from einops import rearrange
+from torch.distributions import LogisticNormal
+from tqdm import tqdm
+
+from videosys.core.pab_mgr import get_diffusion_skip, get_diffusion_skip_timestep, skip_diffusion_timestep
+from videosys.diffusion.gaussian_diffusion import _extract_into_tensor
+
+
+def mean_flat(tensor: torch.Tensor, mask=None):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ if mask is None:
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+ else:
+ assert tensor.dim() == 5
+ assert tensor.shape[2] == mask.shape[1]
+ tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
+ denom = mask.sum(dim=1) * tensor.shape[-1]
+ loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
+ return loss
+
+
+def timestep_transform(
+ t,
+ model_kwargs,
+ base_resolution=512 * 512,
+ base_num_frames=1,
+ scale=1.0,
+ num_timesteps=1,
+):
+ t = t / num_timesteps
+ resolution = model_kwargs["height"] * model_kwargs["width"]
+ ratio_space = (resolution / base_resolution).sqrt()
+ # NOTE: currently, we do not take fps into account
+ # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
+ if model_kwargs["num_frames"][0] == 1:
+ num_frames = torch.ones_like(model_kwargs["num_frames"])
+ else:
+ num_frames = model_kwargs["num_frames"] // 17 * 5
+ ratio_time = (num_frames / base_num_frames).sqrt()
+
+ ratio = ratio_space * ratio_time * scale
+ new_t = ratio * t / (1 + (ratio - 1) * t)
+
+ new_t = new_t * num_timesteps
+ return new_t
+
+
+class RFlowScheduler:
+ def __init__(
+ self,
+ num_timesteps=1000,
+ num_sampling_steps=10,
+ use_discrete_timesteps=False,
+ sample_method="uniform",
+ loc=0.0,
+ scale=1.0,
+ use_timestep_transform=False,
+ transform_scale=1.0,
+ ):
+ self.num_timesteps = num_timesteps
+ self.num_sampling_steps = num_sampling_steps
+ self.use_discrete_timesteps = use_discrete_timesteps
+
+ # sample method
+ assert sample_method in ["uniform", "logit-normal"]
+ assert (
+ sample_method == "uniform" or not use_discrete_timesteps
+ ), "Only uniform sampling is supported for discrete timesteps"
+ self.sample_method = sample_method
+ if sample_method == "logit-normal":
+ self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
+ self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
+
+ # timestep transform
+ self.use_timestep_transform = use_timestep_transform
+ self.transform_scale = transform_scale
+
+ def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
+ """
+ Compute training losses for a single timestep.
+ Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
+ Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
+ """
+ if t is None:
+ if self.use_discrete_timesteps:
+ t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
+ elif self.sample_method == "uniform":
+ t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps
+ elif self.sample_method == "logit-normal":
+ t = self.sample_t(x_start) * self.num_timesteps
+
+ if self.use_timestep_transform:
+ t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps)
+
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = torch.randn_like(x_start)
+ assert noise.shape == x_start.shape
+
+ x_t = self.add_noise(x_start, noise, t)
+ if mask is not None:
+ t0 = torch.zeros_like(t)
+ x_t0 = self.add_noise(x_start, noise, t0)
+ x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
+
+ terms = {}
+ model_output = model(x_t, t, **model_kwargs)
+ velocity_pred = model_output.chunk(2, dim=1)[0]
+ if weights is None:
+ loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask)
+ else:
+ weight = _extract_into_tensor(weights, t, x_start.shape)
+ loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask)
+ terms["loss"] = loss
+
+ return terms
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ """
+ compatible with diffusers add_noise()
+ """
+ timepoints = timesteps.float() / self.num_timesteps
+ timepoints = 1 - timepoints # [1,1/1000]
+
+ # timepoint (bsz) noise: (bsz, 4, frame, w ,h)
+ # expand timepoint to noise shape
+ timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
+ timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
+
+ return timepoints * original_samples + (1 - timepoints) * noise
+
+
+class RFLOW:
+ def __init__(
+ self,
+ num_sampling_steps=10,
+ num_timesteps=1000,
+ cfg_scale=4.0,
+ use_discrete_timesteps=False,
+ use_timestep_transform=False,
+ **kwargs,
+ ):
+ self.num_sampling_steps = num_sampling_steps
+ self.num_timesteps = num_timesteps
+ self.cfg_scale = cfg_scale
+ self.use_discrete_timesteps = use_discrete_timesteps
+ self.use_timestep_transform = use_timestep_transform
+
+ self.scheduler = RFlowScheduler(
+ num_timesteps=num_timesteps,
+ num_sampling_steps=num_sampling_steps,
+ use_discrete_timesteps=use_discrete_timesteps,
+ use_timestep_transform=use_timestep_transform,
+ **kwargs,
+ )
+
+ def sample(
+ self,
+ model,
+ text_encoder,
+ z,
+ prompts,
+ device,
+ additional_args=None,
+ mask=None,
+ guidance_scale=None,
+ progress=True,
+ verbose=False,
+ ):
+ # if no specific guidance scale is provided, use the default scale when initializing the scheduler
+ if guidance_scale is None:
+ guidance_scale = self.cfg_scale
+
+ n = len(prompts)
+ # text encoding
+ model_args = text_encoder.encode(prompts)
+ y_null = text_encoder.null(n)
+ model_args["y"] = torch.cat([model_args["y"], y_null], 0)
+ if additional_args is not None:
+ model_args.update(additional_args)
+
+ # prepare timesteps
+ timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)]
+ if self.use_discrete_timesteps:
+ timesteps = [int(round(t)) for t in timesteps]
+ timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps]
+ if self.use_timestep_transform:
+ timesteps = [timestep_transform(t, additional_args, num_timesteps=self.num_timesteps) for t in timesteps]
+
+ if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
+ orignal_timesteps = timesteps
+ diffusion_skip_timestep = get_diffusion_skip_timestep()
+ timesteps = skip_diffusion_timestep(timesteps, diffusion_skip_timestep)
+
+ if verbose and dist.get_rank() == 0:
+ print("============================")
+ print("skip diffusion steps!!!")
+ print("============================")
+ print(f"orignal sample timesteps: {orignal_timesteps}")
+ print(f"orignal diffusion steps: {len(orignal_timesteps)}")
+ print("============================")
+ print(f"skip diffusion steps: {get_diffusion_skip_timestep()}")
+ print(f"sample timesteps: {timesteps}")
+ print(f"num_inference_steps: {len(timesteps)}")
+ print("============================")
+
+ if mask is not None:
+ noise_added = torch.zeros_like(mask, dtype=torch.bool)
+ noise_added = noise_added | (mask == 1)
+
+ progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x)
+
+ dtype = model.x_embedder.proj.weight.dtype
+ all_timesteps = [int(t.to(dtype).item()) for t in timesteps]
+ for i, t in progress_wrap(list(enumerate(timesteps))):
+ # mask for adding noise
+ if mask is not None:
+ mask_t = mask * self.num_timesteps
+ x0 = z.clone()
+ x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t)
+
+ mask_t_upper = mask_t >= t.unsqueeze(1)
+ model_args["x_mask"] = mask_t_upper.repeat(2, 1)
+ mask_add_noise = mask_t_upper & ~noise_added
+
+ z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0)
+ noise_added = mask_t_upper
+
+ # classifier-free guidance
+ z_in = torch.cat([z, z], 0)
+ t = torch.cat([t, t], 0)
+
+ # pred = model(z_in, t, **model_args).chunk(2, dim=1)[0]
+ output = model(z_in, t, all_timesteps, **model_args)
+
+ pred = output.chunk(2, dim=1)[0]
+ pred_cond, pred_uncond = pred.chunk(2, dim=0)
+ v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+
+ # update z
+ dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
+ dt = dt / self.num_timesteps
+ z = z + v_pred * dt[:, None, None, None, None]
+
+ if mask is not None:
+ z = torch.where(mask_t_upper[:, None, :, None, None], z, x0)
+
+ return z
+
+ def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
+ return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t)
diff --git a/videosys/models/open_sora/stdit3.py b/videosys/models/open_sora/stdit3.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfea9dd5501c4dd341c277eb3023ca3854242b05
--- /dev/null
+++ b/videosys/models/open_sora/stdit3.py
@@ -0,0 +1,603 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+
+import os
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from timm.models.layers import DropPath
+from timm.models.vision_transformer import Mlp
+from transformers import PretrainedConfig, PreTrainedModel
+
+from videosys.core.comm import (
+ all_to_all_with_pad,
+ gather_sequence,
+ get_spatial_pad,
+ get_temporal_pad,
+ set_spatial_pad,
+ set_temporal_pad,
+ split_sequence,
+)
+from videosys.core.pab_mgr import (
+ enable_pab,
+ get_mlp_output,
+ if_broadcast_cross,
+ if_broadcast_mlp,
+ if_broadcast_spatial,
+ if_broadcast_temporal,
+ save_mlp_output,
+)
+from videosys.core.parallel_mgr import (
+ enable_sequence_parallel,
+ get_cfg_parallel_size,
+ get_data_parallel_group,
+ get_sequence_parallel_group,
+)
+from videosys.utils.utils import batch_func
+
+from .modules import (
+ Attention,
+ CaptionEmbedder,
+ MultiHeadCrossAttention,
+ PatchEmbed3D,
+ PositionEmbedding2D,
+ SizeEmbedder,
+ T2IFinalLayer,
+ TimestepEmbedder,
+ approx_gelu,
+ get_layernorm,
+ t2i_modulate,
+)
+from .utils import auto_grad_checkpoint, load_checkpoint
+
+
+class STDiT3Block(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ rope=None,
+ qk_norm=False,
+ temporal=False,
+ enable_flash_attn=False,
+ block_idx=None,
+ ):
+ super().__init__()
+ self.temporal = temporal
+ self.hidden_size = hidden_size
+ self.enable_flash_attn = enable_flash_attn
+
+ attn_cls = Attention
+ mha_cls = MultiHeadCrossAttention
+
+ self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False)
+ self.attn = attn_cls(
+ hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ qk_norm=qk_norm,
+ rope=rope,
+ enable_flash_attn=enable_flash_attn,
+ )
+ self.cross_attn = mha_cls(hidden_size, num_heads)
+ self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False)
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
+
+ # pab
+ self.block_idx = block_idx
+ self.attn_count = 0
+ self.last_attn = None
+ self.cross_count = 0
+ self.last_cross = None
+ self.mlp_count = 0
+
+ def t_mask_select(self, x_mask, x, masked_x, T, S):
+ # x: [B, (T, S), C]
+ # mased_x: [B, (T, S), C]
+ # x_mask: [B, T]
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
+ masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
+ x = torch.where(x_mask[:, :, None, None], x, masked_x)
+ x = rearrange(x, "B T S C -> B (T S) C")
+ return x
+
+ def forward(
+ self,
+ x,
+ y,
+ t,
+ mask=None, # text mask
+ x_mask=None, # temporal mask
+ t0=None, # t with timestamp=0
+ T=None, # number of frames
+ S=None, # number of pixel patches
+ timestep=None,
+ all_timesteps=None,
+ ):
+ # prepare modulate parameters
+ B, N, C = x.shape
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+ if x_mask is not None:
+ shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
+ self.scale_shift_table[None] + t0.reshape(B, 6, -1)
+ ).chunk(6, dim=1)
+
+ if enable_pab():
+ if self.temporal:
+ broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
+ else:
+ broadcast_attn, self.attn_count = if_broadcast_spatial(
+ int(timestep[0]), self.attn_count, self.block_idx
+ )
+
+ if enable_pab() and broadcast_attn:
+ x_m_s = self.last_attn
+ else:
+ # modulate (attention)
+ x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
+ if x_mask is not None:
+ x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
+ x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
+
+ # attention
+ if self.temporal:
+ if enable_sequence_parallel():
+ x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
+ x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
+ x_m = self.attn(x_m)
+ x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
+ if enable_sequence_parallel():
+ x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
+ else:
+ x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
+ x_m = self.attn(x_m)
+ x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S)
+
+ # modulate (attention)
+ x_m_s = gate_msa * x_m
+ if x_mask is not None:
+ x_m_s_zero = gate_msa_zero * x_m
+ x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
+
+ if enable_pab():
+ self.last_attn = x_m_s
+
+ # residual
+ x = x + self.drop_path(x_m_s)
+
+ # cross attention
+ if enable_pab():
+ broadcast_cross, self.cross_count = if_broadcast_cross(int(timestep[0]), self.cross_count)
+
+ if enable_pab() and broadcast_cross:
+ x = x + self.last_cross
+ else:
+ x_cross = self.cross_attn(x, y, mask)
+ if enable_pab():
+ self.last_cross = x_cross
+ x = x + x_cross
+
+ if enable_pab():
+ broadcast_mlp, self.mlp_count, broadcast_next, skip_range = if_broadcast_mlp(
+ int(timestep[0]),
+ self.mlp_count,
+ self.block_idx,
+ all_timesteps,
+ is_temporal=self.temporal,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ x_m_s = get_mlp_output(
+ skip_range,
+ timestep=int(timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=self.temporal,
+ )
+ else:
+ # modulate (MLP)
+ x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
+ if x_mask is not None:
+ x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
+ x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
+
+ # MLP
+ x_m = self.mlp(x_m)
+
+ # modulate (MLP)
+ x_m_s = gate_mlp * x_m
+ if x_mask is not None:
+ x_m_s_zero = gate_mlp_zero * x_m
+ x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=x_m_s,
+ is_temporal=self.temporal,
+ )
+
+ # residual
+ x = x + self.drop_path(x_m_s)
+
+ return x
+
+ def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
+ if to_spatial_shard:
+ scatter_dim, gather_dim = 2, 1
+ scatter_pad = get_spatial_pad()
+ gather_pad = get_temporal_pad()
+ else:
+ scatter_dim, gather_dim = 1, 2
+ scatter_pad = get_temporal_pad()
+ gather_pad = get_spatial_pad()
+
+ x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
+ x = all_to_all_with_pad(
+ x,
+ get_sequence_parallel_group(),
+ scatter_dim=scatter_dim,
+ gather_dim=gather_dim,
+ scatter_pad=scatter_pad,
+ gather_pad=gather_pad,
+ )
+ new_s, new_t = x.shape[2], x.shape[1]
+ x = rearrange(x, "b t s d -> b (t s) d")
+ return x, new_s, new_t
+
+
+class STDiT3Config(PretrainedConfig):
+ model_type = "STDiT3"
+
+ def __init__(
+ self,
+ input_size=(None, None, None),
+ input_sq_size=512,
+ in_channels=4,
+ patch_size=(1, 2, 2),
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ pred_sigma=True,
+ drop_path=0.0,
+ caption_channels=4096,
+ model_max_length=300,
+ qk_norm=True,
+ enable_flash_attn=False,
+ only_train_temporal=False,
+ freeze_y_embedder=False,
+ skip_y_embedder=False,
+ **kwargs,
+ ):
+ self.input_size = input_size
+ self.input_sq_size = input_sq_size
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.hidden_size = hidden_size
+ self.depth = depth
+ self.num_heads = num_heads
+ self.mlp_ratio = mlp_ratio
+ self.class_dropout_prob = class_dropout_prob
+ self.pred_sigma = pred_sigma
+ self.drop_path = drop_path
+ self.caption_channels = caption_channels
+ self.model_max_length = model_max_length
+ self.qk_norm = qk_norm
+ self.enable_flash_attn = enable_flash_attn
+ self.only_train_temporal = only_train_temporal
+ self.freeze_y_embedder = freeze_y_embedder
+ self.skip_y_embedder = skip_y_embedder
+ super().__init__(**kwargs)
+
+
+class STDiT3(PreTrainedModel):
+ config_class = STDiT3Config
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.pred_sigma = config.pred_sigma
+ self.in_channels = config.in_channels
+ self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
+
+ # model size related
+ self.depth = config.depth
+ self.mlp_ratio = config.mlp_ratio
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_heads
+
+ # computation related
+ self.drop_path = config.drop_path
+ self.enable_flash_attn = config.enable_flash_attn
+
+ # input size related
+ self.patch_size = config.patch_size
+ self.input_sq_size = config.input_sq_size
+ self.pos_embed = PositionEmbedding2D(config.hidden_size)
+
+ from rotary_embedding_torch import RotaryEmbedding
+
+ self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
+
+ # embedding
+ self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
+ self.t_embedder = TimestepEmbedder(config.hidden_size)
+ self.fps_embedder = SizeEmbedder(self.hidden_size)
+ self.t_block = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True),
+ )
+ self.y_embedder = CaptionEmbedder(
+ in_channels=config.caption_channels,
+ hidden_size=config.hidden_size,
+ uncond_prob=config.class_dropout_prob,
+ act_layer=approx_gelu,
+ token_num=config.model_max_length,
+ )
+
+ # spatial blocks
+ drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
+ self.spatial_blocks = nn.ModuleList(
+ [
+ STDiT3Block(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_heads,
+ mlp_ratio=config.mlp_ratio,
+ drop_path=drop_path[i],
+ qk_norm=config.qk_norm,
+ enable_flash_attn=config.enable_flash_attn,
+ block_idx=i,
+ )
+ for i in range(config.depth)
+ ]
+ )
+
+ # temporal blocks
+ drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
+ self.temporal_blocks = nn.ModuleList(
+ [
+ STDiT3Block(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_heads,
+ mlp_ratio=config.mlp_ratio,
+ drop_path=drop_path[i],
+ qk_norm=config.qk_norm,
+ enable_flash_attn=config.enable_flash_attn,
+ # temporal
+ temporal=True,
+ rope=self.rope.rotate_queries_or_keys,
+ block_idx=i,
+ )
+ for i in range(config.depth)
+ ]
+ )
+ # final layer
+ self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
+
+ self.initialize_weights()
+ if config.only_train_temporal:
+ for param in self.parameters():
+ param.requires_grad = False
+ for block in self.temporal_blocks:
+ for param in block.parameters():
+ param.requires_grad = True
+
+ if config.freeze_y_embedder:
+ for param in self.y_embedder.parameters():
+ param.requires_grad = False
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize fps_embedder
+ nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02)
+ nn.init.constant_(self.fps_embedder.mlp[0].bias, 0)
+ nn.init.constant_(self.fps_embedder.mlp[2].weight, 0)
+ nn.init.constant_(self.fps_embedder.mlp[2].bias, 0)
+
+ # Initialize timporal blocks
+ for block in self.temporal_blocks:
+ nn.init.constant_(block.attn.proj.weight, 0)
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
+ nn.init.constant_(block.mlp.fc2.weight, 0)
+
+ def get_dynamic_size(self, x):
+ _, _, T, H, W = x.size()
+ if T % self.patch_size[0] != 0:
+ T += self.patch_size[0] - T % self.patch_size[0]
+ if H % self.patch_size[1] != 0:
+ H += self.patch_size[1] - H % self.patch_size[1]
+ if W % self.patch_size[2] != 0:
+ W += self.patch_size[2] - W % self.patch_size[2]
+ T = T // self.patch_size[0]
+ H = H // self.patch_size[1]
+ W = W // self.patch_size[2]
+ return (T, H, W)
+
+ def encode_text(self, y, mask=None):
+ y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
+ if mask is not None:
+ if mask.shape[0] != y.shape[0]:
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
+ mask = mask.squeeze(1).squeeze(1)
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
+ y_lens = mask.sum(dim=1).tolist()
+ else:
+ y_lens = [y.shape[2]] * y.shape[0]
+ y = y.squeeze(1).view(1, -1, self.hidden_size)
+ return y, y_lens
+
+ def forward(
+ self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
+ ):
+ # === Split batch ===
+ if get_cfg_parallel_size() > 1:
+ x, timestep, y, x_mask, mask = batch_func(
+ partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
+ )
+
+ dtype = self.x_embedder.proj.weight.dtype
+ B = x.size(0)
+ x = x.to(dtype)
+ timestep = timestep.to(dtype)
+ y = y.to(dtype)
+
+ # === get pos embed ===
+ _, _, Tx, Hx, Wx = x.size()
+ T, H, W = self.get_dynamic_size(x)
+ S = H * W
+ base_size = round(S**0.5)
+ resolution_sq = (height[0].item() * width[0].item()) ** 0.5
+ scale = resolution_sq / self.input_sq_size
+ pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
+
+ # === get timestep embed ===
+ t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
+ fps = self.fps_embedder(fps.unsqueeze(1), B)
+ t = t + fps
+ t_mlp = self.t_block(t)
+ t0 = t0_mlp = None
+ if x_mask is not None:
+ t0_timestep = torch.zeros_like(timestep)
+ t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
+ t0 = t0 + fps
+ t0_mlp = self.t_block(t0)
+
+ # === get y embed ===
+ if self.config.skip_y_embedder:
+ y_lens = mask
+ if isinstance(y_lens, torch.Tensor):
+ y_lens = y_lens.long().tolist()
+ else:
+ y, y_lens = self.encode_text(y, mask)
+
+ # === get x embed ===
+ x = self.x_embedder(x) # [B, N, C]
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
+ x = x + pos_emb
+
+ # shard over the sequence dim if sp is enabled
+ if enable_sequence_parallel():
+ set_temporal_pad(T)
+ set_spatial_pad(S)
+ x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ T = x.shape[1]
+ x_mask_org = x_mask
+ x_mask = split_sequence(
+ x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ )
+
+ x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
+
+ # === blocks ===
+ for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
+ x = auto_grad_checkpoint(
+ spatial_block,
+ x,
+ y,
+ t_mlp,
+ y_lens,
+ x_mask,
+ t0_mlp,
+ T,
+ S,
+ timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ x = auto_grad_checkpoint(
+ temporal_block,
+ x,
+ y,
+ t_mlp,
+ y_lens,
+ x_mask,
+ t0_mlp,
+ T,
+ S,
+ timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ if enable_sequence_parallel():
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
+ x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
+ T, S = x.shape[1], x.shape[2]
+ x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
+ x_mask = x_mask_org
+
+ # === final layer ===
+ x = self.final_layer(x, t, x_mask, t0, T, S)
+ x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
+
+ # cast to float32 for better accuracy
+ x = x.to(torch.float32)
+
+ # === Gather Output ===
+ if get_cfg_parallel_size() > 1:
+ x = gather_sequence(x, get_data_parallel_group(), dim=0)
+
+ return x
+
+ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
+ """
+ Args:
+ x (torch.Tensor): of shape [B, N, C]
+
+ Return:
+ x (torch.Tensor): of shape [B, C_out, T, H, W]
+ """
+
+ # N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
+ T_p, H_p, W_p = self.patch_size
+ x = rearrange(
+ x,
+ "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
+ N_t=N_t,
+ N_h=N_h,
+ N_w=N_w,
+ T_p=T_p,
+ H_p=H_p,
+ W_p=W_p,
+ C_out=self.out_channels,
+ )
+ # unpad
+ x = x[:, :, :R_t, :R_h, :R_w]
+ return x
+
+
+def STDiT3_XL_2(from_pretrained=None, **kwargs):
+ if from_pretrained is not None and not os.path.isdir(from_pretrained):
+ model = STDiT3.from_pretrained(from_pretrained, **kwargs)
+ else:
+ config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
+ model = STDiT3(config)
+ if from_pretrained is not None:
+ load_checkpoint(model, from_pretrained)
+ return model
diff --git a/videosys/models/open_sora/text_encoder.py b/videosys/models/open_sora/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3d5b313d84090f2ad1dc53bf06861fc0818998
--- /dev/null
+++ b/videosys/models/open_sora/text_encoder.py
@@ -0,0 +1,330 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import html
+import os
+import re
+
+import ftfy
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+
+class T5Embedder:
+ available_models = ["DeepFloyd/t5-v1_1-xxl"]
+
+ def __init__(
+ self,
+ device,
+ from_pretrained=None,
+ *,
+ cache_dir=None,
+ hf_token=None,
+ use_text_preprocessing=True,
+ t5_model_kwargs=None,
+ torch_dtype=None,
+ use_offload_folder=None,
+ model_max_length=120,
+ local_files_only=False,
+ ):
+ self.device = torch.device(device)
+ self.torch_dtype = torch_dtype or torch.bfloat16
+ self.cache_dir = cache_dir
+
+ if t5_model_kwargs is None:
+ t5_model_kwargs = {
+ "low_cpu_mem_usage": True,
+ "torch_dtype": self.torch_dtype,
+ }
+
+ if use_offload_folder is not None:
+ t5_model_kwargs["offload_folder"] = use_offload_folder
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder.embed_tokens": self.device,
+ "encoder.block.0": self.device,
+ "encoder.block.1": self.device,
+ "encoder.block.2": self.device,
+ "encoder.block.3": self.device,
+ "encoder.block.4": self.device,
+ "encoder.block.5": self.device,
+ "encoder.block.6": self.device,
+ "encoder.block.7": self.device,
+ "encoder.block.8": self.device,
+ "encoder.block.9": self.device,
+ "encoder.block.10": self.device,
+ "encoder.block.11": self.device,
+ "encoder.block.12": "disk",
+ "encoder.block.13": "disk",
+ "encoder.block.14": "disk",
+ "encoder.block.15": "disk",
+ "encoder.block.16": "disk",
+ "encoder.block.17": "disk",
+ "encoder.block.18": "disk",
+ "encoder.block.19": "disk",
+ "encoder.block.20": "disk",
+ "encoder.block.21": "disk",
+ "encoder.block.22": "disk",
+ "encoder.block.23": "disk",
+ "encoder.final_layer_norm": "disk",
+ "encoder.dropout": "disk",
+ }
+ else:
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder": self.device,
+ }
+
+ self.use_text_preprocessing = use_text_preprocessing
+ self.hf_token = hf_token
+
+ assert from_pretrained in self.available_models
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ )
+ self.model = T5EncoderModel.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ **t5_model_kwargs,
+ ).eval()
+ self.model_max_length = model_max_length
+
+ def get_text_embeddings(self, texts):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=self.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ input_ids = text_tokens_and_mask["input_ids"].to(self.device)
+ attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
+ with torch.no_grad():
+ text_encoder_embs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["last_hidden_state"].detach()
+ return text_encoder_embs, attention_mask
+
+
+class T5Encoder:
+ def __init__(
+ self,
+ from_pretrained=None,
+ model_max_length=120,
+ device="cuda",
+ dtype=torch.float,
+ cache_dir=None,
+ shardformer=False,
+ local_files_only=False,
+ ):
+ assert from_pretrained is not None, "Please specify the path to the T5 model"
+
+ self.t5 = T5Embedder(
+ device=device,
+ torch_dtype=dtype,
+ from_pretrained=from_pretrained,
+ cache_dir=cache_dir,
+ model_max_length=model_max_length,
+ local_files_only=local_files_only,
+ )
+ self.t5.model.to(dtype=dtype)
+ self.y_embedder = None
+
+ self.model_max_length = model_max_length
+ self.output_dim = self.t5.model.config.d_model
+ self.dtype = dtype
+
+ if shardformer:
+ self.shardformer_t5()
+
+ def eval(self):
+ self.t5.model.eval()
+
+ def to(self, device):
+ self.t5.model.to(device)
+
+ def shardformer_t5(self):
+ from colossalai.shardformer import ShardConfig, ShardFormer
+
+ from videosys.core.shardformer.t5.policy import T5EncoderPolicy
+ from videosys.utils.utils import requires_grad
+
+ shard_config = ShardConfig(
+ tensor_parallel_process_group=None,
+ pipeline_stage_manager=None,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=True,
+ enable_sequence_parallelism=False,
+ enable_sequence_overlap=False,
+ )
+ shard_former = ShardFormer(shard_config=shard_config)
+ optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
+ self.t5.model = optim_model.to(self.dtype)
+
+ # ensure the weights are frozen
+ requires_grad(self.t5.model, False)
+
+ def encode(self, text):
+ caption_embs, emb_masks = self.t5.get_text_embeddings(text)
+ caption_embs = caption_embs[:, None]
+ return dict(y=caption_embs, mask=emb_masks)
+
+ def null(self, n):
+ null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
+ return null_y
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+BAD_PUNCT_REGEX = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+) # noqa
+
+
+def clean_caption(caption):
+ import urllib.parse as ul
+
+ from bs4 import BeautifulSoup
+
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = basic_clean(caption)
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+
+def text_preprocessing(text, use_text_preprocessing: bool = True):
+ if use_text_preprocessing:
+ # The exact text cleaning as was in the training stage:
+ text = clean_caption(text)
+ text = clean_caption(text)
+ return text
+ else:
+ return text.lower().strip()
diff --git a/videosys/models/open_sora/utils.py b/videosys/models/open_sora/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f32611730933876070a704c923d9823bcaaebb0
--- /dev/null
+++ b/videosys/models/open_sora/utils.py
@@ -0,0 +1,179 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import os
+from collections.abc import Iterable
+
+import torch
+import torch.distributed as dist
+from colossalai.checkpoint_io import GeneralCheckpointIO
+from torch.utils.checkpoint import checkpoint, checkpoint_sequential
+from torchvision.datasets.utils import download_url
+
+from videosys.utils.logging import logger
+
+hf_endpoint = os.environ.get("HF_ENDPOINT")
+if hf_endpoint is None:
+ hf_endpoint = "https://huggingface.co"
+
+pretrained_models = {
+ "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
+ "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
+ "Latte-XL-2-256x256-ucf101.pt": hf_endpoint + "/maxin-cn/Latte/resolve/main/ucf101.pt",
+ "PixArt-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
+ "PixArt-XL-2-SAM-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
+ "PixArt-XL-2-512x512.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
+ "PixArt-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
+ "OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
+ "OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
+ "OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
+ "PixArt-Sigma-XL-2-256x256.pth": hf_endpoint
+ + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
+ "PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint
+ + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
+ "PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint
+ + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
+ "PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
+}
+
+
+def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False):
+ ckpt_io = GeneralCheckpointIO()
+ ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
+
+
+def reparameter(ckpt, name=None, model=None):
+ model_name = name
+ name = os.path.basename(name)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ logger.info("loading pretrained model: %s", model_name)
+ if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
+ ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
+ del ckpt["pos_embed"]
+ if name in ["Latte-XL-2-256x256-ucf101.pt"]:
+ ckpt = ckpt["ema"]
+ ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
+ del ckpt["pos_embed"]
+ del ckpt["temp_embed"]
+ if name in [
+ "PixArt-XL-2-256x256.pth",
+ "PixArt-XL-2-SAM-256x256.pth",
+ "PixArt-XL-2-512x512.pth",
+ "PixArt-XL-2-1024-MS.pth",
+ "PixArt-Sigma-XL-2-256x256.pth",
+ "PixArt-Sigma-XL-2-512-MS.pth",
+ "PixArt-Sigma-XL-2-1024-MS.pth",
+ "PixArt-Sigma-XL-2-2K-MS.pth",
+ ]:
+ ckpt = ckpt["state_dict"]
+ ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
+ if "pos_embed" in ckpt:
+ del ckpt["pos_embed"]
+
+ if name in [
+ "PixArt-1B-2.pth",
+ ]:
+ ckpt = ckpt["state_dict"]
+ if "pos_embed" in ckpt:
+ del ckpt["pos_embed"]
+
+ # no need pos_embed
+ if "pos_embed_temporal" in ckpt:
+ del ckpt["pos_embed_temporal"]
+ if "pos_embed" in ckpt:
+ del ckpt["pos_embed"]
+ # different text length
+ if "y_embedder.y_embedding" in ckpt:
+ if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]:
+ logger.info(
+ "Extend y_embedding from %s to %s",
+ ckpt["y_embedder.y_embedding"].shape[0],
+ model.y_embedder.y_embedding.shape[0],
+ )
+ additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0]
+ new_y_embedding = torch.zeros(additional_length, model.y_embedder.y_embedding.shape[1])
+ new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
+ ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
+ elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]:
+ logger.info(
+ "Shrink y_embedding from %s to %s",
+ ckpt["y_embedder.y_embedding"].shape[0],
+ model.y_embedder.y_embedding.shape[0],
+ )
+ ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]]
+ # stdit3 special case
+ if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name:
+ ckpt_keys = list(ckpt.keys())
+ for key in ckpt_keys:
+ if "blocks." in key:
+ ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key]
+ del ckpt[key]
+
+ return ckpt
+
+
+def find_model(model_name, model=None):
+ """
+ Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
+ model_ckpt = download_model(model_name)
+ model_ckpt = reparameter(model_ckpt, model_name, model=model)
+ else: # Load a custom DiT checkpoint:
+ assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
+ model_ckpt = torch.load(model_name, map_location=lambda storage, loc: storage)
+ model_ckpt = reparameter(model_ckpt, model_name, model=model)
+ return model_ckpt
+
+
+def download_model(model_name=None, local_path=None, url=None):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ if model_name is not None:
+ assert model_name in pretrained_models
+ local_path = f"pretrained_models/{model_name}"
+ web_path = pretrained_models[model_name]
+ else:
+ assert local_path is not None
+ assert url is not None
+ web_path = url
+ if not os.path.isfile(local_path):
+ os.makedirs("pretrained_models", exist_ok=True)
+ dir_name = os.path.dirname(local_path)
+ file_name = os.path.basename(local_path)
+ download_url(web_path, dir_name, file_name)
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False):
+ if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
+ state_dict = find_model(ckpt_path, model=model)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
+ logger.info("Missing keys: %s", missing_keys)
+ logger.info("Unexpected keys: %s", unexpected_keys)
+ elif os.path.isdir(ckpt_path):
+ load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
+ logger.info("Model checkpoint loaded from %s", ckpt_path)
+ if save_as_pt:
+ save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
+ torch.save(model.state_dict(), save_path)
+ logger.info("Model checkpoint saved to %s", save_path)
+ else:
+ raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
+
+
+def auto_grad_checkpoint(module, *args, **kwargs):
+ if getattr(module, "grad_checkpointing", False):
+ if not isinstance(module, Iterable):
+ return checkpoint(module, *args, use_reentrant=False, **kwargs)
+ gc_step = module[0].grad_checkpointing_step
+ return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
+ return module(*args, **kwargs)
diff --git a/videosys/models/open_sora/vae.py b/videosys/models/open_sora/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3f92292c4b266f76554a225f55c67c379c3cd1
--- /dev/null
+++ b/videosys/models/open_sora/vae.py
@@ -0,0 +1,769 @@
+# Adapted from OpenSora
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# OpenSora: https://github.com/hpcaitech/Open-Sora
+# --------------------------------------------------------
+
+import os
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
+from einops import rearrange
+from transformers import PretrainedConfig, PreTrainedModel
+
+from .utils import load_checkpoint
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(
+ self,
+ parameters,
+ deterministic=False,
+ ):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype)
+
+ def sample(self):
+ # torch.randn: standard normal distribution
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None: # SCH: assumes other is a standard normal distribution
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3, 4],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3, 4]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def pad_at_dim(t, pad, dim=-1):
+ dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = (0, 0) * dims_from_right
+ return F.pad(t, (*zeros, *pad), mode="constant")
+
+
+def exists(v):
+ return v is not None
+
+
+class CausalConv3d(nn.Module):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ pad_mode="constant",
+ strides=None, # allow custom stride
+ **kwargs,
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ dilation = kwargs.pop("dilation", 1)
+ stride = strides[0] if strides is not None else kwargs.pop("stride", 1)
+
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+ stride = strides if strides is not None else (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels, # SCH: added
+ filters,
+ conv_fn,
+ activation_fn=nn.SiLU,
+ use_conv_shortcut=False,
+ num_groups=32,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.filters = filters
+ self.activate = activation_fn()
+ self.use_conv_shortcut = use_conv_shortcut
+
+ # SCH: MAGVIT uses GroupNorm by default
+ self.norm1 = nn.GroupNorm(num_groups, in_channels)
+ self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
+ self.norm2 = nn.GroupNorm(num_groups, self.filters)
+ self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
+ if in_channels != filters:
+ if self.use_conv_shortcut:
+ self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
+ else:
+ self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
+
+ def forward(self, x):
+ residual = x
+ x = self.norm1(x)
+ x = self.activate(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.activate(x)
+ x = self.conv2(x)
+ if self.in_channels != self.filters: # SCH: ResBlock X->Y
+ residual = self.conv3(residual)
+ return x + residual
+
+
+def get_activation_fn(activation):
+ if activation == "relu":
+ activation_fn = nn.ReLU
+ elif activation == "swish":
+ activation_fn = nn.SiLU
+ else:
+ raise NotImplementedError
+ return activation_fn
+
+
+class Encoder(nn.Module):
+ """Encoder Blocks."""
+
+ def __init__(
+ self,
+ in_out_channels=4,
+ latent_embed_dim=512, # num channels for latent vector
+ filters=128,
+ num_res_blocks=4,
+ channel_multipliers=(1, 2, 2, 4),
+ temporal_downsample=(False, True, True),
+ num_groups=32, # for nn.GroupNorm
+ activation_fn="swish",
+ ):
+ super().__init__()
+ self.filters = filters
+ self.num_res_blocks = num_res_blocks
+ self.num_blocks = len(channel_multipliers)
+ self.channel_multipliers = channel_multipliers
+ self.temporal_downsample = temporal_downsample
+ self.num_groups = num_groups
+ self.embedding_dim = latent_embed_dim
+
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.activate = self.activation_fn()
+ self.conv_fn = CausalConv3d
+ self.block_args = dict(
+ conv_fn=self.conv_fn,
+ activation_fn=self.activation_fn,
+ use_conv_shortcut=False,
+ num_groups=self.num_groups,
+ )
+
+ # first layer conv
+ self.conv_in = self.conv_fn(
+ in_out_channels,
+ filters,
+ kernel_size=(3, 3, 3),
+ bias=False,
+ )
+
+ # ResBlocks and conv downsample
+ self.block_res_blocks = nn.ModuleList([])
+ self.conv_blocks = nn.ModuleList([])
+
+ filters = self.filters
+ prev_filters = filters # record for in_channels
+ for i in range(self.num_blocks):
+ filters = self.filters * self.channel_multipliers[i]
+ block_items = nn.ModuleList([])
+ for _ in range(self.num_res_blocks):
+ block_items.append(ResBlock(prev_filters, filters, **self.block_args))
+ prev_filters = filters # update in_channels
+ self.block_res_blocks.append(block_items)
+
+ if i < self.num_blocks - 1:
+ if self.temporal_downsample[i]:
+ t_stride = 2 if self.temporal_downsample[i] else 1
+ s_stride = 1
+ self.conv_blocks.append(
+ self.conv_fn(
+ prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)
+ )
+ )
+ prev_filters = filters # update in_channels
+ else:
+ # if no t downsample, don't add since this does nothing for pipeline models
+ self.conv_blocks.append(nn.Identity(prev_filters)) # Identity
+ prev_filters = filters # update in_channels
+
+ # last layer res block
+ self.res_blocks = nn.ModuleList([])
+ for _ in range(self.num_res_blocks):
+ self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
+ prev_filters = filters # update in_channels
+
+ # MAGVIT uses Group Normalization
+ self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
+
+ self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
+
+ def forward(self, x):
+ x = self.conv_in(x)
+
+ for i in range(self.num_blocks):
+ for j in range(self.num_res_blocks):
+ x = self.block_res_blocks[i][j](x)
+ if i < self.num_blocks - 1:
+ x = self.conv_blocks[i](x)
+ for i in range(self.num_res_blocks):
+ x = self.res_blocks[i](x)
+
+ x = self.norm1(x)
+ x = self.activate(x)
+ x = self.conv2(x)
+ return x
+
+
+class Decoder(nn.Module):
+ """Decoder Blocks."""
+
+ def __init__(
+ self,
+ in_out_channels=4,
+ latent_embed_dim=512,
+ filters=128,
+ num_res_blocks=4,
+ channel_multipliers=(1, 2, 2, 4),
+ temporal_downsample=(False, True, True),
+ num_groups=32, # for nn.GroupNorm
+ activation_fn="swish",
+ ):
+ super().__init__()
+ self.filters = filters
+ self.num_res_blocks = num_res_blocks
+ self.num_blocks = len(channel_multipliers)
+ self.channel_multipliers = channel_multipliers
+ self.temporal_downsample = temporal_downsample
+ self.num_groups = num_groups
+ self.embedding_dim = latent_embed_dim
+ self.s_stride = 1
+
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.activate = self.activation_fn()
+ self.conv_fn = CausalConv3d
+ self.block_args = dict(
+ conv_fn=self.conv_fn,
+ activation_fn=self.activation_fn,
+ use_conv_shortcut=False,
+ num_groups=self.num_groups,
+ )
+
+ filters = self.filters * self.channel_multipliers[-1]
+ prev_filters = filters
+
+ # last conv
+ self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
+
+ # last layer res block
+ self.res_blocks = nn.ModuleList([])
+ for _ in range(self.num_res_blocks):
+ self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
+
+ # ResBlocks and conv upsample
+ self.block_res_blocks = nn.ModuleList([])
+ self.num_blocks = len(self.channel_multipliers)
+ self.conv_blocks = nn.ModuleList([])
+ # reverse to keep track of the in_channels, but append also in a reverse direction
+ for i in reversed(range(self.num_blocks)):
+ filters = self.filters * self.channel_multipliers[i]
+ # resblock handling
+ block_items = nn.ModuleList([])
+ for _ in range(self.num_res_blocks):
+ block_items.append(ResBlock(prev_filters, filters, **self.block_args))
+ prev_filters = filters # SCH: update in_channels
+ self.block_res_blocks.insert(0, block_items) # SCH: append in front
+
+ # conv blocks with upsampling
+ if i > 0:
+ if self.temporal_downsample[i - 1]:
+ t_stride = 2 if self.temporal_downsample[i - 1] else 1
+ # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
+ self.conv_blocks.insert(
+ 0,
+ self.conv_fn(
+ prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
+ ),
+ )
+ else:
+ self.conv_blocks.insert(
+ 0,
+ nn.Identity(prev_filters),
+ )
+
+ self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
+
+ self.conv_out = self.conv_fn(filters, in_out_channels, 3)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ for i in range(self.num_res_blocks):
+ x = self.res_blocks[i](x)
+ for i in reversed(range(self.num_blocks)):
+ for j in range(self.num_res_blocks):
+ x = self.block_res_blocks[i][j](x)
+ if i > 0:
+ t_stride = 2 if self.temporal_downsample[i - 1] else 1
+ x = self.conv_blocks[i - 1](x)
+ x = rearrange(
+ x,
+ "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)",
+ ts=t_stride,
+ hs=self.s_stride,
+ ws=self.s_stride,
+ )
+
+ x = self.norm1(x)
+ x = self.activate(x)
+ x = self.conv_out(x)
+ return x
+
+
+class VAE_Temporal(nn.Module):
+ def __init__(
+ self,
+ in_out_channels=4,
+ latent_embed_dim=4,
+ embed_dim=4,
+ filters=128,
+ num_res_blocks=4,
+ channel_multipliers=(1, 2, 2, 4),
+ temporal_downsample=(True, True, False),
+ num_groups=32, # for nn.GroupNorm
+ activation_fn="swish",
+ ):
+ super().__init__()
+
+ self.time_downsample_factor = 2 ** sum(temporal_downsample)
+ # self.time_padding = self.time_downsample_factor - 1
+ self.patch_size = (self.time_downsample_factor, 1, 1)
+ self.out_channels = in_out_channels
+
+ # NOTE: following MAGVIT, conv in bias=False in encoder first conv
+ self.encoder = Encoder(
+ in_out_channels=in_out_channels,
+ latent_embed_dim=latent_embed_dim * 2,
+ filters=filters,
+ num_res_blocks=num_res_blocks,
+ channel_multipliers=channel_multipliers,
+ temporal_downsample=temporal_downsample,
+ num_groups=num_groups, # for nn.GroupNorm
+ activation_fn=activation_fn,
+ )
+ self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
+
+ self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
+ self.decoder = Decoder(
+ in_out_channels=in_out_channels,
+ latent_embed_dim=latent_embed_dim,
+ filters=filters,
+ num_res_blocks=num_res_blocks,
+ channel_multipliers=channel_multipliers,
+ temporal_downsample=temporal_downsample,
+ num_groups=num_groups, # for nn.GroupNorm
+ activation_fn=activation_fn,
+ )
+
+ def get_latent_size(self, input_size):
+ latent_size = []
+ for i in range(3):
+ if input_size[i] is None:
+ lsize = None
+ elif i == 0:
+ time_padding = (
+ 0
+ if (input_size[i] % self.time_downsample_factor == 0)
+ else self.time_downsample_factor - input_size[i] % self.time_downsample_factor
+ )
+ lsize = (input_size[i] + time_padding) // self.patch_size[i]
+ else:
+ lsize = input_size[i] // self.patch_size[i]
+ latent_size.append(lsize)
+ return latent_size
+
+ def encode(self, x):
+ time_padding = (
+ 0
+ if (x.shape[2] % self.time_downsample_factor == 0)
+ else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
+ )
+ x = pad_at_dim(x, (time_padding, 0), dim=2)
+ encoded_feature = self.encoder(x)
+ moments = self.quant_conv(encoded_feature).to(x.dtype)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, num_frames=None):
+ time_padding = (
+ 0
+ if (num_frames % self.time_downsample_factor == 0)
+ else self.time_downsample_factor - num_frames % self.time_downsample_factor
+ )
+ z = self.post_quant_conv(z)
+ x = self.decoder(z)
+ x = x[:, :, time_padding:]
+ return x
+
+ def forward(self, x, sample_posterior=True):
+ posterior = self.encode(x)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ recon_video = self.decode(z, num_frames=x.shape[2])
+ return recon_video, posterior, z
+
+
+def VAE_Temporal_SD(from_pretrained=None, **kwargs):
+ model = VAE_Temporal(
+ in_out_channels=4,
+ latent_embed_dim=4,
+ embed_dim=4,
+ filters=128,
+ num_res_blocks=4,
+ channel_multipliers=(1, 2, 2, 4),
+ temporal_downsample=(False, True, True),
+ **kwargs,
+ )
+ if from_pretrained is not None:
+ load_checkpoint(model, from_pretrained)
+ return model
+
+
+class VideoAutoencoderKL(nn.Module):
+ def __init__(
+ self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
+ ):
+ super().__init__()
+ self.module = AutoencoderKL.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ )
+ self.out_channels = self.module.config.latent_channels
+ self.patch_size = (1, 8, 8)
+ self.micro_batch_size = micro_batch_size
+
+ def encode(self, x):
+ # x: (B, C, T, H, W)
+ B = x.shape[0]
+ x = rearrange(x, "B C T H W -> (B T) C H W")
+
+ if self.micro_batch_size is None:
+ x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
+ else:
+ # NOTE: cannot be used for training
+ bs = self.micro_batch_size
+ x_out = []
+ for i in range(0, x.shape[0], bs):
+ x_bs = x[i : i + bs]
+ x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
+ x_out.append(x_bs)
+ x = torch.cat(x_out, dim=0)
+ x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
+ return x
+
+ def decode(self, x, **kwargs):
+ # x: (B, C, T, H, W)
+ B = x.shape[0]
+ x = rearrange(x, "B C T H W -> (B T) C H W")
+ if self.micro_batch_size is None:
+ x = self.module.decode(x / 0.18215).sample
+ else:
+ # NOTE: cannot be used for training
+ bs = self.micro_batch_size
+ x_out = []
+ for i in range(0, x.shape[0], bs):
+ x_bs = x[i : i + bs]
+ x_bs = self.module.decode(x_bs / 0.18215).sample
+ x_out.append(x_bs)
+ x = torch.cat(x_out, dim=0)
+ x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
+ return x
+
+ def get_latent_size(self, input_size):
+ latent_size = []
+ for i in range(3):
+ # assert (
+ # input_size[i] is None or input_size[i] % self.patch_size[i] == 0
+ # ), "Input size must be divisible by patch size"
+ latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
+ return latent_size
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class VideoAutoencoderKLTemporalDecoder(nn.Module):
+ def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
+ super().__init__()
+ self.module = AutoencoderKLTemporalDecoder.from_pretrained(
+ from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
+ )
+ self.out_channels = self.module.config.latent_channels
+ self.patch_size = (1, 8, 8)
+
+ def encode(self, x):
+ raise NotImplementedError
+
+ def decode(self, x, **kwargs):
+ B, _, T = x.shape[:3]
+ x = rearrange(x, "B C T H W -> (B T) C H W")
+ x = self.module.decode(x / 0.18215, num_frames=T).sample
+ x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
+ return x
+
+ def get_latent_size(self, input_size):
+ latent_size = []
+ for i in range(3):
+ # assert (
+ # input_size[i] is None or input_size[i] % self.patch_size[i] == 0
+ # ), "Input size must be divisible by patch size"
+ latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
+ return latent_size
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class VideoAutoencoderPipelineConfig(PretrainedConfig):
+ model_type = "VideoAutoencoderPipeline"
+
+ def __init__(
+ self,
+ vae_2d=None,
+ vae_temporal=None,
+ from_pretrained=None,
+ freeze_vae_2d=False,
+ cal_loss=False,
+ micro_frame_size=None,
+ shift=0.0,
+ scale=1.0,
+ **kwargs,
+ ):
+ self.vae_2d = vae_2d
+ self.vae_temporal = vae_temporal
+ self.from_pretrained = from_pretrained
+ self.freeze_vae_2d = freeze_vae_2d
+ self.cal_loss = cal_loss
+ self.micro_frame_size = micro_frame_size
+ self.shift = shift
+ self.scale = scale
+ super().__init__(**kwargs)
+
+
+class VideoAutoencoderPipeline(PreTrainedModel):
+ config_class = VideoAutoencoderPipelineConfig
+
+ def __init__(self, config: VideoAutoencoderPipelineConfig):
+ super().__init__(config=config)
+ self.spatial_vae = VideoAutoencoderKL(
+ from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
+ local_files_only=False,
+ micro_batch_size=4,
+ subfolder="vae",
+ )
+ self.temporal_vae = VAE_Temporal_SD(from_pretrained=None)
+ self.cal_loss = config.cal_loss
+ self.micro_frame_size = config.micro_frame_size
+ self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
+
+ if config.freeze_vae_2d:
+ for param in self.spatial_vae.parameters():
+ param.requires_grad = False
+
+ self.out_channels = self.temporal_vae.out_channels
+
+ # normalization parameters
+ scale = torch.tensor(config.scale)
+ shift = torch.tensor(config.shift)
+ if len(scale.shape) > 0:
+ scale = scale[None, :, None, None, None]
+ if len(shift.shape) > 0:
+ shift = shift[None, :, None, None, None]
+ self.register_buffer("scale", scale)
+ self.register_buffer("shift", shift)
+
+ def encode(self, x):
+ x_z = self.spatial_vae.encode(x)
+
+ if self.micro_frame_size is None:
+ posterior = self.temporal_vae.encode(x_z)
+ z = posterior.sample()
+ else:
+ z_list = []
+ for i in range(0, x_z.shape[2], self.micro_frame_size):
+ x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
+ posterior = self.temporal_vae.encode(x_z_bs)
+ z_list.append(posterior.sample())
+ z = torch.cat(z_list, dim=2)
+
+ if self.cal_loss:
+ return z, posterior, x_z
+ else:
+ return (z - self.shift) / self.scale
+
+ def decode(self, z, num_frames=None):
+ if not self.cal_loss:
+ z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
+
+ if self.micro_frame_size is None:
+ x_z = self.temporal_vae.decode(z, num_frames=num_frames)
+ x = self.spatial_vae.decode(x_z)
+ else:
+ x_z_list = []
+ for i in range(0, z.size(2), self.micro_z_frame_size):
+ z_bs = z[:, :, i : i + self.micro_z_frame_size]
+ x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
+ x_z_list.append(x_z_bs)
+ num_frames -= self.micro_frame_size
+ x_z = torch.cat(x_z_list, dim=2)
+ x = self.spatial_vae.decode(x_z)
+
+ if self.cal_loss:
+ return x, x_z
+ else:
+ return x
+
+ def forward(self, x):
+ assert self.cal_loss, "This method is only available when cal_loss is True"
+ z, posterior, x_z = self.encode(x)
+ x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
+ return x_rec, x_z_rec, z, posterior, x_z
+
+ def get_latent_size(self, input_size):
+ if self.micro_frame_size is None or input_size[0] is None:
+ return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
+ else:
+ sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
+ sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
+ sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
+ remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
+ if remain_temporal_size[0] > 0:
+ remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
+ sub_latent_size[0] += remain_size[0]
+ return sub_latent_size
+
+ def get_temporal_last_layer(self):
+ return self.temporal_vae.decoder.conv_out.conv.weight
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+def OpenSoraVAE_V1_2(
+ micro_batch_size=4,
+ micro_frame_size=17,
+ from_pretrained=None,
+ local_files_only=False,
+ freeze_vae_2d=False,
+ cal_loss=False,
+):
+ vae_2d = dict(
+ type="VideoAutoencoderKL",
+ from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
+ subfolder="vae",
+ micro_batch_size=micro_batch_size,
+ local_files_only=local_files_only,
+ )
+ vae_temporal = dict(
+ type="VAE_Temporal_SD",
+ from_pretrained=None,
+ )
+ shift = (-0.10, 0.34, 0.27, 0.98)
+ scale = (3.85, 2.32, 2.33, 3.06)
+ kwargs = dict(
+ vae_2d=vae_2d,
+ vae_temporal=vae_temporal,
+ freeze_vae_2d=freeze_vae_2d,
+ cal_loss=cal_loss,
+ micro_frame_size=micro_frame_size,
+ shift=shift,
+ scale=scale,
+ )
+
+ if from_pretrained is not None and not os.path.isdir(from_pretrained):
+ model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
+ else:
+ config = VideoAutoencoderPipelineConfig(**kwargs)
+ model = VideoAutoencoderPipeline(config)
+
+ if from_pretrained:
+ load_checkpoint(model, from_pretrained)
+ return model
diff --git a/videosys/models/open_sora_plan/__init__.py b/videosys/models/open_sora_plan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c19791336c70c4e58bae7a5bb02d0e74f36daf1f
--- /dev/null
+++ b/videosys/models/open_sora_plan/__init__.py
@@ -0,0 +1,7 @@
+from .pipeline import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
+
+__all__ = [
+ "OpenSoraPlanPipeline",
+ "OpenSoraPlanConfig",
+ "OpenSoraPlanPABConfig",
+]
diff --git a/videosys/models/open_sora_plan/ae.py b/videosys/models/open_sora_plan/ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd023d44fbd91c65e1d0e4092f58d6cb8ae87e5a
--- /dev/null
+++ b/videosys/models/open_sora_plan/ae.py
@@ -0,0 +1,857 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import glob
+import importlib
+import os
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers import ConfigMixin, ModelMixin
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from einops import rearrange
+from torch import nn
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+def tensor_to_video(x):
+ x = x.detach().cpu()
+ x = torch.clamp(x, -1, 1)
+ x = (x + 1) / 2
+ x = x.permute(1, 0, 2, 3).float().numpy() # c t h w ->
+ x = (255 * x).astype(np.uint8)
+ return x
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def resolve_str_to_obj(str_val, append=True):
+ if append:
+ str_val = "videosys.models.open_sora_plan.modules." + str_val
+ if "opensora.models.ae.videobase." in str_val:
+ str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.")
+ module_name, class_name = str_val.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name)
+
+
+class VideoBaseAE_PL(ModelMixin, ConfigMixin):
+ config_name = "config.json"
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def encode(self, x: torch.Tensor, *args, **kwargs):
+ pass
+
+ def decode(self, encoding: torch.Tensor, *args, **kwargs):
+ pass
+
+ @property
+ def num_training_steps(self) -> int:
+ """Total training steps inferred from datamodule and devices."""
+ if self.trainer.max_steps:
+ return self.trainer.max_steps
+
+ limit_batches = self.trainer.limit_train_batches
+ batches = len(self.train_dataloader())
+ batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
+
+ num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
+ if self.trainer.tpu_cores:
+ num_devices = max(num_devices, self.trainer.tpu_cores)
+
+ effective_accum = self.trainer.accumulate_grad_batches * num_devices
+ return (batches // effective_accum) * self.trainer.max_epochs
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt"))
+ if ckpt_files:
+ # Adapt to PyTorch Lightning
+ last_ckpt_file = ckpt_files[-1]
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ model = cls.from_config(config_file)
+ print("init from {}".format(last_ckpt_file))
+ model.init_from_ckpt(last_ckpt_file)
+ return model
+ else:
+ print(f"Loading model from {pretrained_model_name_or_path}")
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in: str = "Conv2d",
+ conv_out: str = "CasualConv3d",
+ attention: str = "AttnBlock",
+ resnet_blocks: Tuple[str] = (
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock2D",
+ "ResnetBlock3D",
+ ),
+ spatial_downsample: Tuple[str] = (
+ "Downsample",
+ "Downsample",
+ "Downsample",
+ "",
+ ),
+ temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""),
+ mid_resnet: str = "ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ double_z: bool = True,
+ ) -> None:
+ super().__init__()
+ assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks)
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1)
+
+ # ---- Downsample ----
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(hidden_size_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = hidden_size * in_ch_mult[i_level]
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if spatial_downsample[i_level]:
+ down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in)
+ curr_res = curr_res // 2
+ if temporal_downsample[i_level]:
+ down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in)
+ self.down.append(down)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if hasattr(self.down[i_level], "downsample"):
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ if hasattr(self.down[i_level], "time_downsample"):
+ hs_down = self.down[i_level].time_downsample(hs[-1])
+ hs.append(hs_down)
+
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ hidden_size: int,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = (16,),
+ conv_in: str = "Conv2d",
+ conv_out: str = "CasualConv3d",
+ attention: str = "AttnBlock",
+ resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ spatial_upsample: Tuple[str] = (
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"),
+ mid_resnet: str = "ResnetBlock3D",
+ dropout: float = 0.0,
+ resolution: int = 256,
+ num_res_blocks: int = 2,
+ ):
+ super().__init__()
+ # ---- Config ----
+ self.num_resolutions = len(hidden_size_mult)
+ self.resolution = resolution
+ self.num_res_blocks = num_res_blocks
+
+ # ---- In ----
+ block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1)
+
+ # ---- Mid ----
+ self.mid = nn.Module()
+ self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
+ self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
+ in_channels=block_in,
+ out_channels=block_in,
+ dropout=dropout,
+ )
+
+ # ---- Upsample ----
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = hidden_size * hidden_size_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ resolve_str_to_obj(resnet_blocks[i_level])(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(resolve_str_to_obj(attention)(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if spatial_upsample[i_level]:
+ up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in)
+ curr_res = curr_res * 2
+ if temporal_upsample[i_level]:
+ up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in)
+ self.up.insert(0, up)
+
+ # ---- Out ----
+ self.norm_out = Normalize(block_in)
+ self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1)
+
+ def forward(self, z):
+ h = self.conv_in(z)
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if hasattr(self.up[i_level], "upsample"):
+ h = self.up[i_level].upsample(h)
+ if hasattr(self.up[i_level], "time_upsample"):
+ h = self.up[i_level].time_upsample(h)
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class CausalVAEModel(VideoBaseAE_PL):
+ @register_to_config
+ def __init__(
+ self,
+ lr: float = 1e-5,
+ hidden_size: int = 128,
+ z_channels: int = 4,
+ hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
+ attn_resolutions: Tuple[int] = [],
+ dropout: float = 0.0,
+ resolution: int = 256,
+ double_z: bool = True,
+ embed_dim: int = 4,
+ num_res_blocks: int = 2,
+ loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
+ loss_params: dict = {
+ "kl_weight": 0.000001,
+ "logvar_init": 0.0,
+ "disc_start": 2001,
+ "disc_weight": 0.5,
+ },
+ q_conv: str = "CausalConv3d",
+ encoder_conv_in: str = "CausalConv3d",
+ encoder_conv_out: str = "CausalConv3d",
+ encoder_attention: str = "AttnBlock3D",
+ encoder_resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ encoder_spatial_downsample: Tuple[str] = (
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "SpatialDownsample2x",
+ "",
+ ),
+ encoder_temporal_downsample: Tuple[str] = (
+ "",
+ "TimeDownsample2x",
+ "TimeDownsample2x",
+ "",
+ ),
+ encoder_mid_resnet: str = "ResnetBlock3D",
+ decoder_conv_in: str = "CausalConv3d",
+ decoder_conv_out: str = "CausalConv3d",
+ decoder_attention: str = "AttnBlock3D",
+ decoder_resnet_blocks: Tuple[str] = (
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ "ResnetBlock3D",
+ ),
+ decoder_spatial_upsample: Tuple[str] = (
+ "",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ "SpatialUpsample2x",
+ ),
+ decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"),
+ decoder_mid_resnet: str = "ResnetBlock3D",
+ ) -> None:
+ super().__init__()
+ self.tile_sample_min_size = 256
+ self.tile_sample_min_size_t = 65
+ self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
+ t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
+ self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1
+ self.tile_overlap_factor = 0.25
+ self.use_tiling = False
+
+ self.learning_rate = lr
+ self.lr_g_factor = 1.0
+
+ self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params)
+
+ self.encoder = Encoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=encoder_conv_in,
+ conv_out=encoder_conv_out,
+ attention=encoder_attention,
+ resnet_blocks=encoder_resnet_blocks,
+ spatial_downsample=encoder_spatial_downsample,
+ temporal_downsample=encoder_temporal_downsample,
+ mid_resnet=encoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ double_z=double_z,
+ )
+
+ self.decoder = Decoder(
+ z_channels=z_channels,
+ hidden_size=hidden_size,
+ hidden_size_mult=hidden_size_mult,
+ attn_resolutions=attn_resolutions,
+ conv_in=decoder_conv_in,
+ conv_out=decoder_conv_out,
+ attention=decoder_attention,
+ resnet_blocks=decoder_resnet_blocks,
+ spatial_upsample=decoder_spatial_upsample,
+ temporal_upsample=decoder_temporal_upsample,
+ mid_resnet=decoder_mid_resnet,
+ dropout=dropout,
+ resolution=resolution,
+ num_res_blocks=num_res_blocks,
+ )
+
+ quant_conv_cls = resolve_str_to_obj(q_conv)
+ self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
+ self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
+ if hasattr(self.loss, "discriminator"):
+ self.automatic_optimization = False
+
+ def encode(self, x):
+ if self.use_tiling and (
+ x.shape[-1] > self.tile_sample_min_size
+ or x.shape[-2] > self.tile_sample_min_size
+ or x.shape[-3] > self.tile_sample_min_size_t
+ ):
+ return self.tiled_encode(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ if self.use_tiling and (
+ z.shape[-1] > self.tile_latent_min_size
+ or z.shape[-2] > self.tile_latent_min_size
+ or z.shape[-3] > self.tile_latent_min_size_t
+ ):
+ return self.tiled_decode(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx):
+ if hasattr(self.loss, "discriminator"):
+ return self._training_step_gan(batch, batch_idx=batch_idx)
+ else:
+ return self._training_step(batch, batch_idx=batch_idx)
+
+ def _training_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ split="train",
+ )
+ self.log(
+ "aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def _training_step_gan(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ reconstructions, posterior = self(inputs)
+ opt1, opt2 = self.optimizers()
+
+ # ---- AE Loss ----
+ aeloss, log_dict_ae = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log(
+ "aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ opt1.zero_grad()
+ self.manual_backward(aeloss)
+ self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm")
+ opt1.step()
+ # ---- GAN Loss ----
+ discloss, log_dict_disc = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log(
+ "discloss",
+ discloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ opt2.zero_grad()
+ self.manual_backward(discloss)
+ self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm")
+ opt2.step()
+ self.log_dict(
+ {**log_dict_ae, **log_dict_disc},
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ def configure_optimizers(self):
+ from itertools import chain
+
+ lr = self.learning_rate
+ modules_to_train = [
+ self.encoder.named_parameters(),
+ self.decoder.named_parameters(),
+ self.post_quant_conv.named_parameters(),
+ self.quant_conv.named_parameters(),
+ ]
+ params_with_time = []
+ params_without_time = []
+ for name, param in chain(*modules_to_train):
+ if "time" in name:
+ params_with_time.append(param)
+ else:
+ params_without_time.append(param)
+ optimizers = []
+ opt_ae = torch.optim.Adam(
+ [
+ {"params": params_with_time, "lr": lr},
+ {"params": params_without_time, "lr": lr},
+ ],
+ lr=lr,
+ betas=(0.5, 0.9),
+ )
+ optimizers.append(opt_ae)
+
+ if hasattr(self.loss, "discriminator"):
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
+ optimizers.append(opt_disc)
+
+ return optimizers, []
+
+ def get_last_layer(self):
+ if hasattr(self.decoder.conv_out, "conv"):
+ return self.decoder.conv_out.conv.weight
+ else:
+ return self.decoder.conv_out.weight
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ moments = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
+ else:
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)
+ moments.append(moment)
+ moments = torch.cat(moments, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def tiled_decode(self, x):
+ t = x.shape[2]
+ t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)]
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
+ t_chunk_start_end = [[0, t]]
+ else:
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
+ if t_chunk_start_end[-1][-1] > t:
+ t_chunk_start_end[-1][-1] = t
+ elif t_chunk_start_end[-1][-1] < t:
+ last_start_end = [t_chunk_idx[-1], t]
+ t_chunk_start_end.append(last_start_end)
+ dec_ = []
+ for idx, (start, end) in enumerate(t_chunk_start_end):
+ chunk_x = x[:, :, start:end]
+ if idx != 0:
+ dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
+ else:
+ dec = self.tiled_decode2d(chunk_x)
+ dec_.append(dec)
+ dec_ = torch.cat(dec_, dim=2)
+ return dec_
+
+ def tiled_encode2d(self, x, return_moments=False):
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[3], overlap_size):
+ row = []
+ for j in range(0, x.shape[4], overlap_size):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ moments = torch.cat(result_rows, dim=3)
+ posterior = DiagonalGaussianDistribution(moments)
+ if return_moments:
+ return moments
+ return posterior
+
+ def tiled_decode2d(self, z):
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[3], overlap_size):
+ row = []
+ for j in range(0, z.shape[4], overlap_size):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+ return dec
+
+ def enable_tiling(self, use_tiling: bool = True):
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ self.enable_tiling(False)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False):
+ sd = torch.load(path, map_location="cpu")
+ print("init from " + path)
+ if "state_dict" in sd:
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, "video")
+ latents = self.encode(inputs).sample()
+ video_recon = self.decode(latents)
+ for idx in range(len(video_recon)):
+ self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10])
+
+
+class CausalVAEModelWrapper(nn.Module):
+ def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
+ super(CausalVAEModelWrapper, self).__init__()
+ # if os.path.exists(ckpt):
+ # self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
+ self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
+
+ def encode(self, x): # b c t h w
+ # x = self.vae.encode(x).sample()
+ x = self.vae.encode(x).sample().mul_(0.18215)
+ return x
+
+ def decode(self, x):
+ # x = self.vae.decode(x)
+ x = self.vae.decode(x / 0.18215)
+ x = rearrange(x, "b c t h w -> b t c h w").contiguous()
+ return x
+
+ def dtype(self):
+ return self.vae.dtype
+
+ #
+ # def device(self):
+ # return self.vae.device
+
+
+videobase_ae_stride = {
+ "CausalVAEModel_4x8x8": [4, 8, 8],
+}
+
+videobase_ae_channel = {
+ "CausalVAEModel_4x8x8": 4,
+}
+
+videobase_ae = {
+ "CausalVAEModel_4x8x8": CausalVAEModelWrapper,
+}
+
+
+ae_stride_config = {}
+ae_stride_config.update(videobase_ae_stride)
+
+ae_channel_config = {}
+ae_channel_config.update(videobase_ae_channel)
+
+
+def getae_wrapper(ae):
+ """deprecation"""
+ ae = videobase_ae.get(ae, None)
+ assert ae is not None
+ return ae
diff --git a/videosys/models/open_sora_plan/latte.py b/videosys/models/open_sora_plan/latte.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa0fe146b72a50039797b8d33edaae224fe5cd2
--- /dev/null
+++ b/videosys/models/open_sora_plan/latte.py
@@ -0,0 +1,2835 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+
+import json
+import os
+from dataclasses import dataclass
+from functools import partial
+from importlib import import_module
+from typing import Any, Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention_processor import (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ AttnProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+ LoRAAttnAddedKVProcessor,
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ SlicedAttnAddedKVProcessor,
+ SlicedAttnProcessor,
+ SpatialNorm,
+ XFormersAttnAddedKVProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange, repeat
+from torch import nn
+
+from videosys.core.comm import (
+ all_to_all_with_pad,
+ gather_sequence,
+ get_spatial_pad,
+ get_temporal_pad,
+ set_spatial_pad,
+ set_temporal_pad,
+ split_sequence,
+)
+from videosys.core.pab_mgr import (
+ enable_pab,
+ get_mlp_output,
+ if_broadcast_cross,
+ if_broadcast_mlp,
+ if_broadcast_spatial,
+ if_broadcast_temporal,
+ save_mlp_output,
+)
+from videosys.core.parallel_mgr import (
+ enable_sequence_parallel,
+ get_cfg_parallel_group,
+ get_cfg_parallel_size,
+ get_sequence_parallel_group,
+)
+from videosys.utils.logging import logger
+from videosys.utils.utils import batch_func
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+SPATIAL_LIST = []
+TEMPROAL_LIST = []
+CROSS_LIST = []
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed(embed_dim, length, interpolation_scale=1.0, base_size=16):
+ pos = torch.arange(0, length).unsqueeze(1) / interpolation_scale
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
+ return pos_embed
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class RoPE2D(torch.nn.Module):
+ def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.scaling_factor = scaling_factor
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two"
+ D = tokens.size(3) // 2
+ assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
+ x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
+
+
+class LinearScalingRoPE2D(RoPE2D):
+ """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148"""
+
+ def forward(self, tokens, positions):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ dtype = positions.dtype
+ positions = positions.float() / self.scaling_factor
+ positions = positions.to(dtype)
+ tokens = super().forward(tokens, positions)
+ return tokens
+
+
+class RoPE1D(torch.nn.Module):
+ def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.scaling_factor = scaling_factor
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens (t position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ D = tokens.size(3)
+ assert positions.ndim == 2 # Batch, Seq
+ cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
+ tokens = self.apply_rope1d(tokens, positions, cos, sin)
+ return tokens
+
+
+class LinearScalingRoPE1D(RoPE1D):
+ """Code from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L148"""
+
+ def forward(self, tokens, positions):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ dtype = positions.dtype
+ positions = positions.float() / self.scaling_factor
+ positions = positions.to(dtype)
+ tokens = super().forward(tokens, positions)
+ return tokens
+
+
+class PositionGetter2D(object):
+ """return positions of patches"""
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if not (h, w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
+ return pos
+
+
+class PositionGetter1D(object):
+ """return positions of patches"""
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, l, device):
+ if not (l) in self.cache_positions:
+ x = torch.arange(l, device=device)
+ self.cache_positions[l] = x # (l, )
+ pos = self.cache_positions[l].view(1, l).expand(b, -1).clone()
+ return pos
+
+
+class CombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.use_additional_conditions = True
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
+ if size.ndim == 1:
+ size = size[:, None]
+
+ if size.shape[0] != batch_size:
+ size = size.repeat(batch_size // size.shape[0], 1)
+ if size.shape[0] != batch_size:
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
+
+ current_batch_size, dims = size.shape[0], size.shape[1]
+ size = size.reshape(-1)
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
+
+ size_emb = embedder(size_freq)
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
+ return size_emb
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
+ aspect_ratio = self.apply_condition(
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
+ )
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class CaptionProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, num_tokens=120):
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ self.act_1 = nn.GELU(approximate="tanh")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
+
+ def forward(self, caption, force_drop_ids=None):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ # raise ValueError
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+ return (latent + pos_embed).to(latent.dtype)
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.use_rope = use_rope
+ self.rope_scaling = rope_scaling
+ self.compress_kv_factor = compress_kv_factor
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ if USE_PEFT_BACKEND:
+ linear_cls = nn.Linear
+ else:
+ linear_cls = LoRACompatibleLinear
+
+ assert not (
+ self.use_rope and (self.compress_kv_factor is not None)
+ ), "Can not both enable compressing kv and using rope"
+ if self.compress_kv_factor is not None:
+ self._init_compress()
+
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0(
+ self.inner_dim,
+ attention_mode,
+ use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ r"""
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ is_lora = hasattr(self, "processor")
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ LoRAAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_lora:
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
+ processor = LoRAXFormersAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ # throw warning
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_lora:
+ attn_processor_class = (
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+ )
+ processor = attn_processor_class(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ attn_processor_class = (
+ CustomDiffusionAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else CustomDiffusionAttnProcessor
+ )
+ processor = attn_processor_class(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ r"""
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ _remove_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to remove LoRA layers from the model.
+ """
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ deprecate(
+ "set_processor to offload LoRA",
+ "0.26.0",
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
+ )
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
+ # We need to remove all LoRA layers
+ # Don't forget to remove ALL `_remove_lora` from the codebase
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False):
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
+ # with PEFT is completed.
+ is_lora_activated = {
+ name: module.lora_layer is not None
+ for name, module in self.named_modules()
+ if hasattr(module, "lora_layer")
+ }
+
+ # 1. if no layer has a LoRA activated we can return the processor as usual
+ if not any(is_lora_activated.values()):
+ return self.processor
+
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
+ is_lora_activated.pop("add_k_proj", None)
+ is_lora_activated.pop("add_v_proj", None)
+ # 2. else it is not posssible that only some layers have LoRA activated
+ if not all(is_lora_activated.values()):
+ raise ValueError(
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
+ )
+
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
+ non_lora_processor_cls_name = self.processor.__class__.__name__
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
+
+ hidden_size = self.inner_dim
+
+ # now create a LoRA attention processor from the LoRA layers
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
+ kwargs = {
+ "cross_attention_dim": self.cross_attention_dim,
+ "rank": self.to_q.lora_layer.rank,
+ "network_alpha": self.to_q.lora_layer.network_alpha,
+ "q_rank": self.to_q.lora_layer.rank,
+ "q_hidden_size": self.to_q.lora_layer.out_features,
+ "k_rank": self.to_k.lora_layer.rank,
+ "k_hidden_size": self.to_k.lora_layer.out_features,
+ "v_rank": self.to_v.lora_layer.rank,
+ "v_hidden_size": self.to_v.lora_layer.out_features,
+ "out_rank": self.to_out[0].lora_layer.rank,
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
+ }
+
+ if hasattr(self.processor, "attention_op"):
+ kwargs["attention_op"] = self.processor.attention_op
+
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
+ lora_processor = lora_processor_cls(
+ hidden_size,
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
+ rank=self.to_q.lora_layer.rank,
+ network_alpha=self.to_q.lora_layer.network_alpha,
+ )
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+
+ # only save if used
+ if self.add_k_proj.lora_layer is not None:
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
+ else:
+ lora_processor.add_k_proj_lora = None
+ lora_processor.add_v_proj_lora = None
+ else:
+ raise ValueError(f"{lora_processor_cls} does not exist.")
+
+ return lora_processor
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ def _init_compress(self):
+ if len(self.compress_kv_factor) == 2:
+ self.sr = nn.Conv2d(
+ self.inner_dim,
+ self.inner_dim,
+ groups=self.inner_dim,
+ kernel_size=self.compress_kv_factor,
+ stride=self.compress_kv_factor,
+ )
+ self.sr.weight.data.fill_(1 / self.compress_kv_factor[0] ** 2)
+ elif len(self.compress_kv_factor) == 1:
+ self.kernel_size = self.compress_kv_factor[0]
+ self.sr = nn.Conv1d(
+ self.inner_dim,
+ self.inner_dim,
+ groups=self.inner_dim,
+ kernel_size=self.compress_kv_factor[0],
+ stride=self.compress_kv_factor[0],
+ )
+ self.sr.weight.data.fill_(1 / self.compress_kv_factor[0])
+ self.sr.bias.data.zero_()
+ self.norm = nn.LayerNorm(self.inner_dim)
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, dim=1152, attention_mode="xformers", use_rope=False, rope_scaling=None, compress_kv_factor=None):
+ self.dim = dim
+ self.attention_mode = attention_mode
+ self.use_rope = use_rope
+ self.rope_scaling = rope_scaling
+ self.compress_kv_factor = compress_kv_factor
+ if self.use_rope:
+ self._init_rope()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def _init_rope(self):
+ if self.rope_scaling is None:
+ self.rope2d = RoPE2D()
+ self.rope1d = RoPE1D()
+ else:
+ scaling_type = self.rope_scaling["type"]
+ scaling_factor_2d = self.rope_scaling["factor_2d"]
+ scaling_factor_1d = self.rope_scaling["factor_1d"]
+ if scaling_type == "linear":
+ self.rope2d = LinearScalingRoPE2D(scaling_factor=scaling_factor_2d)
+ self.rope1d = LinearScalingRoPE1D(scaling_factor=scaling_factor_1d)
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ last_shape: Tuple[int] = None,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if self.compress_kv_factor is not None:
+ batch_size = hidden_states.shape[0]
+ if len(last_shape) == 2:
+ encoder_hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, self.dim, *last_shape)
+ encoder_hidden_states = (
+ attn.sr(encoder_hidden_states).reshape(batch_size, self.dim, -1).permute(0, 2, 1)
+ )
+ elif len(last_shape) == 1:
+ encoder_hidden_states = hidden_states.permute(0, 2, 1)
+ if last_shape[0] % 2 == 1:
+ first_frame_pad = encoder_hidden_states[:, :, :1].repeat((1, 1, attn.kernel_size - 1))
+ encoder_hidden_states = torch.concatenate((first_frame_pad, encoder_hidden_states), dim=2)
+ encoder_hidden_states = attn.sr(encoder_hidden_states).permute(0, 2, 1)
+ else:
+ raise NotImplementedError(f"NotImplementedError with last_shape {last_shape}")
+
+ encoder_hidden_states = attn.norm(encoder_hidden_states)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if self.use_rope:
+ # require the shape of (batch_size x nheads x ntokens x dim)
+ if position_q.ndim == 3:
+ query = self.rope2d(query, position_q)
+ elif position_q.ndim == 2:
+ query = self.rope1d(query, position_q)
+ else:
+ raise NotImplementedError
+ if position_k.ndim == 3:
+ key = self.rope2d(key, position_k)
+ elif position_k.ndim == 2:
+ key = self.rope1d(key, position_k)
+ else:
+ raise NotImplementedError
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ if self.attention_mode == "flash":
+ assert attention_mask is None or torch.all(
+ attention_mask.bool()
+ ), "flash-attn do not support attention_mask"
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ elif self.attention_mode == "xformers":
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ elif self.attention_mode == "math":
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ else:
+ raise NotImplementedError(f"Found attention_mode: {self.attention_mode}")
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+ for module in self.net:
+ if isinstance(module, compatible_cls):
+ hidden_states = module(hidden_states, scale)
+ else:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock_(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+
+ # # 2. Cross-Attn
+ # if cross_attention_dim is not None or double_self_attention:
+ # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # # the second cross attention block.
+ # self.norm2 = (
+ # AdaLayerNorm(dim, num_embeds_ada_norm)
+ # if self.use_ada_layer_norm
+ # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ # )
+ # self.attn2 = Attention(
+ # query_dim=dim,
+ # cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # upcast_attention=upcast_attention,
+ # ) # is self-attn if encoder_hidden_states is none
+ # else:
+ # self.norm2 = None
+ # self.attn2 = None
+
+ # 3. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.last_out = None
+ self.count = 0
+ self.block_idx = block_idx
+ self.temp_mlp_count = 0
+
+ def set_last_out(self, last_out: torch.Tensor):
+ self.last_out = last_out
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ frame: int = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ broadcast_temporal, self.count = if_broadcast_temporal(int(org_timestep[0]), self.count)
+ if broadcast_temporal:
+ attn_output = self.last_out
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ if enable_sequence_parallel():
+ norm_hidden_states = self.dynamic_switch(norm_hidden_states, to_spatial_shard=True)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ position_q=position_q,
+ position_k=position_k,
+ last_shape=frame,
+ **cross_attention_kwargs,
+ )
+
+ if enable_sequence_parallel():
+ attn_output = self.dynamic_switch(attn_output, to_spatial_shard=False)
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.set_last_out(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # # 3. Cross-Attention
+ # if self.attn2 is not None:
+ # if self.use_ada_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states, timestep)
+ # elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ # norm_hidden_states = self.norm2(hidden_states)
+ # elif self.use_ada_layer_norm_single:
+ # # For PixArt norm2 isn't applied here:
+ # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ # norm_hidden_states = hidden_states
+ # else:
+ # raise ValueError("Incorrect norm")
+
+ # if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+ # norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # attn_output = self.attn2(
+ # norm_hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # attention_mask=encoder_attention_mask,
+ # **cross_attention_kwargs,
+ # )
+ # hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # if not self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm3(hidden_states)
+
+ if enable_pab():
+ broadcast_mlp, self.temp_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.temp_mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=True,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=True,
+ )
+ else:
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ # norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=True,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+ def dynamic_switch(self, x, to_spatial_shard: bool):
+ if to_spatial_shard:
+ scatter_dim, gather_dim = 0, 1
+ scatter_pad = get_spatial_pad()
+ gather_pad = get_temporal_pad()
+ else:
+ scatter_dim, gather_dim = 1, 0
+ scatter_pad = get_temporal_pad()
+ gather_pad = get_spatial_pad()
+ x = all_to_all_with_pad(
+ x,
+ get_sequence_parallel_group(),
+ scatter_dim=scatter_dim,
+ gather_dim=gather_dim,
+ scatter_pad=scatter_pad,
+ gather_pad=gather_pad,
+ )
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ attention_mode: str = "xformers",
+ use_rope: bool = False,
+ rope_scaling: Optional[Dict] = None,
+ compress_kv_factor: Optional[Tuple] = None,
+ block_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=compress_kv_factor,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ attention_mode=attention_mode, # only xformers support attention_mask
+ use_rope=False, # do not position in cross attention
+ compress_kv_factor=None,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ # pab
+ self.cross_last = None
+ self.cross_count = 0
+ self.spatial_last = None
+ self.spatial_count = 0
+ self.block_idx = block_idx
+ self.spatila_mlp_count = 0
+
+ def set_cross_last(self, last_out: torch.Tensor):
+ self.cross_last = last_out
+
+ def set_spatial_last(self, last_out: torch.Tensor):
+ self.spatial_last = last_out
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ position_q: Optional[torch.LongTensor] = None,
+ position_k: Optional[torch.LongTensor] = None,
+ hw: Tuple[int, int] = None,
+ org_timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ broadcast_spatial, self.spatial_count = if_broadcast_spatial(
+ int(org_timestep[0]), self.spatial_count, self.block_idx
+ )
+ if broadcast_spatial:
+ attn_output = self.spatial_last
+ assert self.use_ada_layer_norm_single
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ position_q=position_q,
+ position_k=position_k,
+ last_shape=hw,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ if enable_pab():
+ self.set_spatial_last(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ broadcast_cross, self.cross_count = if_broadcast_cross(int(org_timestep[0]), self.cross_count)
+ if broadcast_cross:
+ hidden_states = hidden_states + self.cross_last
+ else:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_q=None, # cross attn do not need relative position
+ position_k=None,
+ last_shape=None,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ if enable_pab():
+ self.set_cross_last(attn_output)
+
+ if enable_pab():
+ broadcast_mlp, self.spatila_mlp_count, broadcast_next, broadcast_range = if_broadcast_mlp(
+ int(org_timestep[0]),
+ self.spatila_mlp_count,
+ self.block_idx,
+ all_timesteps.tolist(),
+ is_temporal=False,
+ )
+
+ if enable_pab() and broadcast_mlp:
+ ff_output = get_mlp_output(
+ broadcast_range,
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ is_temporal=False,
+ )
+ else:
+ # 4. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ if enable_pab() and broadcast_next:
+ save_mlp_output(
+ timestep=int(org_timestep[0]),
+ block_idx=self.block_idx,
+ ff_output=ff_output,
+ is_temporal=False,
+ )
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = CombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ batch_size: int = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None
+ )
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class LatteT2V(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ patch_size_t: int = 1,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ video_length: int = 16,
+ attention_mode: str = "flash",
+ use_rope: bool = False,
+ model_max_length: int = 300,
+ rope_scaling_type: str = "linear",
+ compress_kv_factor: int = 1,
+ interpolation_scale_1d: float = None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.video_length = video_length
+ self.use_rope = use_rope
+ self.model_max_length = model_max_length
+ self.compress_kv_factor = compress_kv_factor
+ self.num_layers = num_layers
+ self.config.hidden_size = model_max_length
+
+ assert not (self.compress_kv_factor != 1 and use_rope), "Can not both enable compressing kv and using rope"
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ # self.is_input_patches = in_channels is not None and patch_size is not None
+ self.is_input_patches = True
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ # 2. Define input layers
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size[0]
+ self.width = sample_size[1]
+
+ self.patch_size = patch_size
+ interpolation_scale_2d = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale_2d = max(interpolation_scale_2d, 1)
+ self.pos_embed = PatchEmbed(
+ height=sample_size[0],
+ width=sample_size[1],
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale_2d,
+ )
+
+ # define temporal positional embedding
+ if interpolation_scale_1d is None:
+ if self.config.video_length % 2 == 1:
+ interpolation_scale_1d = (
+ self.config.video_length - 1
+ ) // 16 # => 16 (= 16 Latte) has interpolation scale 1
+ else:
+ interpolation_scale_1d = self.config.video_length // 16 # => 16 (= 16 Latte) has interpolation scale 1
+ # interpolation_scale_1d = self.config.video_length // 5 #
+ interpolation_scale_1d = max(interpolation_scale_1d, 1)
+ temp_pos_embed = get_1d_sincos_pos_embed(
+ inner_dim, video_length, interpolation_scale=interpolation_scale_1d
+ ) # 1152 hidden size
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+
+ rope_scaling = None
+ if self.use_rope:
+ self.position_getter_2d = PositionGetter2D()
+ self.position_getter_1d = PositionGetter1D()
+ rope_scaling = dict(
+ type=rope_scaling_type, factor_2d=interpolation_scale_2d, factor_1d=interpolation_scale_1d
+ )
+
+ # 3. Define transformers blocks, spatial attention
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=(compress_kv_factor, compress_kv_factor)
+ if d >= num_layers // 2 and compress_kv_factor != 1
+ else None, # follow pixart-sigma, apply in second-half layers
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # Define temporal transformers blocks
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock_( # one attention
+ inner_dim,
+ num_attention_heads, # num_attention_heads
+ attention_head_dim, # attention_head_dim 72
+ dropout=dropout,
+ cross_attention_dim=None,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=False,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ attention_mode=attention_mode,
+ use_rope=use_rope,
+ rope_scaling=rope_scaling,
+ compress_kv_factor=(compress_kv_factor,)
+ if d >= num_layers // 2 and compress_kv_factor != 1
+ else None, # follow pixart-sigma, apply in second-half layers
+ block_idx=d,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, in_channels)
+ else:
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches and norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def make_position(self, b, t, use_image_num, h, w, device):
+ pos_hw = self.position_getter_2d(b * (t + use_image_num), h, w, device) # fake_b = b*(t+use_image_num)
+ pos_t = self.position_getter_1d(b * h * w, t, device) # fake_b = b*h*w
+ return pos_hw, pos_t
+
+ def make_attn_mask(self, attention_mask, frame, dtype):
+ attention_mask = rearrange(attention_mask, "b t h w -> (b t) 1 (h w)")
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(dtype)) * -10000.0
+ attention_mask = attention_mask.to(self.dtype)
+ return attention_mask
+
+ def vae_to_diff_mask(self, attention_mask, use_image_num):
+ dtype = attention_mask.dtype
+ # b, t+use_image_num, h, w, assume t as channel
+ # this version do not use 3d patch embedding
+ attention_mask = F.max_pool2d(
+ attention_mask, kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size)
+ )
+ attention_mask = attention_mask.bool().to(dtype)
+ return attention_mask
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Optional[torch.LongTensor] = None,
+ all_timesteps=None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ enable_temporal_attentions: bool = True,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # 0. Split batch
+ if get_cfg_parallel_size() > 1:
+ (
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ ) = batch_func(
+ partial(split_sequence, process_group=get_cfg_parallel_group(), dim=0),
+ hidden_states,
+ timestep,
+ encoder_hidden_states,
+ class_labels,
+ attention_mask,
+ encoder_attention_mask,
+ )
+ input_batch_size, c, frame, h, w = hidden_states.shape
+ frame = frame - use_image_num # 20-4=16
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ org_timestep = timestep
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
+ )
+ attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
+ dtype = attention_mask.dtype
+ attention_mask_compress = F.max_pool2d(
+ attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
+ )
+ attention_mask_compress = attention_mask_compress.to(dtype)
+
+ attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
+ attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
+
+ # 1 + 4, 1 -> video condition, 4 -> image condition
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+ encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
+ encoder_attention_mask = encoder_attention_mask.to(self.dtype)
+ elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
+ encoder_attention_mask_video = repeat(
+ encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
+ ).contiguous()
+ encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
+ encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
+ encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
+ encoder_attention_mask = encoder_attention_mask.to(self.dtype)
+
+ # Retrieve lora scale.
+ cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 1. Input
+ if self.is_input_patches: # here
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hw = (height, width)
+ num_patches = height * width
+
+ hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ # batch_size = hidden_states.shape[0]
+ batch_size = input_batch_size
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
+
+ if use_image_num != 0 and self.training:
+ encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
+ encoder_hidden_states_video = repeat(
+ encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
+ ).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
+ else:
+ encoder_hidden_states_spatial = repeat(
+ encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
+ ).contiguous()
+
+ # prepare timesteps for spatial and temporal block
+ timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
+
+ pos_hw, pos_t = None, None
+ if self.use_rope:
+ pos_hw, pos_t = self.make_position(
+ input_batch_size, frame, use_image_num, height, width, hidden_states.device
+ )
+
+ if enable_sequence_parallel():
+ set_temporal_pad(frame + use_image_num)
+ set_spatial_pad(num_patches)
+ hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
+ encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
+ timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
+ attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
+ attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
+ temp_pos_embed = split_sequence(
+ self.temp_pos_embed, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
+ )
+ else:
+ temp_pos_embed = self.temp_pos_embed
+
+ for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ spatial_block,
+ hidden_states,
+ attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ pos_hw,
+ pos_hw,
+ hw,
+ use_reentrant=False,
+ )
+
+ if enable_temporal_attentions:
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0: # image-video joitn training
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states_video = hidden_states_video + temp_pos_embed
+
+ hidden_states_video = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ use_reentrant=False,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states = hidden_states + temp_pos_embed
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ temp_block,
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ use_reentrant=False,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+ else:
+ hidden_states = spatial_block(
+ hidden_states,
+ attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
+ encoder_hidden_states_spatial,
+ encoder_attention_mask,
+ timestep_spatial,
+ cross_attention_kwargs,
+ class_labels,
+ pos_hw,
+ pos_hw,
+ hw,
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ if enable_temporal_attentions:
+ # b c f h w, f = 16 + 4
+ hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
+
+ if use_image_num != 0 and self.training:
+ hidden_states_video = hidden_states[:, :frame, ...]
+ hidden_states_image = hidden_states[:, frame:, ...]
+
+ # if i == 0 and not self.use_rope:
+ # hidden_states_video = hidden_states_video + temp_pos_embed
+
+ hidden_states_video = temp_block(
+ hidden_states_video,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ org_timestep,
+ )
+
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ else:
+ # if i == 0 and not self.use_rope:
+ if i == 0:
+ hidden_states = hidden_states + temp_pos_embed
+ hidden_states = temp_block(
+ hidden_states,
+ None, # attention_mask
+ None, # encoder_hidden_states
+ None, # encoder_attention_mask
+ timestep_temp,
+ cross_attention_kwargs,
+ class_labels,
+ pos_t,
+ pos_t,
+ (frame,),
+ org_timestep,
+ all_timesteps=all_timesteps,
+ )
+
+ hidden_states = rearrange(
+ hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
+ ).contiguous()
+
+ if enable_sequence_parallel():
+ hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
+
+ # 3. Gather batch for data parallelism
+ if get_cfg_parallel_size() > 1:
+ output = gather_sequence(output, get_cfg_parallel_group(), dim=0)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, "config.json")
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ model = cls.from_config(config, **kwargs)
+ return model
+
+ def split_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+ def gather_from_second_dim(self, x, batch_size):
+ x = x.view(batch_size, -1, *x.shape[1:])
+ x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
+ x = x.reshape(-1, *x.shape[2:])
+ return x
+
+
+# depth = num_layers * 2
+def LatteT2V_XL_122(**kwargs):
+ return LatteT2V(
+ num_layers=28,
+ attention_head_dim=72,
+ num_attention_heads=16,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1152,
+ **kwargs,
+ )
+
+
+def LatteT2V_D64_XL_122(**kwargs):
+ return LatteT2V(
+ num_layers=28,
+ attention_head_dim=64,
+ num_attention_heads=18,
+ patch_size_t=1,
+ patch_size=2,
+ norm_type="ada_norm_single",
+ caption_channels=4096,
+ cross_attention_dim=1152,
+ **kwargs,
+ )
+
+
+Latte_models = {
+ "LatteT2V-XL/122": LatteT2V_XL_122,
+ "LatteT2V-D64-XL/122": LatteT2V_D64_XL_122,
+}
diff --git a/videosys/models/open_sora_plan/losses.py b/videosys/models/open_sora_plan/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..106f4eb7ab063d41719721f97657ff1ee9d90103
--- /dev/null
+++ b/videosys/models/open_sora_plan/losses.py
@@ -0,0 +1,677 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import functools
+import hashlib
+import os
+from collections import namedtuple
+
+import requests
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+from torchvision import models
+from tqdm import tqdm
+
+from videosys.models.open_sora_plan.modules.normalize import ActNorm
+
+URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
+
+CKPT_MAP = {"vgg_lpips": "vgg.pth"}
+
+MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
+ self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """A single linear layer which does a 1x1 conv"""
+
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = (
+ [
+ nn.Dropout(),
+ ]
+ if (use_dropout)
+ else []
+ )
+ layers += [
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
+ ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2, 3], keepdim=keepdim)
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+def weights_init_conv(m):
+ if hasattr(m, "conv"):
+ m = m.conv
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+
+class NLayerDiscriminator3D(nn.Module):
+ """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
+
+ def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
+ """
+ Construct a 3D PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input volumes
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ use_actnorm (bool) -- flag to use actnorm instead of batchnorm
+ """
+ super(NLayerDiscriminator3D, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm3d
+ else:
+ raise NotImplementedError("Not implemented.")
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func != nn.BatchNorm3d
+ else:
+ use_bias = norm_layer != nn.BatchNorm3d
+
+ kw = 3
+ padw = 1
+ sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv3d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=(kw, kw, kw),
+ stride=(2 if n == 1 else 1, 2, 2),
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv3d(
+ ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
+ )
+ return d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.0):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+
+def l1(x, y):
+ return torch.abs(x - y)
+
+
+def l2(x, y):
+ return torch.pow((x - y), 2)
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start,
+ logvar_init=0.0,
+ kl_weight=1.0,
+ pixelloss_weight=1.0,
+ perceptual_weight=1.0,
+ # --- Discriminator Loss ---
+ disc_num_layers=3,
+ disc_in_channels=3,
+ disc_factor=1.0,
+ disc_weight=1.0,
+ use_actnorm=False,
+ disc_conditional=False,
+ disc_loss="hinge",
+ ):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ posteriors,
+ optimizer_idx,
+ global_step,
+ split="train",
+ weights=None,
+ last_layer=None,
+ cond=None,
+ ):
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
+ rec_loss = torch.abs(inputs - reconstructions)
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs, reconstructions)
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # GAN Part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+ log = {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
+
+
+class LPIPSWithDiscriminator3D(nn.Module):
+ def __init__(
+ self,
+ disc_start,
+ logvar_init=0.0,
+ kl_weight=1.0,
+ pixelloss_weight=1.0,
+ perceptual_weight=1.0,
+ # --- Discriminator Loss ---
+ disc_num_layers=3,
+ disc_in_channels=3,
+ disc_factor=1.0,
+ disc_weight=1.0,
+ use_actnorm=False,
+ disc_conditional=False,
+ disc_loss="hinge",
+ ):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator3D(
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ posteriors,
+ optimizer_idx,
+ global_step,
+ split="train",
+ weights=None,
+ last_layer=None,
+ cond=None,
+ ):
+ t = inputs.shape[2]
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
+ rec_loss = torch.abs(inputs - reconstructions)
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs, reconstructions)
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous()
+ reconstructions = rearrange(reconstructions, "(b t) c h w -> b c t h w", t=t).contiguous()
+ # GAN Part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions)
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions, cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError as e:
+ assert not self.training, print(e)
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+ log = {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
+
+
+class SimpleLPIPS(nn.Module):
+ def __init__(
+ self,
+ logvar_init=0.0,
+ kl_weight=1.0,
+ pixelloss_weight=1.0,
+ perceptual_weight=1.0,
+ disc_loss="hinge",
+ ):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ posteriors,
+ split="train",
+ weights=None,
+ ):
+ inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
+ reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
+ rec_loss = torch.abs(inputs - reconstructions)
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs, reconstructions)
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ loss = weighted_nll_loss + self.kl_weight * kl_loss
+ log = {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ }
+ if self.perceptual_weight > 0:
+ log.update({"{}/p_loss".format(split): p_loss.detach().mean()})
+ return loss, log
diff --git a/videosys/models/open_sora_plan/modules/__init__.py b/videosys/models/open_sora_plan/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..519c44694083780c478bfd10079f1b2accd80652
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/__init__.py
@@ -0,0 +1,17 @@
+from .attention import AttnBlock, AttnBlock3D, AttnBlock3DFix, LinAttnBlock, LinearAttention, TemporalAttnBlock
+from .block import Block
+from .conv import CausalConv3d, Conv2d
+from .normalize import GroupNorm, Normalize
+from .resnet_block import ResnetBlock2D, ResnetBlock3D
+from .updownsample import (
+ Downsample,
+ SpatialDownsample2x,
+ SpatialUpsample2x,
+ TimeDownsample2x,
+ TimeDownsampleRes2x,
+ TimeDownsampleResAdv2x,
+ TimeUpsample2x,
+ TimeUpsampleRes2x,
+ TimeUpsampleResAdv2x,
+ Upsample,
+)
diff --git a/videosys/models/open_sora_plan/modules/attention.py b/videosys/models/open_sora_plan/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..97ea9364d493cf52750afe2f399a41072f834fc0
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/attention.py
@@ -0,0 +1,227 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from .block import Block
+from .conv import CausalConv3d
+from .normalize import Normalize
+from .ops import video_to_image
+
+
+class LinearAttention(Block):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock3D(Block):
+ """Compatible with old versions, there are issues, use with caution."""
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b * t, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, t, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock3DFix(nn.Module):
+ """
+ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
+ b, c, t, h, w = q.shape
+ q = q.permute(0, 2, 1, 3, 4)
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1)
+
+ # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
+ k = k.permute(0, 2, 1, 3, 4)
+ k = k.reshape(b * t, c, h * w)
+
+ # w: (b*t hw hw)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ # v: (b c t h w) -> (b t c h w) -> (bt c hw)
+ # w_: (bt hw hw) -> (bt hw hw)
+ v = v.permute(0, 2, 1, 3, 4)
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
+ h_ = h_.reshape(b, t, c, h, w)
+ h_ = h_.permute(0, 2, 1, 3, 4)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class TemporalAttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = rearrange(q, "b c t h w -> (b h w) t c")
+ k = rearrange(k, "b c t h w -> (b h w) c t")
+ v = rearrange(v, "b c t h w -> (b h w) c t")
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ print(attn_type)
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla3D":
+ return AttnBlock3D(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
diff --git a/videosys/models/open_sora_plan/modules/block.py b/videosys/models/open_sora_plan/modules/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..423e1b6f62f4121515a09806e00b38ba68f56516
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/block.py
@@ -0,0 +1,15 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import torch.nn as nn
+
+
+class Block(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
diff --git a/videosys/models/open_sora_plan/modules/conv.py b/videosys/models/open_sora_plan/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..787f4c263f1caf25a0f868ddf73e6e7e99a59ee1
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/conv.py
@@ -0,0 +1,102 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from .ops import cast_tuple, video_to_image
+
+
+class Conv2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]] = 3,
+ stride: Union[int, Tuple[int]] = 1,
+ padding: Union[str, int, Tuple[int]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ device,
+ dtype,
+ )
+
+ @video_to_image
+ def forward(self, x):
+ return super().forward(x)
+
+
+class CausalConv3d(nn.Module):
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.time_kernel_size = self.kernel_size[0]
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ stride = kwargs.pop("stride", 1)
+ padding = kwargs.pop("padding", 0)
+ padding = list(cast_tuple(padding, 3))
+ padding[0] = 0
+ stride = cast_tuple(stride, 3)
+ self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
+ self._init_weights(init_method)
+
+ def _init_weights(self, init_method):
+ torch.tensor(self.kernel_size)
+ if init_method == "avg":
+ assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
+ assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
+ weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
+
+ eyes = torch.concat(
+ [
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ ],
+ dim=-1,
+ )
+ weight[:, :, :, 0, 0] = eyes
+
+ self.conv.weight = nn.Parameter(
+ weight,
+ requires_grad=True,
+ )
+ elif init_method == "zero":
+ self.conv.weight = nn.Parameter(
+ torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
+ requires_grad=True,
+ )
+ if self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ # 1 + 16 16 as video, 1 as image
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
+ return self.conv(x)
diff --git a/videosys/models/open_sora_plan/modules/normalize.py b/videosys/models/open_sora_plan/modules/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ee61c7b25501d19ad7ba3091e73df9750f5a68a
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/normalize.py
@@ -0,0 +1,98 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+
+from .block import Block
+
+
+class GroupNorm(Block):
+ def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
+
+ def forward(self, x):
+ return self.norm(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+ std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
diff --git a/videosys/models/open_sora_plan/modules/ops.py b/videosys/models/open_sora_plan/modules/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd636ae92511f108def36706b79f52a45c939fd
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/ops.py
@@ -0,0 +1,54 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import torch
+from einops import rearrange
+
+
+def video_to_image(func):
+ def wrapper(self, x, *args, **kwargs):
+ if x.dim() == 5:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = func(self, x, *args, **kwargs)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+ return wrapper
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
+ n_dims = len(x.shape)
+ if src_dim < 0:
+ src_dim = n_dims + src_dim
+ if dest_dim < 0:
+ dest_dim = n_dims + dest_dim
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
+ dims = list(range(n_dims))
+ del dims[src_dim]
+ permutation = []
+ ctr = 0
+ for i in range(n_dims):
+ if i == dest_dim:
+ permutation.append(src_dim)
+ else:
+ permutation.append(dims[ctr])
+ ctr += 1
+ x = x.permute(permutation)
+ if make_contiguous:
+ x = x.contiguous()
+ return x
diff --git a/videosys/models/open_sora_plan/modules/quant.py b/videosys/models/open_sora_plan/modules/quant.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b9dcf26b95ad81aa5474feec8397b3c01916bb
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/quant.py
@@ -0,0 +1,111 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ops import shift_dim
+
+
+class Codebook(nn.Module):
+ def __init__(self, n_codes, embedding_dim):
+ super().__init__()
+ self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
+ self.register_buffer("N", torch.zeros(n_codes))
+ self.register_buffer("z_avg", self.embeddings.data.clone())
+
+ self.n_codes = n_codes
+ self.embedding_dim = embedding_dim
+ self._need_init = True
+
+ def _tile(self, x):
+ d, ew = x.shape
+ if d < self.n_codes:
+ n_repeats = (self.n_codes + d - 1) // d
+ std = 0.01 / np.sqrt(ew)
+ x = x.repeat(n_repeats, 1)
+ x = x + torch.randn_like(x) * std
+ return x
+
+ def _init_embeddings(self, z):
+ # z: [b, c, t, h, w]
+ self._need_init = False
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ y = self._tile(flat_inputs)
+
+ y.shape[0]
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+ self.embeddings.data.copy_(_k_rand)
+ self.z_avg.data.copy_(_k_rand)
+ self.N.data.copy_(torch.ones(self.n_codes))
+
+ def forward(self, z):
+ # z: [b, c, t, h, w]
+ if self._need_init and self.training:
+ self._init_embeddings(z)
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ distances = (
+ (flat_inputs**2).sum(dim=1, keepdim=True)
+ - 2 * flat_inputs @ self.embeddings.t()
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
+ )
+
+ encoding_indices = torch.argmin(distances, dim=1)
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
+
+ embeddings = F.embedding(encoding_indices, self.embeddings)
+ embeddings = shift_dim(embeddings, -1, 1)
+
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
+
+ # EMA codebook update
+ if self.training:
+ n_total = encode_onehot.sum(dim=0)
+ encode_sum = flat_inputs.t() @ encode_onehot
+ if dist.is_initialized():
+ dist.all_reduce(n_total)
+ dist.all_reduce(encode_sum)
+
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
+
+ n = self.N.sum()
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
+ self.embeddings.data.copy_(encode_normalized)
+
+ y = self._tile(flat_inputs)
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+
+ usage = (self.N.view(self.n_codes, 1) >= 1).float()
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
+
+ embeddings_st = (embeddings - z).detach() + z
+
+ avg_probs = torch.mean(encode_onehot, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ return dict(
+ embeddings=embeddings_st,
+ encodings=encoding_indices,
+ commitment_loss=commitment_loss,
+ perplexity=perplexity,
+ )
+
+ def dictionary_lookup(self, encodings):
+ embeddings = F.embedding(encodings, self.embeddings)
+ return embeddings
diff --git a/videosys/models/open_sora_plan/modules/resnet_block.py b/videosys/models/open_sora_plan/modules/resnet_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..987c690e525f28114d8321d0cb7c043a4b2a7e8b
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/resnet_block.py
@@ -0,0 +1,87 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import torch
+
+from .block import Block
+from .conv import CausalConv3d
+from .normalize import Normalize
+from .ops import nonlinearity, video_to_image
+
+
+class ResnetBlock2D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ x = x + h
+ return x
+
+
+class ResnetBlock3D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ else:
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
diff --git a/videosys/models/open_sora_plan/modules/updownsample.py b/videosys/models/open_sora_plan/modules/updownsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..db27de1d95206d80336472a1acc4a99165ebbb98
--- /dev/null
+++ b/videosys/models/open_sora_plan/modules/updownsample.py
@@ -0,0 +1,215 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .attention import TemporalAttnBlock
+from .block import Block
+from .conv import CausalConv3d
+from .normalize import Normalize
+from .ops import cast_tuple, video_to_image
+from .resnet_block import ResnetBlock3D
+
+
+class Upsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ @video_to_image
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class SpatialDownsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (2, 2),
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 2)
+ stride = cast_tuple(stride, 2)
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1, 0, 0)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class SpatialUpsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (1, 1),
+ ):
+ super().__init__()
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
+
+ def forward(self, x):
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> b (c t) h w")
+ x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
+ x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
+ x = self.conv(x)
+ return x
+
+
+class TimeDownsample2x(Block):
+ def __init__(self, chan_in, chan_out, kernel_size: int = 3):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return self.conv(x)
+
+
+class TimeUpsample2x(Block):
+ def __init__(self, chan_in, chan_out):
+ super().__init__()
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return x
+
+
+class TimeDownsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
+
+
+class TimeUpsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return alpha * x + (1 - alpha) * self.conv(x)
+
+
+class TimeDownsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.attn = TemporalAttnBlock(in_channels)
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
+
+
+class TimeUpsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.attn = TemporalAttnBlock(in_channels)
+ self.norm = Normalize(in_channels=in_channels)
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
diff --git a/videosys/models/open_sora_plan/pipeline.py b/videosys/models/open_sora_plan/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..004486d671306fe769713195ab438c15c62d72c0
--- /dev/null
+++ b/videosys/models/open_sora_plan/pipeline.py
@@ -0,0 +1,890 @@
+# Adapted from Open-Sora-Plan
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
+# --------------------------------------------------------
+
+import html
+import inspect
+import math
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import ftfy
+import torch
+import torch.distributed as dist
+import tqdm
+from bs4 import BeautifulSoup
+from diffusers.models import AutoencoderKL, Transformer2DModel
+from diffusers.schedulers import PNDMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import T5EncoderModel, T5Tokenizer
+
+from videosys.core.pab_mgr import (
+ PABConfig,
+ get_diffusion_skip,
+ get_diffusion_skip_timestep,
+ set_pab_manager,
+ skip_diffusion_timestep,
+ update_steps,
+)
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.utils.logging import logger
+from videosys.utils.utils import save_video
+
+from .ae import ae_stride_config, getae_wrapper
+from .latte import LatteT2V
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PixArtAlphaPipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class OpenSoraPlanPABConfig(PABConfig):
+ def __init__(
+ self,
+ steps: int = 150,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [100, 850],
+ spatial_gap: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [100, 850],
+ temporal_gap: int = 4,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [100, 850],
+ cross_gap: int = 6,
+ diffusion_skip: bool = False,
+ diffusion_timestep_respacing: list = None,
+ diffusion_skip_timestep: list = None,
+ mlp_skip: bool = True,
+ mlp_spatial_skip_config: dict = {
+ 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ },
+ mlp_temporal_skip_config: dict = {
+ 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 666: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 642: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 618: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 594: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 570: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 546: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 522: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 498: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 474: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ 426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
+ },
+ ):
+ super().__init__(
+ steps=steps,
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_gap=spatial_gap,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_gap=temporal_gap,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_gap=cross_gap,
+ diffusion_skip=diffusion_skip,
+ diffusion_timestep_respacing=diffusion_timestep_respacing,
+ diffusion_skip_timestep=diffusion_skip_timestep,
+ mlp_skip=mlp_skip,
+ mlp_spatial_skip_config=mlp_spatial_skip_config,
+ mlp_temporal_skip_config=mlp_temporal_skip_config,
+ )
+
+
+class OpenSoraPlanConfig:
+ def __init__(
+ self,
+ world_size: int = 1,
+ model_path: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
+ num_frames: int = 65,
+ ae: str = "CausalVAEModel_4x8x8",
+ text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
+ # ======= vae =======
+ enable_tiling: bool = True,
+ tile_overlap_factor: float = 0.25,
+ # ======= pab ========
+ enable_pab: bool = False,
+ pab_config: PABConfig = OpenSoraPlanPABConfig(),
+ ):
+ # ======= engine ========
+ self.world_size = world_size
+
+ # ======= pipeline ========
+ self.pipeline_cls = OpenSoraPlanPipeline
+ self.ae = ae
+ self.text_encoder = text_encoder
+
+ # ======= model ========
+ self.model_path = model_path
+ assert num_frames in [65, 221], "num_frames must be one of [65, 221]"
+ self.num_frames = num_frames
+ self.version = f"{num_frames}x512x512"
+
+ # ======= vae ========
+ self.enable_tiling = enable_tiling
+ self.tile_overlap_factor = tile_overlap_factor
+
+ # ======= pab ========
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class OpenSoraPlanPipeline(VideoSysPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ config: OpenSoraPlanConfig,
+ tokenizer: Optional[T5Tokenizer] = None,
+ text_encoder: Optional[T5EncoderModel] = None,
+ vae: Optional[AutoencoderKL] = None,
+ transformer: Optional[Transformer2DModel] = None,
+ scheduler: Optional[PNDMScheduler] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.float16,
+ ):
+ super().__init__()
+ self._config = config
+
+ # init
+ if tokenizer is None:
+ tokenizer = T5Tokenizer.from_pretrained(config.text_encoder)
+ if text_encoder is None:
+ text_encoder = T5EncoderModel.from_pretrained(config.text_encoder, torch_dtype=torch.float16)
+ if vae is None:
+ vae = getae_wrapper(config.ae)(config.model_path, subfolder="vae").to(dtype=dtype)
+ if transformer is None:
+ transformer = LatteT2V.from_pretrained(config.model_path, subfolder=config.version, torch_dtype=dtype)
+ if scheduler is None:
+ scheduler = PNDMScheduler()
+
+ # setting
+ if config.enable_tiling:
+ vae.vae.enable_tiling()
+ vae.vae.tile_overlap_factor = config.tile_overlap_factor
+ vae.vae_scale_factor = ae_stride_config[config.ae]
+ transformer.force_images = False
+
+ # set eval and device
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
+
+ # pab
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
+ else:
+ masked_feature = emb * mask[:, None, :, None] # 1 120 4096
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ mask_feature: bool = True,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (bool, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ mask_feature: (bool, defaults to `True`):
+ If `True`, the function will mask the text embeddings.
+ """
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
+ if device is None:
+ device = self.text_encoder.device or self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = 300
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because the model can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds_attention_mask = attention_mask
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ # print(prompt_embeds.shape) # 1 120 4096
+ # print(negative_prompt_embeds.shape) # 1 120 4096
+
+ # Perform additional masking.
+ if mask_feature and not embeds_initially_provided:
+ prompt_embeds = prompt_embeds.unsqueeze(1)
+ masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
+ masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
+ masked_negative_prompt_embeds = (
+ negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
+ )
+
+ # import torch.nn.functional as F
+
+ # padding = (0, 0, 0, 113) # (左, 右, 下, 上)
+ # masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0)
+ # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0)
+
+ # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...])
+
+ return masked_prompt_embeds, masked_negative_prompt_embeds
+ # return masked_prompt_embeds_, masked_negative_prompt_embeds_
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
+ # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
+ # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
+ # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1)
+ if int(num_frames) % 2 == 1
+ else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]),
+ math.ceil(int(height) / self.vae.vae_scale_factor[1]),
+ math.ceil(int(width) / self.vae.vae_scale_factor[2]),
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 150,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ mask_feature: bool = True,
+ enable_temporal_attentions: bool = True,
+ verbose: bool = True,
+ ) -> Union[VideoSysPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = 512
+ width = 512
+ num_frames = self._config.num_frames
+ update_steps(num_inference_steps)
+ self.check_inputs(prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self.text_encoder.device or self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ mask_feature=mask_feature,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ # if self.transformer.config.sample_size == 128:
+ # resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ # aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ # resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ # aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+ # added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
+ diffusion_skip_timestep = get_diffusion_skip_timestep()
+
+ # warmup_timesteps = timesteps[:num_warmup_steps]
+ # after_warmup_timesteps = skip_diffusion_timestep(timesteps[num_warmup_steps:], diffusion_skip_timestep)
+ # timesteps = torch.cat((warmup_timesteps, after_warmup_timesteps))
+
+ timesteps = skip_diffusion_timestep(timesteps, diffusion_skip_timestep)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
+ for i, t in progress_wrap(list(enumerate(timesteps))):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ if prompt_embeds.ndim == 3:
+ prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ all_timesteps=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ enable_temporal_attentions=enable_temporal_attentions,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latents":
+ video = self.decode_latents(latents)
+ video = video[:, :num_frames, :height, :width]
+ else:
+ video = latents
+ return VideoSysPipelineOutput(video=video)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return VideoSysPipelineOutput(video=video)
+
+ def decode_latents(self, latents):
+ video = self.vae.decode(latents) # b t c h w
+ # b t c h w -> b t h w c
+ video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous()
+ return video
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=24)
diff --git a/videosys/modules/__init__.py b/videosys/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/modules/attn.py b/videosys/modules/attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..424c2b78a340696c26a530941ea013caa352889a
--- /dev/null
+++ b/videosys/modules/attn.py
@@ -0,0 +1,217 @@
+from dataclasses import dataclass
+from typing import Iterable, List, Optional, Sequence, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from videosys.modules.layers import LlamaRMSNorm
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = LlamaRMSNorm,
+ enable_flashattn: bool = False,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.enable_flashattn = enable_flashattn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = False
+ if rope is not None:
+ self.rope = True
+ self.rotary_emb = rope
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+
+ qkv = self.qkv(x)
+ qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv.unbind(0)
+ if self.rope:
+ q = self.rotary_emb(q)
+ k = self.rotary_emb(k)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.enable_flashattn:
+ from flash_attn import flash_attn_func
+
+ x = flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ softmax_scale=self.scale,
+ )
+ else:
+ q, k, v = map(lambda t: t.permute(0, 2, 1, 3), (q, k, v))
+ x = F.scaled_dot_product_attention(
+ q, k, v, scale=self.scale, dropout_p=self.attn_drop.p if self.training else 0.0
+ )
+
+ x_output_shape = (B, N, C)
+ if not self.enable_flashattn:
+ x = x.transpose(1, 2)
+ x = x.reshape(x_output_shape)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MultiHeadCrossAttention(nn.Module):
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flashattn=False):
+ super(MultiHeadCrossAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
+
+ self.d_model = d_model
+ self.num_heads = num_heads
+ self.head_dim = d_model // num_heads
+ self.enable_flashattn = enable_flashattn
+
+ self.q_linear = nn.Linear(d_model, d_model)
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(d_model, d_model)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.last_out = None
+ self.count = 0
+
+ def forward(self, x, cond, mask=None, timestep=None):
+ # query/value: img tokens; key: condition; mask: if padding tokens
+ B, N, C = x.shape
+
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
+ k, v = kv.unbind(2)
+ x = self.flash_attn_impl(q, k, v, mask, B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def flash_attn_impl(self, q, k, v, mask, B, N, C):
+ from flash_attn import flash_attn_varlen_func
+
+ q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
+ k_seqinfo = _SeqLenInfo.from_seqlens(mask)
+
+ x = flash_attn_varlen_func(
+ q.view(-1, self.num_heads, self.head_dim),
+ k.view(-1, self.num_heads, self.head_dim),
+ v.view(-1, self.num_heads, self.head_dim),
+ cu_seqlens_q=q_seqinfo.seqstart.cuda(),
+ cu_seqlens_k=k_seqinfo.seqstart.cuda(),
+ max_seqlen_q=q_seqinfo.max_seqlen,
+ max_seqlen_k=k_seqinfo.max_seqlen,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ x = x.view(B, N, C)
+ return x
+
+ def torch_impl(self, q, k, v, mask, B, N, C):
+ q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attn_mask = torch.zeros(B, N, k.shape[2], dtype=torch.float32, device=q.device)
+ for i, m in enumerate(mask):
+ attn_mask[i, :, m:] = -1e8
+
+ scale = 1 / q.shape[-1] ** 0.5
+ q = q * scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.to(torch.float32)
+ if mask is not None:
+ attn = attn + attn_mask.unsqueeze(1)
+ attn = attn.softmax(-1)
+ attn = attn.to(v.dtype)
+ out = attn @ v
+
+ x = out.transpose(1, 2).contiguous().view(B, N, C)
+ return x
+
+
+@dataclass
+class _SeqLenInfo:
+ """
+ copied from xformers
+
+ (Internal) Represents the division of a dimension into blocks.
+ For example, to represents a dimension of length 7 divided into
+ three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
+ The members will be:
+ max_seqlen: 3
+ min_seqlen: 2
+ seqstart_py: [0, 2, 5, 7]
+ seqstart: torch.IntTensor([0, 2, 5, 7])
+ """
+
+ seqstart: torch.Tensor
+ max_seqlen: int
+ min_seqlen: int
+ seqstart_py: List[int]
+
+ def to(self, device: torch.device) -> None:
+ self.seqstart = self.seqstart.to(device, non_blocking=True)
+
+ def intervals(self) -> Iterable[Tuple[int, int]]:
+ yield from zip(self.seqstart_py, self.seqstart_py[1:])
+
+ @classmethod
+ def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
+ """
+ Input tensors are assumed to be in shape [B, M, *]
+ """
+ assert not isinstance(seqlens, torch.Tensor)
+ seqstart_py = [0]
+ max_seqlen = -1
+ min_seqlen = -1
+ for seqlen in seqlens:
+ min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
+ max_seqlen = max(max_seqlen, seqlen)
+ seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
+ seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
+ return cls(
+ max_seqlen=max_seqlen,
+ min_seqlen=min_seqlen,
+ seqstart=seqstart,
+ seqstart_py=seqstart_py,
+ )
+
+ def split(self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None) -> List[torch.Tensor]:
+ if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
+ raise ValueError(
+ f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
+ f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
+ f" seqstart: {self.seqstart_py}"
+ )
+ if batch_sizes is None:
+ batch_sizes = [1] * (len(self.seqstart_py) - 1)
+ split_chunks = []
+ it = 0
+ for batch_size in batch_sizes:
+ split_chunks.append(self.seqstart_py[it + batch_size] - self.seqstart_py[it])
+ it += batch_size
+ return [
+ tensor.reshape([bs, -1, *tensor.shape[2:]]) for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
+ ]
diff --git a/videosys/modules/embed.py b/videosys/modules/embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a166fadaa695c314cd5279754b2b1389136d547
--- /dev/null
+++ b/videosys/modules/embed.py
@@ -0,0 +1,145 @@
+# Modified from Meta DiT
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DiT: https://github.com/facebookresearch/DiT/tree/main
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ device=t.device
+ )
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
diff --git a/videosys/modules/layers.py b/videosys/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b717ba12d1685e1ed9e843bb6907db044a229824
--- /dev/null
+++ b/videosys/modules/layers.py
@@ -0,0 +1,80 @@
+# Modified from Meta DiT
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DiT: https://github.com/facebookresearch/DiT/tree/main
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+
+def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
+ if use_kernel:
+ try:
+ from apex.normalization import FusedLayerNorm
+
+ return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
+ except ImportError:
+ raise RuntimeError("FusedLayerNorm not available. Please install apex.")
+ else:
+ return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
+
+
+def modulate(norm_func, x, shift, scale, use_kernel=False):
+ # Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
+ dtype = x.dtype
+ x = norm_func(x.to(torch.float32)).to(dtype)
+ if use_kernel:
+ try:
+ from videosys.kernels.fused_modulate import fused_modulate
+
+ x = fused_modulate(x, scale, shift)
+ except ImportError:
+ raise RuntimeError("FusedModulate kernel not available. Please install triton.")
+ else:
+ x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
+ x = x.to(dtype)
+
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final, x, shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
diff --git a/videosys/utils/ckpt_utils.py b/videosys/utils/ckpt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8b33c7909dc9eb83a6013f5042c07d5ef3dc71b
--- /dev/null
+++ b/videosys/utils/ckpt_utils.py
@@ -0,0 +1,115 @@
+import functools
+import json
+import operator
+import os
+from typing import Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+from videosys.core.comm import model_sharding
+
+
+def load_json(file_path: str):
+ with open(file_path, "r") as f:
+ return json.load(f)
+
+
+def save_json(data, file_path: str):
+ with open(file_path, "w") as f:
+ json.dump(data, f, indent=4)
+
+
+def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
+ return tensor[: functools.reduce(operator.mul, original_shape)]
+
+
+def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
+ global_rank = dist.get_rank()
+ global_size = dist.get_world_size()
+ for name, param in model.named_parameters():
+ all_params = [torch.empty_like(param.data) for _ in range(global_size)]
+ dist.all_gather(all_params, param.data, group=dist.group.WORLD)
+ if global_rank == 0:
+ all_params = torch.cat(all_params)
+ param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
+ dist.barrier()
+
+
+def record_model_param_shape(model: torch.nn.Module) -> dict:
+ param_shape = {}
+ for name, param in model.named_parameters():
+ param_shape[name] = param.shape
+ return param_shape
+
+
+def save(
+ booster: Booster,
+ model: nn.Module,
+ ema: nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ global_step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+ save_dir: str,
+ shape_dict: dict,
+ shard_ema: bool = False,
+):
+ torch.cuda.empty_cache()
+ global_rank = dist.get_rank()
+ save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
+ os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
+ booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
+
+ # Gather the sharded ema model before saving
+ if shard_ema:
+ model_gathering(ema, shape_dict)
+
+ # ema is not boosted, so we don't need to use booster.save_model
+ if global_rank == 0:
+ torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
+ # Shard ema model when using zero2 plugin
+ if shard_ema:
+ model_sharding(ema)
+ if optimizer is not None:
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
+ if lr_scheduler is not None:
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "global_step": global_step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+ dist.barrier()
+
+
+def load(
+ booster: Booster,
+ model: nn.Module,
+ ema: nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ load_dir: str,
+) -> Tuple[int, int, int]:
+ booster.load_model(model, os.path.join(load_dir, "model"))
+ # ema is not boosted, so we don't use booster.load_model
+ ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
+ if optimizer is not None:
+ booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
+ if lr_scheduler is not None:
+ booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
+ running_states = load_json(os.path.join(load_dir, "running_states.json"))
+ dist.barrier()
+ torch.cuda.empty_cache()
+ return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
diff --git a/videosys/utils/debug_utils.py b/videosys/utils/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c6d8f45786f647aa87c064aedc01831ff8d498d
--- /dev/null
+++ b/videosys/utils/debug_utils.py
@@ -0,0 +1,7 @@
+import torch.distributed as dist
+
+
+# Print debug information on selected rank
+def print_rank(var_name, var_value, rank=0):
+ if dist.get_rank() == rank:
+ print(f"[Rank {rank}] {var_name}: {var_value}")
diff --git a/videosys/utils/download.py b/videosys/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..75cf8dd8bf91e66d568ca2055a8138d6fb8977bb
--- /dev/null
+++ b/videosys/utils/download.py
@@ -0,0 +1,79 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Functions for downloading pre-trained DiT models
+"""
+import json
+import os
+
+import torch
+from torchvision.datasets.utils import download_url
+
+pretrained_models = {"DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"}
+
+
+def find_model(model_name):
+ """
+ Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
+ return download_model(model_name)
+ else: # Load a custom DiT checkpoint:
+ if not os.path.isfile(model_name):
+ # if the model_name is a directory, then we assume we should load it in the Hugging Face manner
+ # i.e. the model weights are sharded into multiple files and there is an index.json file
+ # walk through the files in the directory and find the index.json file
+ index_file = [os.path.join(model_name, f) for f in os.listdir(model_name) if "index.json" in f]
+ assert len(index_file) == 1, f"Could not find index.json in {model_name}"
+
+ # process index json
+ with open(index_file[0], "r") as f:
+ index_data = json.load(f)
+
+ bin_to_weight_mapping = dict()
+ for k, v in index_data["weight_map"].items():
+ if v in bin_to_weight_mapping:
+ bin_to_weight_mapping[v].append(k)
+ else:
+ bin_to_weight_mapping[v] = [k]
+
+ # make state dict
+ state_dict = dict()
+ for bin_name, weight_list in bin_to_weight_mapping.items():
+ bin_path = os.path.join(model_name, bin_name)
+ bin_state_dict = torch.load(bin_path, map_location=lambda storage, loc: storage)
+ for weight in weight_list:
+ state_dict[weight] = bin_state_dict[weight]
+ return state_dict
+ else:
+ # if it is a file, we just load it directly in the typical PyTorch manner
+ assert os.path.exists(model_name), f"Could not find DiT checkpoint at {model_name}"
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ checkpoint = checkpoint["ema"]
+ return checkpoint
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ assert model_name in pretrained_models
+ local_path = f"pretrained_models/{model_name}"
+ if not os.path.isfile(local_path):
+ os.makedirs("pretrained_models", exist_ok=True)
+ web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
+ download_url(web_path, "pretrained_models")
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+if __name__ == "__main__":
+ # Download all DiT checkpoints
+ for model in pretrained_models:
+ download_model(model)
+ print("Done.")
diff --git a/videosys/utils/logging.py b/videosys/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..896a4d60ed07c73e571e7d5a30f2574103de379b
--- /dev/null
+++ b/videosys/utils/logging.py
@@ -0,0 +1,32 @@
+import logging
+
+import torch.distributed as dist
+from rich.logging import RichHandler
+
+
+def create_logger():
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ logger = logging.getLogger(__name__)
+ return logger
+
+
+def init_dist_logger():
+ """
+ Update the logger to write to a log file.
+ """
+ global logger
+ if dist.get_rank() == 0:
+ logger = logging.getLogger(__name__)
+ handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
+ formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ logger.setLevel(logging.INFO)
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+
+
+logger = create_logger()
diff --git a/videosys/utils/train_utils.py b/videosys/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8906872698c59eb4cf76e00692ea2d6418224190
--- /dev/null
+++ b/videosys/utils/train_utils.py
@@ -0,0 +1,65 @@
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer
+
+
+def get_model_numel(model: torch.nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f"{numel / B:.2f} B"
+ elif numel >= M:
+ return f"{numel / M:.2f} M"
+ elif numel >= K:
+ return f"{numel / K:.2f} K"
+ else:
+ return f"{numel}"
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+@torch.no_grad()
+def update_ema(
+ ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
+) -> None:
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+
+ for name, param in model_params.items():
+ if name == "pos_embed":
+ continue
+ if param.requires_grad == False:
+ continue
+ if not sharded:
+ param_data = param.data
+ ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
+ else:
+ if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer):
+ param_id = id(param)
+ master_param = optimizer._param_store.working_to_master_param[param_id]
+ param_data = master_param.data
+ else:
+ param_data = param.data
+ ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
+
+
+def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
diff --git a/videosys/utils/utils.py b/videosys/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44bcf1bb3954f9ec00d2d14d4a08b713f36ecf3
--- /dev/null
+++ b/videosys/utils/utils.py
@@ -0,0 +1,82 @@
+import os
+import random
+
+import imageio
+import numpy as np
+import torch
+import torch.distributed as dist
+from omegaconf import DictConfig, ListConfig, OmegaConf
+
+
+def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+
+def set_seed(seed):
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+
+def str_to_dtype(x: str):
+ if x == "fp32":
+ return torch.float32
+ elif x == "fp16":
+ return torch.float16
+ elif x == "bf16":
+ return torch.bfloat16
+ else:
+ raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
+
+
+def batch_func(func, *args):
+ """
+ Apply a function to each element of a batch.
+ """
+ batch = []
+ for arg in args:
+ if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
+ batch.append(func(arg))
+ else:
+ batch.append(arg)
+
+ return batch
+
+
+def merge_args(args1, args2):
+ """
+ Merge two argparse Namespace objects.
+ """
+ if args2 is None:
+ return args1
+
+ for k in args2._content.keys():
+ if k in args1.__dict__:
+ v = getattr(args2, k)
+ if isinstance(v, ListConfig) or isinstance(v, DictConfig):
+ v = OmegaConf.to_object(v)
+ setattr(args1, k, v)
+ else:
+ raise RuntimeError(f"Unknown argument {k}")
+
+ return args1
+
+
+def all_exists(paths):
+ return all(os.path.exists(path) for path in paths)
+
+
+def save_video(video, output_path, fps):
+ """
+ Save a video to disk.
+ """
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ if dist.get_rank() == 0:
+ imageio.mimwrite(output_path, video, fps=fps)
+ dist.barrier()