maxin-cn commited on
Commit
be791d6
·
verified ·
1 Parent(s): cc8bc4c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. .gitignore +1 -0
  3. LICENSE +201 -0
  4. README.md +133 -12
  5. __pycache__/utils.cpython-312.pyc +0 -0
  6. animated_images/aircraft.jpg +3 -0
  7. animated_images/car.jpg +3 -0
  8. animated_images/fireworks.jpg +0 -0
  9. animated_images/flowers.jpg +0 -0
  10. animated_images/forest.jpg +3 -0
  11. animated_images/shark_unwater.jpg +0 -0
  12. configs/sample.yaml +38 -0
  13. datasets/__pycache__/video_transforms.cpython-312.pyc +0 -0
  14. datasets/video_transforms.py +748 -0
  15. demo.py +311 -0
  16. environment.yml +21 -0
  17. example/aircrafts_flying/0.jpg +0 -0
  18. example/aircrafts_flying/aircrafts_flying.mp4 +0 -0
  19. example/car_moving/0.jpg +0 -0
  20. example/car_moving/car_moving.mp4 +0 -0
  21. example/fireworks/0.jpg +0 -0
  22. example/fireworks/fireworks.mp4 +0 -0
  23. example/flowers_swaying/0.jpg +0 -0
  24. example/flowers_swaying/flowers_swaying.mp4 +0 -0
  25. example/girl_walking_on_the_beach/0.jpg +0 -0
  26. example/girl_walking_on_the_beach/girl_walking_on_the_beach.mp4 +0 -0
  27. example/house_rotating/0.jpg +0 -0
  28. example/house_rotating/house_rotating.mp4 +0 -0
  29. example/people_runing/0.jpg +0 -0
  30. example/people_runing/people_runing.mp4 +0 -0
  31. example/shark_swimming/0.jpg +0 -0
  32. example/shark_swimming/shark_swimming.mp4 +0 -0
  33. example/windmill_turning/0.jpg +0 -0
  34. example/windmill_turning/windmill_turning.mp4 +0 -0
  35. gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4 +0 -0
  36. gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4 +0 -0
  37. gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/.nfs6a1237621cfe7a8800009149 +0 -0
  38. gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4 +0 -0
  39. gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4 +0 -0
  40. gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/.nfs88c2a0e49709591000009148 +0 -0
  41. gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4 +0 -0
  42. gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4 +0 -0
  43. gradio_cached_examples/39/indices.csv +6 -0
  44. gradio_cached_examples/39/log.csv +7 -0
  45. models/__init__.py +33 -0
  46. models/__pycache__/__init__.cpython-312.pyc +0 -0
  47. models/__pycache__/attention.cpython-312.pyc +0 -0
  48. models/__pycache__/resnet.cpython-312.pyc +0 -0
  49. models/__pycache__/rotary_embedding_torch_mx.cpython-312.pyc +0 -0
  50. models/__pycache__/temporal_attention.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ animated_images/aircraft.jpg filter=lfs diff=lfs merge=lfs -text
37
+ animated_images/car.jpg filter=lfs diff=lfs merge=lfs -text
38
+ animated_images/forest.jpg filter=lfs diff=lfs merge=lfs -text
39
+ visuals/animations/dragon_glowing_eyes/dragon_glowing_eyes.gif filter=lfs diff=lfs merge=lfs -text
40
+ visuals/animations/girl_dancing_under_the_stars/girl_dancing_under_the_stars.gif filter=lfs diff=lfs merge=lfs -text
41
+ visuals/animations/people_walking/people_walking.gif filter=lfs diff=lfs merge=lfs -text
42
+ visuals/animations/sea_swell/sea_swell.gif filter=lfs diff=lfs merge=lfs -text
43
+ visuals/video_editing/edit/editing_a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif filter=lfs diff=lfs merge=lfs -text
44
+ visuals/video_editing/origin/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [XIN MA] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,133 @@
1
- ---
2
- title: Cinemo
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Cinemo
3
+ app_file: demo.py
4
+ sdk: gradio
5
+ sdk_version: 4.37.2
6
+ ---
7
+ ## Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models<br><sub>Official PyTorch Implementation</sub>
8
+
9
+
10
+ [![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b.svg)](https://arxiv.org/abs/2407.15642)
11
+ [![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://maxin-cn.github.io/cinemo_project/)
12
+
13
+
14
+ This repo contains pre-trained weights, and sampling code for our paper exploring image animation with motion diffusion models (Cinemo). You can find more visualizations on our [project page](https://maxin-cn.github.io/cinemo_project/).
15
+
16
+ In this project, we propose a novel method called Cinemo, which can perform motion-controllable image animation with strong consistency and smoothness. To improve motion smoothness, Cinemo learns the distribution of motion residuals, rather than directly generating subsequent frames. Additionally, a structural similarity index-based method is proposed to control the motion intensity. Furthermore, we propose a noise refinement technique based on discrete cosine transformation to ensure temporal consistency. These three methods help Cinemo generate highly consistent, smooth, and motion-controlled image animation results. Compared to previous methods, Cinemo offers simpler and more precise user control and better generative performance.
17
+
18
+ <div align="center">
19
+ <img src="visuals/pipeline.svg">
20
+ </div>
21
+
22
+ ## News
23
+
24
+ - (🔥 New) Jul. 23, 2024. 💥 Our paper is released on [arxiv](https://arxiv.org/abs/2407.15642).
25
+
26
+ - (🔥 New) Jun. 2, 2024. 💥 The inference code is released. The checkpoint can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main).
27
+
28
+
29
+ ## Setup
30
+
31
+ First, download and set up the repo:
32
+
33
+ ```bash
34
+ git clone https://github.com/maxin-cn/Cinemo
35
+ cd Cinemo
36
+ ```
37
+
38
+ We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
39
+ to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
40
+
41
+ ```bash
42
+ conda env create -f environment.yml
43
+ conda activate cinemo
44
+ ```
45
+
46
+
47
+ ## Animation
48
+
49
+ You can sample from our **pre-trained Cinemo models** with [`animation.py`](pipelines/animation.py). Weights for our pre-trained Cinemo model can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main). The script has various arguments for adjusting sampling steps, changing the classifier-free guidance scale, etc:
50
+
51
+ ```bash
52
+ bash pipelines/animation.sh
53
+ ```
54
+
55
+ All related checkpoints will download automatically and then you will get the following results,
56
+
57
+ <table style="width:100%; text-align:center;">
58
+ <tr>
59
+ <td align="center">Input image</td>
60
+ <td align="center">Output video</td>
61
+ <td align="center">Input image</td>
62
+ <td align="center">Output video</td>
63
+ </tr>
64
+ <tr>
65
+ <td align="center"><img src="visuals/animations/people_walking/0.jpg" width="100%"></td>
66
+ <td align="center"><img src="visuals/animations/people_walking/people_walking.gif" width="100%"></td>
67
+ <td align="center"><img src="visuals/animations/sea_swell/0.jpg" width="100%"></td>
68
+ <td align="center"><img src="visuals/animations/sea_swell/sea_swell.gif" width="100%"></td>
69
+ </tr>
70
+ <tr>
71
+ <td align="center" colspan="2">"People Walking"</td>
72
+ <td align="center" colspan="2">"Sea Swell"</td>
73
+ </tr>
74
+ <tr>
75
+ <td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/0.jpg" width="100%"></td>
76
+ <td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/girl_dancing_under_the_stars.gif" width="100%"></td>
77
+ <td align="center"><img src="visuals/animations/dragon_glowing_eyes/0.jpg" width="100%"></td>
78
+ <td align="center"><img src="visuals/animations/dragon_glowing_eyes/dragon_glowing_eyes.gif" width="100%"></td>
79
+ </tr>
80
+ <tr>
81
+ <td align="center" colspan="2">"Girl Dancing under the Stars"</td>
82
+ <td align="center" colspan="2">"Dragon Glowing Eyes"</td>
83
+ </tr>
84
+
85
+ </table>
86
+
87
+
88
+ ## Other Applications
89
+
90
+ You can also utilize Cinemo for other applications, such as motion transfer and video editing:
91
+
92
+ ```bash
93
+ bash pipelines/video_editing.sh
94
+ ```
95
+
96
+ All related checkpoints will download automatically and you will get the following results,
97
+
98
+ <table style="width:100%; text-align:center;">
99
+ <tr>
100
+ <td align="center">Input video</td>
101
+ <td align="center">First frame</td>
102
+ <td align="center">Edited first frame</td>
103
+ <td align="center">Output video</td>
104
+ </tr>
105
+ <tr>
106
+ <td align="center"><img src="visuals/video_editing/origin/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
107
+ <td align="center"><img src="visuals/video_editing/origin/0.jpg" width="100%"></td>
108
+ <td align="center"><img src="visuals/video_editing/edit/0.jpg" width="100%"></td>
109
+ <td align="center"><img src="visuals/video_editing/edit/editing_a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
110
+ </tr>
111
+
112
+ </table>
113
+
114
+
115
+
116
+ ## Citation
117
+ If you find this work useful for your research, please consider citing it.
118
+ ```bibtex
119
+ @article{ma2024cinemo,
120
+ title={Cinemo: Latent Diffusion Transformer for Video Generation},
121
+ author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
122
+ journal={arXiv preprint arXiv:2407.15642},
123
+ year={2024}
124
+ }
125
+ ```
126
+
127
+
128
+ ## Acknowledgments
129
+ Cinemo has been greatly inspired by the following amazing works and teams: [LaVie](https://github.com/Vchitect/LaVie) and [SEINE](https://github.com/Vchitect/SEINE), we thank all the contributors for open-sourcing.
130
+
131
+
132
+ ## License
133
+ The code and model weights are licensed under [LICENSE](LICENSE).
__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.62 kB). View file
 
animated_images/aircraft.jpg ADDED

Git LFS Details

  • SHA256: 8c74eb22424fdcf0e40a6e22eeb497babe23d8fa9700f70ad9b9288072801bc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
animated_images/car.jpg ADDED

Git LFS Details

  • SHA256: 331981fa29ba5d5314a3c7f42499b30cd98f8e7b01ed147626934a1d808a103b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
animated_images/fireworks.jpg ADDED
animated_images/flowers.jpg ADDED
animated_images/forest.jpg ADDED

Git LFS Details

  • SHA256: 2519a6b8cd3d901388b52d9835c223dee93610d3aa4db15659003a004e2ac2eb
  • Pointer size: 132 Bytes
  • Size of remote file: 7 MB
animated_images/shark_unwater.jpg ADDED
configs/sample.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ckpt
2
+ ckpt: /mnt/lustre/maxin/work/animation/animation-v6/results_qt_003-UNet-pixabaylaion-Xfor-Gc-320_512_0303000.pt
3
+ save_img_path: "./sample_videos/"
4
+
5
+ # pretrained_model_path: "/mnt/hwfile/gcc/maxin/work/pretrained/stable-diffusion-v1-4/"
6
+ # pretrained_model_path: "maxin-cn/Cinemo"
7
+ pretrained_model_path: "./pretrained/Cinemo"
8
+
9
+ # model config:
10
+ model: UNet
11
+ video_length: 15
12
+ image_size: [320, 512]
13
+ # beta schedule
14
+ beta_start: 0.0001
15
+ beta_end: 0.02
16
+ beta_schedule: "linear"
17
+
18
+ # model speedup
19
+ use_compile: False
20
+ use_fp16: True
21
+
22
+ # sample config:
23
+ seed:
24
+ run_time: 0
25
+ use_dct: True
26
+ guidance_scale: 7.5 #
27
+ motion_bucket_id: 95 # [0-19] The larger the value, the stronger the motion intensity
28
+ sample_method: 'DDIM'
29
+ num_sampling_steps: 50
30
+ enable_vae_temporal_decoder: True
31
+ image_prompts: [
32
+ ['aircraft.jpg', 'aircrafts flying'],
33
+ ['car.jpg' ,"car moving"],
34
+ ['fireworks.jpg', 'fireworks'],
35
+ ['flowers.jpg', 'flowers swaying'],
36
+ ['forest.jpg', 'people walking'],
37
+ ['shark_unwater.jpg', 'shark falling into the sea'],
38
+ ]
datasets/__pycache__/video_transforms.cpython-312.pyc ADDED
Binary file (32.3 kB). View file
 
datasets/video_transforms.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+ from PIL import Image
6
+ from torchvision.utils import _log_api_usage_once
7
+
8
+ def _is_tensor_video_clip(clip):
9
+ if not torch.is_tensor(clip):
10
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
11
+
12
+ if not clip.ndimension() == 4:
13
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
14
+
15
+ return True
16
+
17
+
18
+ def center_crop_arr(pil_image, image_size):
19
+ """
20
+ Center cropping implementation from ADM.
21
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
22
+ """
23
+ while min(*pil_image.size) >= 2 * image_size:
24
+ pil_image = pil_image.resize(
25
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
26
+ )
27
+
28
+ scale = image_size / min(*pil_image.size)
29
+ pil_image = pil_image.resize(
30
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
31
+ )
32
+
33
+ arr = np.array(pil_image)
34
+ crop_y = (arr.shape[0] - image_size) // 2
35
+ crop_x = (arr.shape[1] - image_size) // 2
36
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
37
+
38
+
39
+ def crop(clip, i, j, h, w):
40
+ """
41
+ Args:
42
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
43
+ """
44
+ if len(clip.size()) != 4:
45
+ raise ValueError("clip should be a 4D tensor")
46
+ return clip[..., i : i + h, j : j + w]
47
+
48
+
49
+ def resize(clip, target_size, interpolation_mode):
50
+ if len(target_size) != 2:
51
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
52
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
53
+
54
+ def resize_scale(clip, target_size, interpolation_mode):
55
+ if len(target_size) != 2:
56
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
57
+ H, W = clip.size(-2), clip.size(-1)
58
+ scale_ = target_size[0] / min(H, W)
59
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
60
+
61
+ def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
62
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
63
+
64
+ def resize_scale_with_height(clip, target_size, interpolation_mode):
65
+ H, W = clip.size(-2), clip.size(-1)
66
+ scale_ = target_size / H
67
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
68
+
69
+ def resize_scale_with_weight(clip, target_size, interpolation_mode):
70
+ H, W = clip.size(-2), clip.size(-1)
71
+ scale_ = target_size / W
72
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
73
+
74
+
75
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
76
+ """
77
+ Do spatial cropping and resizing to the video clip
78
+ Args:
79
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
80
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
81
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
82
+ h (int): Height of the cropped region.
83
+ w (int): Width of the cropped region.
84
+ size (tuple(int, int)): height and width of resized clip
85
+ Returns:
86
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
87
+ """
88
+ if not _is_tensor_video_clip(clip):
89
+ raise ValueError("clip should be a 4D torch.tensor")
90
+ clip = crop(clip, i, j, h, w)
91
+ clip = resize(clip, size, interpolation_mode)
92
+ return clip
93
+
94
+
95
+ def center_crop(clip, crop_size):
96
+ if not _is_tensor_video_clip(clip):
97
+ raise ValueError("clip should be a 4D torch.tensor")
98
+ h, w = clip.size(-2), clip.size(-1)
99
+ # print(clip.shape)
100
+ th, tw = crop_size
101
+ if h < th or w < tw:
102
+ # print(h, w)
103
+ raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
104
+
105
+ i = int(round((h - th) / 2.0))
106
+ j = int(round((w - tw) / 2.0))
107
+ return crop(clip, i, j, th, tw), i, j
108
+
109
+
110
+ def center_crop_using_short_edge(clip):
111
+ if not _is_tensor_video_clip(clip):
112
+ raise ValueError("clip should be a 4D torch.tensor")
113
+ h, w = clip.size(-2), clip.size(-1)
114
+ if h < w:
115
+ th, tw = h, h
116
+ i = 0
117
+ j = int(round((w - tw) / 2.0))
118
+ else:
119
+ th, tw = w, w
120
+ i = int(round((h - th) / 2.0))
121
+ j = 0
122
+ return crop(clip, i, j, th, tw)
123
+
124
+
125
+ def random_shift_crop(clip):
126
+ '''
127
+ Slide along the long edge, with the short edge as crop size
128
+ '''
129
+ if not _is_tensor_video_clip(clip):
130
+ raise ValueError("clip should be a 4D torch.tensor")
131
+ h, w = clip.size(-2), clip.size(-1)
132
+
133
+ if h <= w:
134
+ long_edge = w
135
+ short_edge = h
136
+ else:
137
+ long_edge = h
138
+ short_edge =w
139
+
140
+ th, tw = short_edge, short_edge
141
+
142
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
143
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
144
+ return crop(clip, i, j, th, tw), i, j
145
+
146
+ def random_crop(clip, crop_size):
147
+ if not _is_tensor_video_clip(clip):
148
+ raise ValueError("clip should be a 4D torch.tensor")
149
+ h, w = clip.size(-2), clip.size(-1)
150
+ th, tw = crop_size[-2], crop_size[-1]
151
+
152
+ if h < th or w < tw:
153
+ raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
154
+
155
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
156
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
157
+ clip_crop = crop(clip, i, j, th, tw)
158
+ return clip_crop, i, j
159
+
160
+
161
+ def to_tensor(clip):
162
+ """
163
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
164
+ permute the dimensions of clip tensor
165
+ Args:
166
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
167
+ Return:
168
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
169
+ """
170
+ _is_tensor_video_clip(clip)
171
+ if not clip.dtype == torch.uint8:
172
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
173
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
174
+ return clip.float() / 255.0
175
+
176
+
177
+ def normalize(clip, mean, std, inplace=False):
178
+ """
179
+ Args:
180
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
181
+ mean (tuple): pixel RGB mean. Size is (3)
182
+ std (tuple): pixel standard deviation. Size is (3)
183
+ Returns:
184
+ normalized clip (torch.tensor): Size is (T, C, H, W)
185
+ """
186
+ if not _is_tensor_video_clip(clip):
187
+ raise ValueError("clip should be a 4D torch.tensor")
188
+ if not inplace:
189
+ clip = clip.clone()
190
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
191
+ # print(mean)
192
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
193
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
194
+ return clip
195
+
196
+
197
+ def hflip(clip):
198
+ """
199
+ Args:
200
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
201
+ Returns:
202
+ flipped clip (torch.tensor): Size is (T, C, H, W)
203
+ """
204
+ if not _is_tensor_video_clip(clip):
205
+ raise ValueError("clip should be a 4D torch.tensor")
206
+ return clip.flip(-1)
207
+
208
+
209
+ class RandomCropVideo:
210
+ def __init__(self, size):
211
+ if isinstance(size, numbers.Number):
212
+ self.size = (int(size), int(size))
213
+ else:
214
+ self.size = size
215
+
216
+ def __call__(self, clip):
217
+ """
218
+ Args:
219
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
220
+ Returns:
221
+ torch.tensor: randomly cropped video clip.
222
+ size is (T, C, OH, OW)
223
+ """
224
+ i, j, h, w = self.get_params(clip)
225
+ return crop(clip, i, j, h, w)
226
+
227
+ def get_params(self, clip):
228
+ h, w = clip.shape[-2:]
229
+ th, tw = self.size
230
+
231
+ if h < th or w < tw:
232
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
233
+
234
+ if w == tw and h == th:
235
+ return 0, 0, h, w
236
+
237
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
238
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
239
+
240
+ return i, j, th, tw
241
+
242
+ def __repr__(self) -> str:
243
+ return f"{self.__class__.__name__}(size={self.size})"
244
+
245
+ class CenterCropResizeVideo:
246
+ '''
247
+ First use the short side for cropping length,
248
+ center crop video, then resize to the specified size
249
+ '''
250
+ def __init__(
251
+ self,
252
+ size,
253
+ interpolation_mode="bilinear",
254
+ ):
255
+ if isinstance(size, tuple):
256
+ if len(size) != 2:
257
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
258
+ self.size = size
259
+ else:
260
+ self.size = (size, size)
261
+
262
+ self.interpolation_mode = interpolation_mode
263
+
264
+
265
+ def __call__(self, clip):
266
+ """
267
+ Args:
268
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
269
+ Returns:
270
+ torch.tensor: scale resized / center cropped video clip.
271
+ size is (T, C, crop_size, crop_size)
272
+ """
273
+ # print(clip.shape)
274
+ clip_center_crop = center_crop_using_short_edge(clip)
275
+ # print(clip_center_crop.shape) 320 512
276
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
277
+ return clip_center_crop_resize
278
+
279
+ def __repr__(self) -> str:
280
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
281
+
282
+
283
+ class SDXL:
284
+ def __init__(
285
+ self,
286
+ size,
287
+ interpolation_mode="bilinear",
288
+ ):
289
+ if isinstance(size, tuple):
290
+ if len(size) != 2:
291
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
292
+ self.size = size
293
+ else:
294
+ self.size = (size, size)
295
+
296
+ self.interpolation_mode = interpolation_mode
297
+
298
+ def __call__(self, clip):
299
+ """
300
+ Args:
301
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
302
+ Returns:
303
+ torch.tensor: scale resized / center cropped video clip.
304
+ size is (T, C, crop_size, crop_size)
305
+ """
306
+ # add aditional one pixel for avoiding error in center crop
307
+ ori_h, ori_w = clip.size(-2), clip.size(-1)
308
+ tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
309
+
310
+ # if ori_h >= tar_h and ori_w >= tar_w:
311
+ # clip_tar_crop, i, j = random_crop(clip=clip, crop_size=self.size)
312
+ # else:
313
+ # tar_h_div_ori_h = tar_h / ori_h
314
+ # tar_w_div_ori_w = tar_w / ori_w
315
+ # if tar_h_div_ori_h > tar_w_div_ori_w:
316
+ # clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
317
+ # else:
318
+ # clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
319
+ # clip_tar_crop, i, j = random_crop(clip, self.size)
320
+ if ori_h >= tar_h and ori_w >= tar_w:
321
+ tar_h_div_ori_h = tar_h / ori_h
322
+ tar_w_div_ori_w = tar_w / ori_w
323
+ if tar_h_div_ori_h > tar_w_div_ori_w:
324
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
325
+ else:
326
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
327
+ ori_h, ori_w = clip.size(-2), clip.size(-1)
328
+ clip_tar_crop, i, j = random_crop(clip, self.size)
329
+ else:
330
+ tar_h_div_ori_h = tar_h / ori_h
331
+ tar_w_div_ori_w = tar_w / ori_w
332
+ if tar_h_div_ori_h > tar_w_div_ori_w:
333
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
334
+ else:
335
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
336
+ clip_tar_crop, i, j = random_crop(clip, self.size)
337
+ return clip_tar_crop, ori_h, ori_w, i, j
338
+
339
+ def __repr__(self) -> str:
340
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
341
+
342
+
343
+ class SDXLCenterCrop:
344
+ def __init__(
345
+ self,
346
+ size,
347
+ interpolation_mode="bilinear",
348
+ ):
349
+ if isinstance(size, tuple):
350
+ if len(size) != 2:
351
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
352
+ self.size = size
353
+ else:
354
+ self.size = (size, size)
355
+
356
+ self.interpolation_mode = interpolation_mode
357
+
358
+
359
+ def __call__(self, clip):
360
+ """
361
+ Args:
362
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
363
+ Returns:
364
+ torch.tensor: scale resized / center cropped video clip.
365
+ size is (T, C, crop_size, crop_size)
366
+ """
367
+ # add aditional one pixel for avoiding error in center crop
368
+ ori_h, ori_w = clip.size(-2), clip.size(-1)
369
+ tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
370
+ tar_h_div_ori_h = tar_h / ori_h
371
+ tar_w_div_ori_w = tar_w / ori_w
372
+ # print('before resize', clip.shape)
373
+ if tar_h_div_ori_h > tar_w_div_ori_w:
374
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
375
+ # print('after h resize', clip.shape)
376
+ else:
377
+ clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
378
+ # print('after resize', clip.shape)
379
+ # print(clip.shape)
380
+ # clip_tar_crop, i, j = random_crop(clip, self.size)
381
+ clip_tar_crop, i, j = center_crop(clip, self.size)
382
+ # print('after crop', clip_tar_crop.shape)
383
+
384
+ return clip_tar_crop, ori_h, ori_w, i, j
385
+
386
+ def __repr__(self) -> str:
387
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
388
+
389
+
390
+ class InternVideo320512:
391
+ def __init__(
392
+ self,
393
+ size,
394
+ interpolation_mode="bilinear",
395
+ ):
396
+ if isinstance(size, tuple):
397
+ if len(size) != 2:
398
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
399
+ self.size = size
400
+ else:
401
+ self.size = (size, size)
402
+
403
+ self.interpolation_mode = interpolation_mode
404
+
405
+
406
+ def __call__(self, clip):
407
+ """
408
+ Args:
409
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
410
+ Returns:
411
+ torch.tensor: scale resized / center cropped video clip.
412
+ size is (T, C, crop_size, crop_size)
413
+ """
414
+ # add aditional one pixel for avoiding error in center crop
415
+ h, w = clip.size(-2), clip.size(-1)
416
+ # print('before resize', clip.shape)
417
+ if h < 320:
418
+ clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
419
+ # print('after h resize', clip.shape)
420
+ if w < 512:
421
+ clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
422
+ # print('after w resize', clip.shape)
423
+ # print(clip.shape)
424
+ clip_center_crop = center_crop(clip, self.size)
425
+ clip_center_crop_no_subtitles = center_crop(clip, (220, 352))
426
+ clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode)
427
+ # print(clip_center_crop.shape)
428
+ return clip_center_resize
429
+
430
+ def __repr__(self) -> str:
431
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
432
+
433
+ class CenterCropVideo:
434
+ '''
435
+ First scale to the specified size in equal proportion to the short edge,
436
+ then center cropping
437
+ '''
438
+ def __init__(
439
+ self,
440
+ size,
441
+ interpolation_mode="bilinear",
442
+ ):
443
+ if isinstance(size, tuple):
444
+ if len(size) != 2:
445
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
446
+ self.size = size
447
+ else:
448
+ self.size = (size, size)
449
+
450
+ self.interpolation_mode = interpolation_mode
451
+
452
+
453
+ def __call__(self, clip):
454
+ """
455
+ Args:
456
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
457
+ Returns:
458
+ torch.tensor: scale resized / center cropped video clip.
459
+ size is (T, C, crop_size, crop_size)
460
+ """
461
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
462
+ clip_center_crop = center_crop(clip_resize, self.size)
463
+ return clip_center_crop
464
+
465
+ def __repr__(self) -> str:
466
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
467
+
468
+ class KineticsRandomCropResizeVideo:
469
+ '''
470
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
471
+ '''
472
+ def __init__(
473
+ self,
474
+ size,
475
+ interpolation_mode="bilinear",
476
+ ):
477
+ if isinstance(size, tuple):
478
+ if len(size) != 2:
479
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
480
+ self.size = size
481
+ else:
482
+ self.size = (size, size)
483
+
484
+ self.interpolation_mode = interpolation_mode
485
+
486
+ def __call__(self, clip):
487
+ clip_random_crop = random_shift_crop(clip)
488
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
489
+ return clip_resize
490
+
491
+ class ResizeVideo():
492
+ '''
493
+ First use the short side for cropping length,
494
+ center crop video, then resize to the specified size
495
+ '''
496
+ def __init__(
497
+ self,
498
+ size,
499
+ interpolation_mode="bilinear",
500
+ ):
501
+ if isinstance(size, tuple):
502
+ if len(size) != 2:
503
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
504
+ self.size = size
505
+ else:
506
+ self.size = (size, size)
507
+
508
+ self.interpolation_mode = interpolation_mode
509
+
510
+
511
+ def __call__(self, clip):
512
+ """
513
+ Args:
514
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
515
+ Returns:
516
+ torch.tensor: scale resized / center cropped video clip.
517
+ size is (T, C, crop_size, crop_size)
518
+ """
519
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
520
+ return clip_resize
521
+
522
+ def __repr__(self) -> str:
523
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
524
+
525
+ class CenterCropVideo:
526
+ def __init__(
527
+ self,
528
+ size,
529
+ interpolation_mode="bilinear",
530
+ ):
531
+ if isinstance(size, tuple):
532
+ if len(size) != 2:
533
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
534
+ self.size = size
535
+ else:
536
+ self.size = (size, size)
537
+
538
+ self.interpolation_mode = interpolation_mode
539
+
540
+
541
+ def __call__(self, clip):
542
+ """
543
+ Args:
544
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
545
+ Returns:
546
+ torch.tensor: center cropped video clip.
547
+ size is (T, C, crop_size, crop_size)
548
+ """
549
+ clip_center_crop = center_crop(clip, self.size)
550
+ return clip_center_crop
551
+
552
+ def __repr__(self) -> str:
553
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
554
+
555
+
556
+ class NormalizeVideo:
557
+ """
558
+ Normalize the video clip by mean subtraction and division by standard deviation
559
+ Args:
560
+ mean (3-tuple): pixel RGB mean
561
+ std (3-tuple): pixel RGB standard deviation
562
+ inplace (boolean): whether do in-place normalization
563
+ """
564
+
565
+ def __init__(self, mean, std, inplace=False):
566
+ self.mean = mean
567
+ self.std = std
568
+ self.inplace = inplace
569
+
570
+ def __call__(self, clip):
571
+ """
572
+ Args:
573
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
574
+ """
575
+ return normalize(clip, self.mean, self.std, self.inplace)
576
+
577
+ def __repr__(self) -> str:
578
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
579
+
580
+
581
+ class ToTensorVideo:
582
+ """
583
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
584
+ permute the dimensions of clip tensor
585
+ """
586
+
587
+ def __init__(self):
588
+ pass
589
+
590
+ def __call__(self, clip):
591
+ """
592
+ Args:
593
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
594
+ Return:
595
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
596
+ """
597
+ return to_tensor(clip)
598
+
599
+ def __repr__(self) -> str:
600
+ return self.__class__.__name__
601
+
602
+
603
+ class RandomHorizontalFlipVideo:
604
+ """
605
+ Flip the video clip along the horizontal direction with a given probability
606
+ Args:
607
+ p (float): probability of the clip being flipped. Default value is 0.5
608
+ """
609
+
610
+ def __init__(self, p=0.5):
611
+ self.p = p
612
+
613
+ def __call__(self, clip):
614
+ """
615
+ Args:
616
+ clip (torch.tensor): Size is (T, C, H, W)
617
+ Return:
618
+ clip (torch.tensor): Size is (T, C, H, W)
619
+ """
620
+ if random.random() < self.p:
621
+ clip = hflip(clip)
622
+ return clip
623
+
624
+ def __repr__(self) -> str:
625
+ return f"{self.__class__.__name__}(p={self.p})"
626
+
627
+ class Compose:
628
+ """Composes several transforms together. This transform does not support torchscript.
629
+ Please, see the note below.
630
+
631
+ Args:
632
+ transforms (list of ``Transform`` objects): list of transforms to compose.
633
+
634
+ Example:
635
+ >>> transforms.Compose([
636
+ >>> transforms.CenterCrop(10),
637
+ >>> transforms.PILToTensor(),
638
+ >>> transforms.ConvertImageDtype(torch.float),
639
+ >>> ])
640
+
641
+ .. note::
642
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
643
+
644
+ >>> transforms = torch.nn.Sequential(
645
+ >>> transforms.CenterCrop(10),
646
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
647
+ >>> )
648
+ >>> scripted_transforms = torch.jit.script(transforms)
649
+
650
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
651
+ `lambda` functions or ``PIL.Image``.
652
+
653
+ """
654
+
655
+ def __init__(self, transforms):
656
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
657
+ _log_api_usage_once(self)
658
+ self.transforms = transforms
659
+
660
+ def __call__(self, img):
661
+ for t in self.transforms:
662
+ if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL):
663
+ img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img)
664
+ else:
665
+ img = t(img)
666
+ return img, ori_h, ori_w, crops_coords_top, crops_coords_left
667
+
668
+ def __repr__(self) -> str:
669
+ format_string = self.__class__.__name__ + "("
670
+ for t in self.transforms:
671
+ format_string += "\n"
672
+ format_string += f" {t}"
673
+ format_string += "\n)"
674
+ return format_string
675
+
676
+ # ------------------------------------------------------------
677
+ # --------------------- Sampling ---------------------------
678
+ # ------------------------------------------------------------
679
+ class TemporalRandomCrop(object):
680
+ """Temporally crop the given frame indices at a random location.
681
+
682
+ Args:
683
+ size (int): Desired length of frames will be seen in the model.
684
+ """
685
+
686
+ def __init__(self, size):
687
+ self.size = size
688
+
689
+ def __call__(self, total_frames):
690
+ rand_end = max(0, total_frames - self.size - 1)
691
+ begin_index = random.randint(0, rand_end)
692
+ end_index = min(begin_index + self.size, total_frames)
693
+ return begin_index, end_index
694
+
695
+
696
+ if __name__ == '__main__':
697
+ from torchvision import transforms
698
+ import torchvision.io as io
699
+ import numpy as np
700
+ from torchvision.utils import save_image
701
+ import os
702
+
703
+ vframes, aframes, info = io.read_video(
704
+ filename='./v_Archery_g01_c03.avi',
705
+ pts_unit='sec',
706
+ output_format='TCHW'
707
+ )
708
+
709
+ trans = transforms.Compose([
710
+ ToTensorVideo(),
711
+ RandomHorizontalFlipVideo(),
712
+ UCFCenterCropVideo(512),
713
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
714
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
715
+ ])
716
+
717
+ target_video_len = 32
718
+ frame_interval = 1
719
+ total_frames = len(vframes)
720
+ print(total_frames)
721
+
722
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
723
+
724
+
725
+ # Sampling video frames
726
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
727
+ # print(start_frame_ind)
728
+ # print(end_frame_ind)
729
+ assert end_frame_ind - start_frame_ind >= target_video_len
730
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
731
+ print(frame_indice)
732
+
733
+ select_vframes = vframes[frame_indice]
734
+ print(select_vframes.shape)
735
+ print(select_vframes.dtype)
736
+
737
+ select_vframes_trans = trans(select_vframes)
738
+ print(select_vframes_trans.shape)
739
+ print(select_vframes_trans.dtype)
740
+
741
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
742
+ print(select_vframes_trans_int.dtype)
743
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
744
+
745
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
746
+
747
+ for i in range(target_video_len):
748
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
demo.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import argparse
5
+ import torchvision
6
+
7
+
8
+ from pipelines.pipeline_videogen import VideoGenPipeline
9
+ from diffusers.schedulers import DDIMScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from diffusers.models import AutoencoderKLTemporalDecoder
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
+ from omegaconf import OmegaConf
14
+
15
+ import os, sys
16
+ sys.path.append(os.path.split(sys.path[0])[0])
17
+ from models import get_models
18
+ import imageio
19
+ from PIL import Image
20
+ import numpy as np
21
+ from datasets import video_transforms
22
+ from torchvision import transforms
23
+ from einops import rearrange, repeat
24
+ from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
25
+ from copy import deepcopy
26
+ import spaces
27
+ import requests
28
+ from datetime import datetime
29
+ import random
30
+
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--config", type=str, default="./configs/sample.yaml")
33
+ args = parser.parse_args()
34
+ args = OmegaConf.load(args.config)
35
+
36
+ torch.set_grad_enabled(False)
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ dtype = torch.float16 # torch.float16
39
+
40
+ unet = get_models(args).to(device, dtype=dtype)
41
+
42
+ if args.enable_vae_temporal_decoder:
43
+ if args.use_dct:
44
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
45
+ else:
46
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
47
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
48
+ else:
49
+ vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
50
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
51
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
52
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
53
+
54
+ # set eval mode
55
+ unet.eval()
56
+ vae.eval()
57
+ text_encoder.eval()
58
+
59
+ basedir = os.getcwd()
60
+ savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
61
+ savedir_sample = os.path.join(savedir, "sample")
62
+ os.makedirs(savedir, exist_ok=True)
63
+
64
+ def update_and_resize_image(input_image_path, height_slider, width_slider):
65
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
66
+ pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
67
+ else:
68
+ pil_image = Image.open(input_image_path).convert('RGB')
69
+
70
+ original_width, original_height = pil_image.size
71
+
72
+ if original_height == height_slider and original_width == width_slider:
73
+ return gr.Image(value=np.array(pil_image))
74
+
75
+ ratio1 = height_slider / original_height
76
+ ratio2 = width_slider / original_width
77
+
78
+ if ratio1 > ratio2:
79
+ new_width = int(original_width * ratio1)
80
+ new_height = int(original_height * ratio1)
81
+ else:
82
+ new_width = int(original_width * ratio2)
83
+ new_height = int(original_height * ratio2)
84
+
85
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
86
+
87
+ left = (new_width - width_slider) / 2
88
+ top = (new_height - height_slider) / 2
89
+ right = left + width_slider
90
+ bottom = top + height_slider
91
+
92
+ pil_image = pil_image.crop((left, top, right, bottom))
93
+
94
+ return gr.Image(value=np.array(pil_image))
95
+
96
+
97
+ def update_textbox_and_save_image(input_image, height_slider, width_slider):
98
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
99
+
100
+ original_width, original_height = pil_image.size
101
+
102
+ ratio1 = height_slider / original_height
103
+ ratio2 = width_slider / original_width
104
+
105
+ if ratio1 > ratio2:
106
+ new_width = int(original_width * ratio1)
107
+ new_height = int(original_height * ratio1)
108
+ else:
109
+ new_width = int(original_width * ratio2)
110
+ new_height = int(original_height * ratio2)
111
+
112
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
113
+
114
+ left = (new_width - width_slider) / 2
115
+ top = (new_height - height_slider) / 2
116
+ right = left + width_slider
117
+ bottom = top + height_slider
118
+
119
+ pil_image = pil_image.crop((left, top, right, bottom))
120
+
121
+ img_path = os.path.join(savedir, "input_image.png")
122
+ pil_image.save(img_path)
123
+
124
+ return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
125
+
126
+ def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
127
+ image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
128
+ image = transform_video(image)
129
+ image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
130
+ image = image.unsqueeze(2)
131
+ return image
132
+
133
+
134
+ @spaces.GPU
135
+ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
136
+
137
+ torch.manual_seed(seed)
138
+
139
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
140
+ subfolder="scheduler",
141
+ beta_start=args.beta_start,
142
+ beta_end=args.beta_end,
143
+ beta_schedule=args.beta_schedule)
144
+
145
+ videogen_pipeline = VideoGenPipeline(vae=vae,
146
+ text_encoder=text_encoder,
147
+ tokenizer=tokenizer,
148
+ scheduler=scheduler,
149
+ unet=unet).to(device)
150
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
151
+
152
+ transform_video = transforms.Compose([
153
+ video_transforms.ToTensorVideo(),
154
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
155
+ ])
156
+
157
+ if args.use_dct:
158
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
159
+ else:
160
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
161
+
162
+ if use_dctinit:
163
+ # filter params
164
+ print("Using DCT!")
165
+ base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
166
+
167
+ # define filter
168
+ freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
169
+
170
+ noise = torch.randn(1, 4, 15, 40, 64).to(device)
171
+
172
+ # add noise to base_content
173
+ diffuse_timesteps = torch.full((1,),int(noise_level))
174
+ diffuse_timesteps = diffuse_timesteps.long()
175
+
176
+ # 3d content
177
+ base_content_noise = scheduler.add_noise(
178
+ original_samples=base_content_repeat.to(device),
179
+ noise=noise,
180
+ timesteps=diffuse_timesteps.to(device))
181
+
182
+ # 3d content
183
+ latents = exchanged_mixed_dct_freq(noise=noise,
184
+ base_content=base_content_noise,
185
+ LPF_3d=freq_filter).to(dtype=torch.float16)
186
+
187
+ base_content = base_content.to(dtype=torch.float16)
188
+
189
+ videos = videogen_pipeline(prompt,
190
+ negative_prompt=negative_prompt,
191
+ latents=latents if use_dctinit else None,
192
+ base_content=base_content,
193
+ video_length=15,
194
+ height=height,
195
+ width=width,
196
+ num_inference_steps=diffusion_step,
197
+ guidance_scale=scfg_scale,
198
+ motion_bucket_id=100-motion_bucket_id,
199
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
200
+
201
+ save_path = args.save_img_path + 'temp' + '.mp4'
202
+ # torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
203
+ imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
204
+ return save_path
205
+
206
+
207
+ if not os.path.exists(args.save_img_path):
208
+ os.makedirs(args.save_img_path)
209
+
210
+
211
+ with gr.Blocks() as demo:
212
+
213
+ gr.Markdown("<font color=red size=6.5><center>Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models</center></font>")
214
+ gr.Markdown(
215
+ """<div style="display: flex;align-items: center;justify-content: center">
216
+ [<a href="https://arxiv.org/abs/2407.15642">Arxiv Report</a>] | [<a href="https://https://maxin-cn.github.io/cinemo_project/">Project Page</a>] | [<a href="https://github.com/maxin-cn/Cinemo">Github</a>]</div>
217
+ """
218
+ )
219
+
220
+
221
+ with gr.Column(variant="panel"):
222
+ gr.Markdown(
223
+ """
224
+ - Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
225
+ - Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
226
+ - After setting the input image path, press the "Preview" button to visualize the resized input image.
227
+ """
228
+ )
229
+
230
+ with gr.Row():
231
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
232
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
233
+
234
+ with gr.Row(equal_height=False):
235
+ with gr.Column():
236
+ with gr.Row():
237
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
238
+
239
+ with gr.Row():
240
+ seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
241
+ # seed_textbox = gr.Textbox(label="Seed", value=100)
242
+ # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
243
+ # seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
244
+
245
+ with gr.Row():
246
+ height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
247
+ width = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=16, interactive=False)
248
+ with gr.Row():
249
+ txt_cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1, interactive=True)
250
+ motion_bucket_id = gr.Slider(label="Motion Intensity", value=10, minimum=1, maximum=20, step=1, interactive=True)
251
+
252
+ with gr.Row():
253
+ use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
254
+ dct_coefficients = gr.Slider(label="DCT Coefficients", value=0.23, minimum=0, maximum=1, step=0.01, interactive=True)
255
+ noise_level = gr.Slider(label="Noise Level", value=985, minimum=1, maximum=999, step=1, interactive=True)
256
+
257
+ generate_button = gr.Button(value="Generate", variant='primary')
258
+
259
+ with gr.Column():
260
+ with gr.Row():
261
+ input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
262
+ preview_button = gr.Button(value="Preview")
263
+
264
+ with gr.Row():
265
+ input_image = gr.Image(label="Input Image", interactive=True)
266
+ input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height, width], outputs=[input_image_path, input_image])
267
+ result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
268
+
269
+ preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
270
+ input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
271
+
272
+ EXAMPLES = [
273
+ ["./example/aircrafts_flying/0.jpg", "aircrafts flying", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
+ ["./example/fireworks/0.jpg", "fireworks", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
+ ["./example/flowers_swaying/0.jpg", "flowers swaying", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
276
+ ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach", "", 50, 320, 512, 7.5, True, 0.23, 985, 10, 200],
277
+ ["./example/house_rotating/0.jpg", "house rotating", "", 50, 320, 512, 7.5, True, 0.23, 985, 10, 100],
278
+ ["./example/people_runing/0.jpg", "people runing", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
279
+ ]
280
+
281
+ examples = gr.Examples(
282
+ examples = EXAMPLES,
283
+ fn = gen_video,
284
+ inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
285
+ outputs=[result_video],
286
+ # cache_examples=True,
287
+ cache_examples="lazy",
288
+ )
289
+
290
+ generate_button.click(
291
+ fn=gen_video,
292
+ inputs=[
293
+ input_image,
294
+ prompt_textbox,
295
+ negative_prompt_textbox,
296
+ sample_step_slider,
297
+ height,
298
+ width,
299
+ txt_cfg_scale,
300
+ use_dctinit,
301
+ dct_coefficients,
302
+ noise_level,
303
+ motion_bucket_id,
304
+ seed_textbox,
305
+ ],
306
+ outputs=[result_video]
307
+ )
308
+
309
+ demo.launch(debug=False, share=True)
310
+
311
+ # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
environment.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cinemo
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.10
7
+ - pytorch >= 2.0
8
+ - torchvision
9
+ - pytorch-cuda >= 11.7
10
+ - pip:
11
+ - timm
12
+ - diffusers[torch]==0.24.0
13
+ - accelerate
14
+ - python-hostlist
15
+ - tensorboard
16
+ - einops
17
+ - transformers
18
+ - av
19
+ - scikit-image
20
+ - decord
21
+ - pandas
example/aircrafts_flying/0.jpg ADDED
example/aircrafts_flying/aircrafts_flying.mp4 ADDED
Binary file (533 kB). View file
 
example/car_moving/0.jpg ADDED
example/car_moving/car_moving.mp4 ADDED
Binary file (399 kB). View file
 
example/fireworks/0.jpg ADDED
example/fireworks/fireworks.mp4 ADDED
Binary file (479 kB). View file
 
example/flowers_swaying/0.jpg ADDED
example/flowers_swaying/flowers_swaying.mp4 ADDED
Binary file (469 kB). View file
 
example/girl_walking_on_the_beach/0.jpg ADDED
example/girl_walking_on_the_beach/girl_walking_on_the_beach.mp4 ADDED
Binary file (619 kB). View file
 
example/house_rotating/0.jpg ADDED
example/house_rotating/house_rotating.mp4 ADDED
Binary file (481 kB). View file
 
example/people_runing/0.jpg ADDED
example/people_runing/people_runing.mp4 ADDED
Binary file (482 kB). View file
 
example/shark_swimming/0.jpg ADDED
example/shark_swimming/shark_swimming.mp4 ADDED
Binary file (282 kB). View file
 
example/windmill_turning/0.jpg ADDED
example/windmill_turning/windmill_turning.mp4 ADDED
Binary file (403 kB). View file
 
gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4 ADDED
Binary file (226 kB). View file
 
gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4 ADDED
Binary file (223 kB). View file
 
gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/.nfs6a1237621cfe7a8800009149 ADDED
Binary file (334 kB). View file
 
gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4 ADDED
Binary file (619 kB). View file
 
gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4 ADDED
Binary file (272 kB). View file
 
gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/.nfs88c2a0e49709591000009148 ADDED
Binary file (352 kB). View file
 
gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4 ADDED
Binary file (481 kB). View file
 
gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4 ADDED
Binary file (206 kB). View file
 
gradio_cached_examples/39/indices.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 0
2
+ 1
3
+ 2
4
+ 3
5
+ 4
6
+ 5
gradio_cached_examples/39/log.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Generated Animation,flag,username,timestamp
2
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/24d44f90b57c00bd6fdccd61cc35a4f4d459388c/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:42:07.603268
3
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/deb9620f616c3681cb074388781099f78a25dc8f/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:42:25.127236
4
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/d9e392300169a439b3f5721579849e3e5ce6abf9/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:44:01.949003
5
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/d1a5066828be61057823cf98edae890db907f358/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:00.196580
6
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/5998620333bb5bdaf52b49f2de86e428df991431/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:28.848377
7
+ "{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/7e3c6838256fd5a89b20ad62ce42288212e62097/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:48.004623
models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .unet import UNet3DConditionModel
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ def fn(step):
11
+ if warmup_steps > 0:
12
+ return min(step / warmup_steps, 1)
13
+ else:
14
+ return 1
15
+ return LambdaLR(optimizer, fn)
16
+
17
+
18
+ def get_lr_scheduler(optimizer, name, **kwargs):
19
+ if name == 'warmup':
20
+ return customized_lr_scheduler(optimizer, **kwargs)
21
+ elif name == 'cosine':
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ return CosineAnnealingLR(optimizer, **kwargs)
24
+ else:
25
+ raise NotImplementedError(name)
26
+
27
+ def get_models(args):
28
+
29
+ if 'UNet' in args.model:
30
+ return UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet")
31
+ else:
32
+ raise '{} Model Not Supported!'.format(args.model)
33
+
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.73 kB). View file
 
models/__pycache__/attention.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
models/__pycache__/resnet.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
models/__pycache__/rotary_embedding_torch_mx.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
models/__pycache__/temporal_attention.cpython-312.pyc ADDED
Binary file (21.7 kB). View file