Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- .gitignore +1 -0
- LICENSE +201 -0
- README.md +133 -12
- __pycache__/utils.cpython-312.pyc +0 -0
- animated_images/aircraft.jpg +3 -0
- animated_images/car.jpg +3 -0
- animated_images/fireworks.jpg +0 -0
- animated_images/flowers.jpg +0 -0
- animated_images/forest.jpg +3 -0
- animated_images/shark_unwater.jpg +0 -0
- configs/sample.yaml +38 -0
- datasets/__pycache__/video_transforms.cpython-312.pyc +0 -0
- datasets/video_transforms.py +748 -0
- demo.py +311 -0
- environment.yml +21 -0
- example/aircrafts_flying/0.jpg +0 -0
- example/aircrafts_flying/aircrafts_flying.mp4 +0 -0
- example/car_moving/0.jpg +0 -0
- example/car_moving/car_moving.mp4 +0 -0
- example/fireworks/0.jpg +0 -0
- example/fireworks/fireworks.mp4 +0 -0
- example/flowers_swaying/0.jpg +0 -0
- example/flowers_swaying/flowers_swaying.mp4 +0 -0
- example/girl_walking_on_the_beach/0.jpg +0 -0
- example/girl_walking_on_the_beach/girl_walking_on_the_beach.mp4 +0 -0
- example/house_rotating/0.jpg +0 -0
- example/house_rotating/house_rotating.mp4 +0 -0
- example/people_runing/0.jpg +0 -0
- example/people_runing/people_runing.mp4 +0 -0
- example/shark_swimming/0.jpg +0 -0
- example/shark_swimming/shark_swimming.mp4 +0 -0
- example/windmill_turning/0.jpg +0 -0
- example/windmill_turning/windmill_turning.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/.nfs6a1237621cfe7a8800009149 +0 -0
- gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/.nfs88c2a0e49709591000009148 +0 -0
- gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4 +0 -0
- gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4 +0 -0
- gradio_cached_examples/39/indices.csv +6 -0
- gradio_cached_examples/39/log.csv +7 -0
- models/__init__.py +33 -0
- models/__pycache__/__init__.cpython-312.pyc +0 -0
- models/__pycache__/attention.cpython-312.pyc +0 -0
- models/__pycache__/resnet.cpython-312.pyc +0 -0
- models/__pycache__/rotary_embedding_torch_mx.cpython-312.pyc +0 -0
- models/__pycache__/temporal_attention.cpython-312.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
animated_images/aircraft.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
animated_images/car.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
animated_images/forest.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
visuals/animations/dragon_glowing_eyes/dragon_glowing_eyes.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
visuals/animations/girl_dancing_under_the_stars/girl_dancing_under_the_stars.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
visuals/animations/people_walking/people_walking.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
visuals/animations/sea_swell/sea_swell.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
visuals/video_editing/edit/editing_a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif filter=lfs diff=lfs merge=lfs -text
|
44 |
+
visuals/video_editing/origin/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.vscode
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [XIN MA] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,133 @@
|
|
1 |
-
---
|
2 |
-
title: Cinemo
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Cinemo
|
3 |
+
app_file: demo.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.37.2
|
6 |
+
---
|
7 |
+
## Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models<br><sub>Official PyTorch Implementation</sub>
|
8 |
+
|
9 |
+
|
10 |
+
[![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b.svg)](https://arxiv.org/abs/2407.15642)
|
11 |
+
[![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://maxin-cn.github.io/cinemo_project/)
|
12 |
+
|
13 |
+
|
14 |
+
This repo contains pre-trained weights, and sampling code for our paper exploring image animation with motion diffusion models (Cinemo). You can find more visualizations on our [project page](https://maxin-cn.github.io/cinemo_project/).
|
15 |
+
|
16 |
+
In this project, we propose a novel method called Cinemo, which can perform motion-controllable image animation with strong consistency and smoothness. To improve motion smoothness, Cinemo learns the distribution of motion residuals, rather than directly generating subsequent frames. Additionally, a structural similarity index-based method is proposed to control the motion intensity. Furthermore, we propose a noise refinement technique based on discrete cosine transformation to ensure temporal consistency. These three methods help Cinemo generate highly consistent, smooth, and motion-controlled image animation results. Compared to previous methods, Cinemo offers simpler and more precise user control and better generative performance.
|
17 |
+
|
18 |
+
<div align="center">
|
19 |
+
<img src="visuals/pipeline.svg">
|
20 |
+
</div>
|
21 |
+
|
22 |
+
## News
|
23 |
+
|
24 |
+
- (🔥 New) Jul. 23, 2024. 💥 Our paper is released on [arxiv](https://arxiv.org/abs/2407.15642).
|
25 |
+
|
26 |
+
- (🔥 New) Jun. 2, 2024. 💥 The inference code is released. The checkpoint can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main).
|
27 |
+
|
28 |
+
|
29 |
+
## Setup
|
30 |
+
|
31 |
+
First, download and set up the repo:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
git clone https://github.com/maxin-cn/Cinemo
|
35 |
+
cd Cinemo
|
36 |
+
```
|
37 |
+
|
38 |
+
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
|
39 |
+
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
|
40 |
+
|
41 |
+
```bash
|
42 |
+
conda env create -f environment.yml
|
43 |
+
conda activate cinemo
|
44 |
+
```
|
45 |
+
|
46 |
+
|
47 |
+
## Animation
|
48 |
+
|
49 |
+
You can sample from our **pre-trained Cinemo models** with [`animation.py`](pipelines/animation.py). Weights for our pre-trained Cinemo model can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main). The script has various arguments for adjusting sampling steps, changing the classifier-free guidance scale, etc:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
bash pipelines/animation.sh
|
53 |
+
```
|
54 |
+
|
55 |
+
All related checkpoints will download automatically and then you will get the following results,
|
56 |
+
|
57 |
+
<table style="width:100%; text-align:center;">
|
58 |
+
<tr>
|
59 |
+
<td align="center">Input image</td>
|
60 |
+
<td align="center">Output video</td>
|
61 |
+
<td align="center">Input image</td>
|
62 |
+
<td align="center">Output video</td>
|
63 |
+
</tr>
|
64 |
+
<tr>
|
65 |
+
<td align="center"><img src="visuals/animations/people_walking/0.jpg" width="100%"></td>
|
66 |
+
<td align="center"><img src="visuals/animations/people_walking/people_walking.gif" width="100%"></td>
|
67 |
+
<td align="center"><img src="visuals/animations/sea_swell/0.jpg" width="100%"></td>
|
68 |
+
<td align="center"><img src="visuals/animations/sea_swell/sea_swell.gif" width="100%"></td>
|
69 |
+
</tr>
|
70 |
+
<tr>
|
71 |
+
<td align="center" colspan="2">"People Walking"</td>
|
72 |
+
<td align="center" colspan="2">"Sea Swell"</td>
|
73 |
+
</tr>
|
74 |
+
<tr>
|
75 |
+
<td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/0.jpg" width="100%"></td>
|
76 |
+
<td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/girl_dancing_under_the_stars.gif" width="100%"></td>
|
77 |
+
<td align="center"><img src="visuals/animations/dragon_glowing_eyes/0.jpg" width="100%"></td>
|
78 |
+
<td align="center"><img src="visuals/animations/dragon_glowing_eyes/dragon_glowing_eyes.gif" width="100%"></td>
|
79 |
+
</tr>
|
80 |
+
<tr>
|
81 |
+
<td align="center" colspan="2">"Girl Dancing under the Stars"</td>
|
82 |
+
<td align="center" colspan="2">"Dragon Glowing Eyes"</td>
|
83 |
+
</tr>
|
84 |
+
|
85 |
+
</table>
|
86 |
+
|
87 |
+
|
88 |
+
## Other Applications
|
89 |
+
|
90 |
+
You can also utilize Cinemo for other applications, such as motion transfer and video editing:
|
91 |
+
|
92 |
+
```bash
|
93 |
+
bash pipelines/video_editing.sh
|
94 |
+
```
|
95 |
+
|
96 |
+
All related checkpoints will download automatically and you will get the following results,
|
97 |
+
|
98 |
+
<table style="width:100%; text-align:center;">
|
99 |
+
<tr>
|
100 |
+
<td align="center">Input video</td>
|
101 |
+
<td align="center">First frame</td>
|
102 |
+
<td align="center">Edited first frame</td>
|
103 |
+
<td align="center">Output video</td>
|
104 |
+
</tr>
|
105 |
+
<tr>
|
106 |
+
<td align="center"><img src="visuals/video_editing/origin/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
|
107 |
+
<td align="center"><img src="visuals/video_editing/origin/0.jpg" width="100%"></td>
|
108 |
+
<td align="center"><img src="visuals/video_editing/edit/0.jpg" width="100%"></td>
|
109 |
+
<td align="center"><img src="visuals/video_editing/edit/editing_a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
|
110 |
+
</tr>
|
111 |
+
|
112 |
+
</table>
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
## Citation
|
117 |
+
If you find this work useful for your research, please consider citing it.
|
118 |
+
```bibtex
|
119 |
+
@article{ma2024cinemo,
|
120 |
+
title={Cinemo: Latent Diffusion Transformer for Video Generation},
|
121 |
+
author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
|
122 |
+
journal={arXiv preprint arXiv:2407.15642},
|
123 |
+
year={2024}
|
124 |
+
}
|
125 |
+
```
|
126 |
+
|
127 |
+
|
128 |
+
## Acknowledgments
|
129 |
+
Cinemo has been greatly inspired by the following amazing works and teams: [LaVie](https://github.com/Vchitect/LaVie) and [SEINE](https://github.com/Vchitect/SEINE), we thank all the contributors for open-sourcing.
|
130 |
+
|
131 |
+
|
132 |
+
## License
|
133 |
+
The code and model weights are licensed under [LICENSE](LICENSE).
|
__pycache__/utils.cpython-312.pyc
ADDED
Binary file (6.62 kB). View file
|
|
animated_images/aircraft.jpg
ADDED
Git LFS Details
|
animated_images/car.jpg
ADDED
Git LFS Details
|
animated_images/fireworks.jpg
ADDED
animated_images/flowers.jpg
ADDED
animated_images/forest.jpg
ADDED
Git LFS Details
|
animated_images/shark_unwater.jpg
ADDED
configs/sample.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ckpt
|
2 |
+
ckpt: /mnt/lustre/maxin/work/animation/animation-v6/results_qt_003-UNet-pixabaylaion-Xfor-Gc-320_512_0303000.pt
|
3 |
+
save_img_path: "./sample_videos/"
|
4 |
+
|
5 |
+
# pretrained_model_path: "/mnt/hwfile/gcc/maxin/work/pretrained/stable-diffusion-v1-4/"
|
6 |
+
# pretrained_model_path: "maxin-cn/Cinemo"
|
7 |
+
pretrained_model_path: "./pretrained/Cinemo"
|
8 |
+
|
9 |
+
# model config:
|
10 |
+
model: UNet
|
11 |
+
video_length: 15
|
12 |
+
image_size: [320, 512]
|
13 |
+
# beta schedule
|
14 |
+
beta_start: 0.0001
|
15 |
+
beta_end: 0.02
|
16 |
+
beta_schedule: "linear"
|
17 |
+
|
18 |
+
# model speedup
|
19 |
+
use_compile: False
|
20 |
+
use_fp16: True
|
21 |
+
|
22 |
+
# sample config:
|
23 |
+
seed:
|
24 |
+
run_time: 0
|
25 |
+
use_dct: True
|
26 |
+
guidance_scale: 7.5 #
|
27 |
+
motion_bucket_id: 95 # [0-19] The larger the value, the stronger the motion intensity
|
28 |
+
sample_method: 'DDIM'
|
29 |
+
num_sampling_steps: 50
|
30 |
+
enable_vae_temporal_decoder: True
|
31 |
+
image_prompts: [
|
32 |
+
['aircraft.jpg', 'aircrafts flying'],
|
33 |
+
['car.jpg' ,"car moving"],
|
34 |
+
['fireworks.jpg', 'fireworks'],
|
35 |
+
['flowers.jpg', 'flowers swaying'],
|
36 |
+
['forest.jpg', 'people walking'],
|
37 |
+
['shark_unwater.jpg', 'shark falling into the sea'],
|
38 |
+
]
|
datasets/__pycache__/video_transforms.cpython-312.pyc
ADDED
Binary file (32.3 kB). View file
|
|
datasets/video_transforms.py
ADDED
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numbers
|
4 |
+
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.utils import _log_api_usage_once
|
7 |
+
|
8 |
+
def _is_tensor_video_clip(clip):
|
9 |
+
if not torch.is_tensor(clip):
|
10 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
11 |
+
|
12 |
+
if not clip.ndimension() == 4:
|
13 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
14 |
+
|
15 |
+
return True
|
16 |
+
|
17 |
+
|
18 |
+
def center_crop_arr(pil_image, image_size):
|
19 |
+
"""
|
20 |
+
Center cropping implementation from ADM.
|
21 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
22 |
+
"""
|
23 |
+
while min(*pil_image.size) >= 2 * image_size:
|
24 |
+
pil_image = pil_image.resize(
|
25 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
26 |
+
)
|
27 |
+
|
28 |
+
scale = image_size / min(*pil_image.size)
|
29 |
+
pil_image = pil_image.resize(
|
30 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
31 |
+
)
|
32 |
+
|
33 |
+
arr = np.array(pil_image)
|
34 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
35 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
36 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
37 |
+
|
38 |
+
|
39 |
+
def crop(clip, i, j, h, w):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
43 |
+
"""
|
44 |
+
if len(clip.size()) != 4:
|
45 |
+
raise ValueError("clip should be a 4D tensor")
|
46 |
+
return clip[..., i : i + h, j : j + w]
|
47 |
+
|
48 |
+
|
49 |
+
def resize(clip, target_size, interpolation_mode):
|
50 |
+
if len(target_size) != 2:
|
51 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
52 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
53 |
+
|
54 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
55 |
+
if len(target_size) != 2:
|
56 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
57 |
+
H, W = clip.size(-2), clip.size(-1)
|
58 |
+
scale_ = target_size[0] / min(H, W)
|
59 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
60 |
+
|
61 |
+
def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
|
62 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
|
63 |
+
|
64 |
+
def resize_scale_with_height(clip, target_size, interpolation_mode):
|
65 |
+
H, W = clip.size(-2), clip.size(-1)
|
66 |
+
scale_ = target_size / H
|
67 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
68 |
+
|
69 |
+
def resize_scale_with_weight(clip, target_size, interpolation_mode):
|
70 |
+
H, W = clip.size(-2), clip.size(-1)
|
71 |
+
scale_ = target_size / W
|
72 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
73 |
+
|
74 |
+
|
75 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
76 |
+
"""
|
77 |
+
Do spatial cropping and resizing to the video clip
|
78 |
+
Args:
|
79 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
80 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
81 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
82 |
+
h (int): Height of the cropped region.
|
83 |
+
w (int): Width of the cropped region.
|
84 |
+
size (tuple(int, int)): height and width of resized clip
|
85 |
+
Returns:
|
86 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
87 |
+
"""
|
88 |
+
if not _is_tensor_video_clip(clip):
|
89 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
90 |
+
clip = crop(clip, i, j, h, w)
|
91 |
+
clip = resize(clip, size, interpolation_mode)
|
92 |
+
return clip
|
93 |
+
|
94 |
+
|
95 |
+
def center_crop(clip, crop_size):
|
96 |
+
if not _is_tensor_video_clip(clip):
|
97 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
98 |
+
h, w = clip.size(-2), clip.size(-1)
|
99 |
+
# print(clip.shape)
|
100 |
+
th, tw = crop_size
|
101 |
+
if h < th or w < tw:
|
102 |
+
# print(h, w)
|
103 |
+
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
|
104 |
+
|
105 |
+
i = int(round((h - th) / 2.0))
|
106 |
+
j = int(round((w - tw) / 2.0))
|
107 |
+
return crop(clip, i, j, th, tw), i, j
|
108 |
+
|
109 |
+
|
110 |
+
def center_crop_using_short_edge(clip):
|
111 |
+
if not _is_tensor_video_clip(clip):
|
112 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
113 |
+
h, w = clip.size(-2), clip.size(-1)
|
114 |
+
if h < w:
|
115 |
+
th, tw = h, h
|
116 |
+
i = 0
|
117 |
+
j = int(round((w - tw) / 2.0))
|
118 |
+
else:
|
119 |
+
th, tw = w, w
|
120 |
+
i = int(round((h - th) / 2.0))
|
121 |
+
j = 0
|
122 |
+
return crop(clip, i, j, th, tw)
|
123 |
+
|
124 |
+
|
125 |
+
def random_shift_crop(clip):
|
126 |
+
'''
|
127 |
+
Slide along the long edge, with the short edge as crop size
|
128 |
+
'''
|
129 |
+
if not _is_tensor_video_clip(clip):
|
130 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
131 |
+
h, w = clip.size(-2), clip.size(-1)
|
132 |
+
|
133 |
+
if h <= w:
|
134 |
+
long_edge = w
|
135 |
+
short_edge = h
|
136 |
+
else:
|
137 |
+
long_edge = h
|
138 |
+
short_edge =w
|
139 |
+
|
140 |
+
th, tw = short_edge, short_edge
|
141 |
+
|
142 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
143 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
144 |
+
return crop(clip, i, j, th, tw), i, j
|
145 |
+
|
146 |
+
def random_crop(clip, crop_size):
|
147 |
+
if not _is_tensor_video_clip(clip):
|
148 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
149 |
+
h, w = clip.size(-2), clip.size(-1)
|
150 |
+
th, tw = crop_size[-2], crop_size[-1]
|
151 |
+
|
152 |
+
if h < th or w < tw:
|
153 |
+
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
|
154 |
+
|
155 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
156 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
157 |
+
clip_crop = crop(clip, i, j, th, tw)
|
158 |
+
return clip_crop, i, j
|
159 |
+
|
160 |
+
|
161 |
+
def to_tensor(clip):
|
162 |
+
"""
|
163 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
164 |
+
permute the dimensions of clip tensor
|
165 |
+
Args:
|
166 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
167 |
+
Return:
|
168 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
169 |
+
"""
|
170 |
+
_is_tensor_video_clip(clip)
|
171 |
+
if not clip.dtype == torch.uint8:
|
172 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
173 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
174 |
+
return clip.float() / 255.0
|
175 |
+
|
176 |
+
|
177 |
+
def normalize(clip, mean, std, inplace=False):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
181 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
182 |
+
std (tuple): pixel standard deviation. Size is (3)
|
183 |
+
Returns:
|
184 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
185 |
+
"""
|
186 |
+
if not _is_tensor_video_clip(clip):
|
187 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
188 |
+
if not inplace:
|
189 |
+
clip = clip.clone()
|
190 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
191 |
+
# print(mean)
|
192 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
193 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
194 |
+
return clip
|
195 |
+
|
196 |
+
|
197 |
+
def hflip(clip):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
201 |
+
Returns:
|
202 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
203 |
+
"""
|
204 |
+
if not _is_tensor_video_clip(clip):
|
205 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
206 |
+
return clip.flip(-1)
|
207 |
+
|
208 |
+
|
209 |
+
class RandomCropVideo:
|
210 |
+
def __init__(self, size):
|
211 |
+
if isinstance(size, numbers.Number):
|
212 |
+
self.size = (int(size), int(size))
|
213 |
+
else:
|
214 |
+
self.size = size
|
215 |
+
|
216 |
+
def __call__(self, clip):
|
217 |
+
"""
|
218 |
+
Args:
|
219 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
220 |
+
Returns:
|
221 |
+
torch.tensor: randomly cropped video clip.
|
222 |
+
size is (T, C, OH, OW)
|
223 |
+
"""
|
224 |
+
i, j, h, w = self.get_params(clip)
|
225 |
+
return crop(clip, i, j, h, w)
|
226 |
+
|
227 |
+
def get_params(self, clip):
|
228 |
+
h, w = clip.shape[-2:]
|
229 |
+
th, tw = self.size
|
230 |
+
|
231 |
+
if h < th or w < tw:
|
232 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
233 |
+
|
234 |
+
if w == tw and h == th:
|
235 |
+
return 0, 0, h, w
|
236 |
+
|
237 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
238 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
239 |
+
|
240 |
+
return i, j, th, tw
|
241 |
+
|
242 |
+
def __repr__(self) -> str:
|
243 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
244 |
+
|
245 |
+
class CenterCropResizeVideo:
|
246 |
+
'''
|
247 |
+
First use the short side for cropping length,
|
248 |
+
center crop video, then resize to the specified size
|
249 |
+
'''
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
size,
|
253 |
+
interpolation_mode="bilinear",
|
254 |
+
):
|
255 |
+
if isinstance(size, tuple):
|
256 |
+
if len(size) != 2:
|
257 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
258 |
+
self.size = size
|
259 |
+
else:
|
260 |
+
self.size = (size, size)
|
261 |
+
|
262 |
+
self.interpolation_mode = interpolation_mode
|
263 |
+
|
264 |
+
|
265 |
+
def __call__(self, clip):
|
266 |
+
"""
|
267 |
+
Args:
|
268 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
269 |
+
Returns:
|
270 |
+
torch.tensor: scale resized / center cropped video clip.
|
271 |
+
size is (T, C, crop_size, crop_size)
|
272 |
+
"""
|
273 |
+
# print(clip.shape)
|
274 |
+
clip_center_crop = center_crop_using_short_edge(clip)
|
275 |
+
# print(clip_center_crop.shape) 320 512
|
276 |
+
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
277 |
+
return clip_center_crop_resize
|
278 |
+
|
279 |
+
def __repr__(self) -> str:
|
280 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
281 |
+
|
282 |
+
|
283 |
+
class SDXL:
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
size,
|
287 |
+
interpolation_mode="bilinear",
|
288 |
+
):
|
289 |
+
if isinstance(size, tuple):
|
290 |
+
if len(size) != 2:
|
291 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
292 |
+
self.size = size
|
293 |
+
else:
|
294 |
+
self.size = (size, size)
|
295 |
+
|
296 |
+
self.interpolation_mode = interpolation_mode
|
297 |
+
|
298 |
+
def __call__(self, clip):
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
302 |
+
Returns:
|
303 |
+
torch.tensor: scale resized / center cropped video clip.
|
304 |
+
size is (T, C, crop_size, crop_size)
|
305 |
+
"""
|
306 |
+
# add aditional one pixel for avoiding error in center crop
|
307 |
+
ori_h, ori_w = clip.size(-2), clip.size(-1)
|
308 |
+
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
|
309 |
+
|
310 |
+
# if ori_h >= tar_h and ori_w >= tar_w:
|
311 |
+
# clip_tar_crop, i, j = random_crop(clip=clip, crop_size=self.size)
|
312 |
+
# else:
|
313 |
+
# tar_h_div_ori_h = tar_h / ori_h
|
314 |
+
# tar_w_div_ori_w = tar_w / ori_w
|
315 |
+
# if tar_h_div_ori_h > tar_w_div_ori_w:
|
316 |
+
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
|
317 |
+
# else:
|
318 |
+
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
|
319 |
+
# clip_tar_crop, i, j = random_crop(clip, self.size)
|
320 |
+
if ori_h >= tar_h and ori_w >= tar_w:
|
321 |
+
tar_h_div_ori_h = tar_h / ori_h
|
322 |
+
tar_w_div_ori_w = tar_w / ori_w
|
323 |
+
if tar_h_div_ori_h > tar_w_div_ori_w:
|
324 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
|
325 |
+
else:
|
326 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
|
327 |
+
ori_h, ori_w = clip.size(-2), clip.size(-1)
|
328 |
+
clip_tar_crop, i, j = random_crop(clip, self.size)
|
329 |
+
else:
|
330 |
+
tar_h_div_ori_h = tar_h / ori_h
|
331 |
+
tar_w_div_ori_w = tar_w / ori_w
|
332 |
+
if tar_h_div_ori_h > tar_w_div_ori_w:
|
333 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
|
334 |
+
else:
|
335 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
|
336 |
+
clip_tar_crop, i, j = random_crop(clip, self.size)
|
337 |
+
return clip_tar_crop, ori_h, ori_w, i, j
|
338 |
+
|
339 |
+
def __repr__(self) -> str:
|
340 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
341 |
+
|
342 |
+
|
343 |
+
class SDXLCenterCrop:
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
size,
|
347 |
+
interpolation_mode="bilinear",
|
348 |
+
):
|
349 |
+
if isinstance(size, tuple):
|
350 |
+
if len(size) != 2:
|
351 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
352 |
+
self.size = size
|
353 |
+
else:
|
354 |
+
self.size = (size, size)
|
355 |
+
|
356 |
+
self.interpolation_mode = interpolation_mode
|
357 |
+
|
358 |
+
|
359 |
+
def __call__(self, clip):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
363 |
+
Returns:
|
364 |
+
torch.tensor: scale resized / center cropped video clip.
|
365 |
+
size is (T, C, crop_size, crop_size)
|
366 |
+
"""
|
367 |
+
# add aditional one pixel for avoiding error in center crop
|
368 |
+
ori_h, ori_w = clip.size(-2), clip.size(-1)
|
369 |
+
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
|
370 |
+
tar_h_div_ori_h = tar_h / ori_h
|
371 |
+
tar_w_div_ori_w = tar_w / ori_w
|
372 |
+
# print('before resize', clip.shape)
|
373 |
+
if tar_h_div_ori_h > tar_w_div_ori_w:
|
374 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
|
375 |
+
# print('after h resize', clip.shape)
|
376 |
+
else:
|
377 |
+
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
|
378 |
+
# print('after resize', clip.shape)
|
379 |
+
# print(clip.shape)
|
380 |
+
# clip_tar_crop, i, j = random_crop(clip, self.size)
|
381 |
+
clip_tar_crop, i, j = center_crop(clip, self.size)
|
382 |
+
# print('after crop', clip_tar_crop.shape)
|
383 |
+
|
384 |
+
return clip_tar_crop, ori_h, ori_w, i, j
|
385 |
+
|
386 |
+
def __repr__(self) -> str:
|
387 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
388 |
+
|
389 |
+
|
390 |
+
class InternVideo320512:
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
size,
|
394 |
+
interpolation_mode="bilinear",
|
395 |
+
):
|
396 |
+
if isinstance(size, tuple):
|
397 |
+
if len(size) != 2:
|
398 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
399 |
+
self.size = size
|
400 |
+
else:
|
401 |
+
self.size = (size, size)
|
402 |
+
|
403 |
+
self.interpolation_mode = interpolation_mode
|
404 |
+
|
405 |
+
|
406 |
+
def __call__(self, clip):
|
407 |
+
"""
|
408 |
+
Args:
|
409 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
410 |
+
Returns:
|
411 |
+
torch.tensor: scale resized / center cropped video clip.
|
412 |
+
size is (T, C, crop_size, crop_size)
|
413 |
+
"""
|
414 |
+
# add aditional one pixel for avoiding error in center crop
|
415 |
+
h, w = clip.size(-2), clip.size(-1)
|
416 |
+
# print('before resize', clip.shape)
|
417 |
+
if h < 320:
|
418 |
+
clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
|
419 |
+
# print('after h resize', clip.shape)
|
420 |
+
if w < 512:
|
421 |
+
clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
|
422 |
+
# print('after w resize', clip.shape)
|
423 |
+
# print(clip.shape)
|
424 |
+
clip_center_crop = center_crop(clip, self.size)
|
425 |
+
clip_center_crop_no_subtitles = center_crop(clip, (220, 352))
|
426 |
+
clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
427 |
+
# print(clip_center_crop.shape)
|
428 |
+
return clip_center_resize
|
429 |
+
|
430 |
+
def __repr__(self) -> str:
|
431 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
432 |
+
|
433 |
+
class CenterCropVideo:
|
434 |
+
'''
|
435 |
+
First scale to the specified size in equal proportion to the short edge,
|
436 |
+
then center cropping
|
437 |
+
'''
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
size,
|
441 |
+
interpolation_mode="bilinear",
|
442 |
+
):
|
443 |
+
if isinstance(size, tuple):
|
444 |
+
if len(size) != 2:
|
445 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
446 |
+
self.size = size
|
447 |
+
else:
|
448 |
+
self.size = (size, size)
|
449 |
+
|
450 |
+
self.interpolation_mode = interpolation_mode
|
451 |
+
|
452 |
+
|
453 |
+
def __call__(self, clip):
|
454 |
+
"""
|
455 |
+
Args:
|
456 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
457 |
+
Returns:
|
458 |
+
torch.tensor: scale resized / center cropped video clip.
|
459 |
+
size is (T, C, crop_size, crop_size)
|
460 |
+
"""
|
461 |
+
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
462 |
+
clip_center_crop = center_crop(clip_resize, self.size)
|
463 |
+
return clip_center_crop
|
464 |
+
|
465 |
+
def __repr__(self) -> str:
|
466 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
467 |
+
|
468 |
+
class KineticsRandomCropResizeVideo:
|
469 |
+
'''
|
470 |
+
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
471 |
+
'''
|
472 |
+
def __init__(
|
473 |
+
self,
|
474 |
+
size,
|
475 |
+
interpolation_mode="bilinear",
|
476 |
+
):
|
477 |
+
if isinstance(size, tuple):
|
478 |
+
if len(size) != 2:
|
479 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
480 |
+
self.size = size
|
481 |
+
else:
|
482 |
+
self.size = (size, size)
|
483 |
+
|
484 |
+
self.interpolation_mode = interpolation_mode
|
485 |
+
|
486 |
+
def __call__(self, clip):
|
487 |
+
clip_random_crop = random_shift_crop(clip)
|
488 |
+
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
489 |
+
return clip_resize
|
490 |
+
|
491 |
+
class ResizeVideo():
|
492 |
+
'''
|
493 |
+
First use the short side for cropping length,
|
494 |
+
center crop video, then resize to the specified size
|
495 |
+
'''
|
496 |
+
def __init__(
|
497 |
+
self,
|
498 |
+
size,
|
499 |
+
interpolation_mode="bilinear",
|
500 |
+
):
|
501 |
+
if isinstance(size, tuple):
|
502 |
+
if len(size) != 2:
|
503 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
504 |
+
self.size = size
|
505 |
+
else:
|
506 |
+
self.size = (size, size)
|
507 |
+
|
508 |
+
self.interpolation_mode = interpolation_mode
|
509 |
+
|
510 |
+
|
511 |
+
def __call__(self, clip):
|
512 |
+
"""
|
513 |
+
Args:
|
514 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
515 |
+
Returns:
|
516 |
+
torch.tensor: scale resized / center cropped video clip.
|
517 |
+
size is (T, C, crop_size, crop_size)
|
518 |
+
"""
|
519 |
+
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
520 |
+
return clip_resize
|
521 |
+
|
522 |
+
def __repr__(self) -> str:
|
523 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
524 |
+
|
525 |
+
class CenterCropVideo:
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
size,
|
529 |
+
interpolation_mode="bilinear",
|
530 |
+
):
|
531 |
+
if isinstance(size, tuple):
|
532 |
+
if len(size) != 2:
|
533 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
534 |
+
self.size = size
|
535 |
+
else:
|
536 |
+
self.size = (size, size)
|
537 |
+
|
538 |
+
self.interpolation_mode = interpolation_mode
|
539 |
+
|
540 |
+
|
541 |
+
def __call__(self, clip):
|
542 |
+
"""
|
543 |
+
Args:
|
544 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
545 |
+
Returns:
|
546 |
+
torch.tensor: center cropped video clip.
|
547 |
+
size is (T, C, crop_size, crop_size)
|
548 |
+
"""
|
549 |
+
clip_center_crop = center_crop(clip, self.size)
|
550 |
+
return clip_center_crop
|
551 |
+
|
552 |
+
def __repr__(self) -> str:
|
553 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
554 |
+
|
555 |
+
|
556 |
+
class NormalizeVideo:
|
557 |
+
"""
|
558 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
559 |
+
Args:
|
560 |
+
mean (3-tuple): pixel RGB mean
|
561 |
+
std (3-tuple): pixel RGB standard deviation
|
562 |
+
inplace (boolean): whether do in-place normalization
|
563 |
+
"""
|
564 |
+
|
565 |
+
def __init__(self, mean, std, inplace=False):
|
566 |
+
self.mean = mean
|
567 |
+
self.std = std
|
568 |
+
self.inplace = inplace
|
569 |
+
|
570 |
+
def __call__(self, clip):
|
571 |
+
"""
|
572 |
+
Args:
|
573 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
574 |
+
"""
|
575 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
576 |
+
|
577 |
+
def __repr__(self) -> str:
|
578 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
579 |
+
|
580 |
+
|
581 |
+
class ToTensorVideo:
|
582 |
+
"""
|
583 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
584 |
+
permute the dimensions of clip tensor
|
585 |
+
"""
|
586 |
+
|
587 |
+
def __init__(self):
|
588 |
+
pass
|
589 |
+
|
590 |
+
def __call__(self, clip):
|
591 |
+
"""
|
592 |
+
Args:
|
593 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
594 |
+
Return:
|
595 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
596 |
+
"""
|
597 |
+
return to_tensor(clip)
|
598 |
+
|
599 |
+
def __repr__(self) -> str:
|
600 |
+
return self.__class__.__name__
|
601 |
+
|
602 |
+
|
603 |
+
class RandomHorizontalFlipVideo:
|
604 |
+
"""
|
605 |
+
Flip the video clip along the horizontal direction with a given probability
|
606 |
+
Args:
|
607 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(self, p=0.5):
|
611 |
+
self.p = p
|
612 |
+
|
613 |
+
def __call__(self, clip):
|
614 |
+
"""
|
615 |
+
Args:
|
616 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
617 |
+
Return:
|
618 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
619 |
+
"""
|
620 |
+
if random.random() < self.p:
|
621 |
+
clip = hflip(clip)
|
622 |
+
return clip
|
623 |
+
|
624 |
+
def __repr__(self) -> str:
|
625 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
626 |
+
|
627 |
+
class Compose:
|
628 |
+
"""Composes several transforms together. This transform does not support torchscript.
|
629 |
+
Please, see the note below.
|
630 |
+
|
631 |
+
Args:
|
632 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
633 |
+
|
634 |
+
Example:
|
635 |
+
>>> transforms.Compose([
|
636 |
+
>>> transforms.CenterCrop(10),
|
637 |
+
>>> transforms.PILToTensor(),
|
638 |
+
>>> transforms.ConvertImageDtype(torch.float),
|
639 |
+
>>> ])
|
640 |
+
|
641 |
+
.. note::
|
642 |
+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
|
643 |
+
|
644 |
+
>>> transforms = torch.nn.Sequential(
|
645 |
+
>>> transforms.CenterCrop(10),
|
646 |
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
647 |
+
>>> )
|
648 |
+
>>> scripted_transforms = torch.jit.script(transforms)
|
649 |
+
|
650 |
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
651 |
+
`lambda` functions or ``PIL.Image``.
|
652 |
+
|
653 |
+
"""
|
654 |
+
|
655 |
+
def __init__(self, transforms):
|
656 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
657 |
+
_log_api_usage_once(self)
|
658 |
+
self.transforms = transforms
|
659 |
+
|
660 |
+
def __call__(self, img):
|
661 |
+
for t in self.transforms:
|
662 |
+
if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL):
|
663 |
+
img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img)
|
664 |
+
else:
|
665 |
+
img = t(img)
|
666 |
+
return img, ori_h, ori_w, crops_coords_top, crops_coords_left
|
667 |
+
|
668 |
+
def __repr__(self) -> str:
|
669 |
+
format_string = self.__class__.__name__ + "("
|
670 |
+
for t in self.transforms:
|
671 |
+
format_string += "\n"
|
672 |
+
format_string += f" {t}"
|
673 |
+
format_string += "\n)"
|
674 |
+
return format_string
|
675 |
+
|
676 |
+
# ------------------------------------------------------------
|
677 |
+
# --------------------- Sampling ---------------------------
|
678 |
+
# ------------------------------------------------------------
|
679 |
+
class TemporalRandomCrop(object):
|
680 |
+
"""Temporally crop the given frame indices at a random location.
|
681 |
+
|
682 |
+
Args:
|
683 |
+
size (int): Desired length of frames will be seen in the model.
|
684 |
+
"""
|
685 |
+
|
686 |
+
def __init__(self, size):
|
687 |
+
self.size = size
|
688 |
+
|
689 |
+
def __call__(self, total_frames):
|
690 |
+
rand_end = max(0, total_frames - self.size - 1)
|
691 |
+
begin_index = random.randint(0, rand_end)
|
692 |
+
end_index = min(begin_index + self.size, total_frames)
|
693 |
+
return begin_index, end_index
|
694 |
+
|
695 |
+
|
696 |
+
if __name__ == '__main__':
|
697 |
+
from torchvision import transforms
|
698 |
+
import torchvision.io as io
|
699 |
+
import numpy as np
|
700 |
+
from torchvision.utils import save_image
|
701 |
+
import os
|
702 |
+
|
703 |
+
vframes, aframes, info = io.read_video(
|
704 |
+
filename='./v_Archery_g01_c03.avi',
|
705 |
+
pts_unit='sec',
|
706 |
+
output_format='TCHW'
|
707 |
+
)
|
708 |
+
|
709 |
+
trans = transforms.Compose([
|
710 |
+
ToTensorVideo(),
|
711 |
+
RandomHorizontalFlipVideo(),
|
712 |
+
UCFCenterCropVideo(512),
|
713 |
+
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
714 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
715 |
+
])
|
716 |
+
|
717 |
+
target_video_len = 32
|
718 |
+
frame_interval = 1
|
719 |
+
total_frames = len(vframes)
|
720 |
+
print(total_frames)
|
721 |
+
|
722 |
+
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
723 |
+
|
724 |
+
|
725 |
+
# Sampling video frames
|
726 |
+
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
727 |
+
# print(start_frame_ind)
|
728 |
+
# print(end_frame_ind)
|
729 |
+
assert end_frame_ind - start_frame_ind >= target_video_len
|
730 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
|
731 |
+
print(frame_indice)
|
732 |
+
|
733 |
+
select_vframes = vframes[frame_indice]
|
734 |
+
print(select_vframes.shape)
|
735 |
+
print(select_vframes.dtype)
|
736 |
+
|
737 |
+
select_vframes_trans = trans(select_vframes)
|
738 |
+
print(select_vframes_trans.shape)
|
739 |
+
print(select_vframes_trans.dtype)
|
740 |
+
|
741 |
+
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
|
742 |
+
print(select_vframes_trans_int.dtype)
|
743 |
+
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
744 |
+
|
745 |
+
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
|
746 |
+
|
747 |
+
for i in range(target_video_len):
|
748 |
+
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
|
demo.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
|
8 |
+
from pipelines.pipeline_videogen import VideoGenPipeline
|
9 |
+
from diffusers.schedulers import DDIMScheduler
|
10 |
+
from diffusers.models import AutoencoderKL
|
11 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
12 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
import os, sys
|
16 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
17 |
+
from models import get_models
|
18 |
+
import imageio
|
19 |
+
from PIL import Image
|
20 |
+
import numpy as np
|
21 |
+
from datasets import video_transforms
|
22 |
+
from torchvision import transforms
|
23 |
+
from einops import rearrange, repeat
|
24 |
+
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
|
25 |
+
from copy import deepcopy
|
26 |
+
import spaces
|
27 |
+
import requests
|
28 |
+
from datetime import datetime
|
29 |
+
import random
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
|
33 |
+
args = parser.parse_args()
|
34 |
+
args = OmegaConf.load(args.config)
|
35 |
+
|
36 |
+
torch.set_grad_enabled(False)
|
37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
+
dtype = torch.float16 # torch.float16
|
39 |
+
|
40 |
+
unet = get_models(args).to(device, dtype=dtype)
|
41 |
+
|
42 |
+
if args.enable_vae_temporal_decoder:
|
43 |
+
if args.use_dct:
|
44 |
+
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
|
45 |
+
else:
|
46 |
+
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
|
47 |
+
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
|
48 |
+
else:
|
49 |
+
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
|
50 |
+
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
|
51 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
52 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
|
53 |
+
|
54 |
+
# set eval mode
|
55 |
+
unet.eval()
|
56 |
+
vae.eval()
|
57 |
+
text_encoder.eval()
|
58 |
+
|
59 |
+
basedir = os.getcwd()
|
60 |
+
savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
61 |
+
savedir_sample = os.path.join(savedir, "sample")
|
62 |
+
os.makedirs(savedir, exist_ok=True)
|
63 |
+
|
64 |
+
def update_and_resize_image(input_image_path, height_slider, width_slider):
|
65 |
+
if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
|
66 |
+
pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
|
67 |
+
else:
|
68 |
+
pil_image = Image.open(input_image_path).convert('RGB')
|
69 |
+
|
70 |
+
original_width, original_height = pil_image.size
|
71 |
+
|
72 |
+
if original_height == height_slider and original_width == width_slider:
|
73 |
+
return gr.Image(value=np.array(pil_image))
|
74 |
+
|
75 |
+
ratio1 = height_slider / original_height
|
76 |
+
ratio2 = width_slider / original_width
|
77 |
+
|
78 |
+
if ratio1 > ratio2:
|
79 |
+
new_width = int(original_width * ratio1)
|
80 |
+
new_height = int(original_height * ratio1)
|
81 |
+
else:
|
82 |
+
new_width = int(original_width * ratio2)
|
83 |
+
new_height = int(original_height * ratio2)
|
84 |
+
|
85 |
+
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
86 |
+
|
87 |
+
left = (new_width - width_slider) / 2
|
88 |
+
top = (new_height - height_slider) / 2
|
89 |
+
right = left + width_slider
|
90 |
+
bottom = top + height_slider
|
91 |
+
|
92 |
+
pil_image = pil_image.crop((left, top, right, bottom))
|
93 |
+
|
94 |
+
return gr.Image(value=np.array(pil_image))
|
95 |
+
|
96 |
+
|
97 |
+
def update_textbox_and_save_image(input_image, height_slider, width_slider):
|
98 |
+
pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
|
99 |
+
|
100 |
+
original_width, original_height = pil_image.size
|
101 |
+
|
102 |
+
ratio1 = height_slider / original_height
|
103 |
+
ratio2 = width_slider / original_width
|
104 |
+
|
105 |
+
if ratio1 > ratio2:
|
106 |
+
new_width = int(original_width * ratio1)
|
107 |
+
new_height = int(original_height * ratio1)
|
108 |
+
else:
|
109 |
+
new_width = int(original_width * ratio2)
|
110 |
+
new_height = int(original_height * ratio2)
|
111 |
+
|
112 |
+
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
|
113 |
+
|
114 |
+
left = (new_width - width_slider) / 2
|
115 |
+
top = (new_height - height_slider) / 2
|
116 |
+
right = left + width_slider
|
117 |
+
bottom = top + height_slider
|
118 |
+
|
119 |
+
pil_image = pil_image.crop((left, top, right, bottom))
|
120 |
+
|
121 |
+
img_path = os.path.join(savedir, "input_image.png")
|
122 |
+
pil_image.save(img_path)
|
123 |
+
|
124 |
+
return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
|
125 |
+
|
126 |
+
def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
|
127 |
+
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
|
128 |
+
image = transform_video(image)
|
129 |
+
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
|
130 |
+
image = image.unsqueeze(2)
|
131 |
+
return image
|
132 |
+
|
133 |
+
|
134 |
+
@spaces.GPU
|
135 |
+
def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
|
136 |
+
|
137 |
+
torch.manual_seed(seed)
|
138 |
+
|
139 |
+
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
|
140 |
+
subfolder="scheduler",
|
141 |
+
beta_start=args.beta_start,
|
142 |
+
beta_end=args.beta_end,
|
143 |
+
beta_schedule=args.beta_schedule)
|
144 |
+
|
145 |
+
videogen_pipeline = VideoGenPipeline(vae=vae,
|
146 |
+
text_encoder=text_encoder,
|
147 |
+
tokenizer=tokenizer,
|
148 |
+
scheduler=scheduler,
|
149 |
+
unet=unet).to(device)
|
150 |
+
# videogen_pipeline.enable_xformers_memory_efficient_attention()
|
151 |
+
|
152 |
+
transform_video = transforms.Compose([
|
153 |
+
video_transforms.ToTensorVideo(),
|
154 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
155 |
+
])
|
156 |
+
|
157 |
+
if args.use_dct:
|
158 |
+
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
|
159 |
+
else:
|
160 |
+
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
|
161 |
+
|
162 |
+
if use_dctinit:
|
163 |
+
# filter params
|
164 |
+
print("Using DCT!")
|
165 |
+
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
|
166 |
+
|
167 |
+
# define filter
|
168 |
+
freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
|
169 |
+
|
170 |
+
noise = torch.randn(1, 4, 15, 40, 64).to(device)
|
171 |
+
|
172 |
+
# add noise to base_content
|
173 |
+
diffuse_timesteps = torch.full((1,),int(noise_level))
|
174 |
+
diffuse_timesteps = diffuse_timesteps.long()
|
175 |
+
|
176 |
+
# 3d content
|
177 |
+
base_content_noise = scheduler.add_noise(
|
178 |
+
original_samples=base_content_repeat.to(device),
|
179 |
+
noise=noise,
|
180 |
+
timesteps=diffuse_timesteps.to(device))
|
181 |
+
|
182 |
+
# 3d content
|
183 |
+
latents = exchanged_mixed_dct_freq(noise=noise,
|
184 |
+
base_content=base_content_noise,
|
185 |
+
LPF_3d=freq_filter).to(dtype=torch.float16)
|
186 |
+
|
187 |
+
base_content = base_content.to(dtype=torch.float16)
|
188 |
+
|
189 |
+
videos = videogen_pipeline(prompt,
|
190 |
+
negative_prompt=negative_prompt,
|
191 |
+
latents=latents if use_dctinit else None,
|
192 |
+
base_content=base_content,
|
193 |
+
video_length=15,
|
194 |
+
height=height,
|
195 |
+
width=width,
|
196 |
+
num_inference_steps=diffusion_step,
|
197 |
+
guidance_scale=scfg_scale,
|
198 |
+
motion_bucket_id=100-motion_bucket_id,
|
199 |
+
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
|
200 |
+
|
201 |
+
save_path = args.save_img_path + 'temp' + '.mp4'
|
202 |
+
# torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
|
203 |
+
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
|
204 |
+
return save_path
|
205 |
+
|
206 |
+
|
207 |
+
if not os.path.exists(args.save_img_path):
|
208 |
+
os.makedirs(args.save_img_path)
|
209 |
+
|
210 |
+
|
211 |
+
with gr.Blocks() as demo:
|
212 |
+
|
213 |
+
gr.Markdown("<font color=red size=6.5><center>Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models</center></font>")
|
214 |
+
gr.Markdown(
|
215 |
+
"""<div style="display: flex;align-items: center;justify-content: center">
|
216 |
+
[<a href="https://arxiv.org/abs/2407.15642">Arxiv Report</a>] | [<a href="https://https://maxin-cn.github.io/cinemo_project/">Project Page</a>] | [<a href="https://github.com/maxin-cn/Cinemo">Github</a>]</div>
|
217 |
+
"""
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
with gr.Column(variant="panel"):
|
222 |
+
gr.Markdown(
|
223 |
+
"""
|
224 |
+
- Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
|
225 |
+
- Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
|
226 |
+
- After setting the input image path, press the "Preview" button to visualize the resized input image.
|
227 |
+
"""
|
228 |
+
)
|
229 |
+
|
230 |
+
with gr.Row():
|
231 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=1)
|
232 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
|
233 |
+
|
234 |
+
with gr.Row(equal_height=False):
|
235 |
+
with gr.Column():
|
236 |
+
with gr.Row():
|
237 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
|
241 |
+
# seed_textbox = gr.Textbox(label="Seed", value=100)
|
242 |
+
# seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
243 |
+
# seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
|
247 |
+
width = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=16, interactive=False)
|
248 |
+
with gr.Row():
|
249 |
+
txt_cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1, interactive=True)
|
250 |
+
motion_bucket_id = gr.Slider(label="Motion Intensity", value=10, minimum=1, maximum=20, step=1, interactive=True)
|
251 |
+
|
252 |
+
with gr.Row():
|
253 |
+
use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
|
254 |
+
dct_coefficients = gr.Slider(label="DCT Coefficients", value=0.23, minimum=0, maximum=1, step=0.01, interactive=True)
|
255 |
+
noise_level = gr.Slider(label="Noise Level", value=985, minimum=1, maximum=999, step=1, interactive=True)
|
256 |
+
|
257 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
258 |
+
|
259 |
+
with gr.Column():
|
260 |
+
with gr.Row():
|
261 |
+
input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
|
262 |
+
preview_button = gr.Button(value="Preview")
|
263 |
+
|
264 |
+
with gr.Row():
|
265 |
+
input_image = gr.Image(label="Input Image", interactive=True)
|
266 |
+
input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height, width], outputs=[input_image_path, input_image])
|
267 |
+
result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
|
268 |
+
|
269 |
+
preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
270 |
+
input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
|
271 |
+
|
272 |
+
EXAMPLES = [
|
273 |
+
["./example/aircrafts_flying/0.jpg", "aircrafts flying", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
274 |
+
["./example/fireworks/0.jpg", "fireworks", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
275 |
+
["./example/flowers_swaying/0.jpg", "flowers swaying", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
276 |
+
["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach", "", 50, 320, 512, 7.5, True, 0.23, 985, 10, 200],
|
277 |
+
["./example/house_rotating/0.jpg", "house rotating", "", 50, 320, 512, 7.5, True, 0.23, 985, 10, 100],
|
278 |
+
["./example/people_runing/0.jpg", "people runing", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
|
279 |
+
]
|
280 |
+
|
281 |
+
examples = gr.Examples(
|
282 |
+
examples = EXAMPLES,
|
283 |
+
fn = gen_video,
|
284 |
+
inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
|
285 |
+
outputs=[result_video],
|
286 |
+
# cache_examples=True,
|
287 |
+
cache_examples="lazy",
|
288 |
+
)
|
289 |
+
|
290 |
+
generate_button.click(
|
291 |
+
fn=gen_video,
|
292 |
+
inputs=[
|
293 |
+
input_image,
|
294 |
+
prompt_textbox,
|
295 |
+
negative_prompt_textbox,
|
296 |
+
sample_step_slider,
|
297 |
+
height,
|
298 |
+
width,
|
299 |
+
txt_cfg_scale,
|
300 |
+
use_dctinit,
|
301 |
+
dct_coefficients,
|
302 |
+
noise_level,
|
303 |
+
motion_bucket_id,
|
304 |
+
seed_textbox,
|
305 |
+
],
|
306 |
+
outputs=[result_video]
|
307 |
+
)
|
308 |
+
|
309 |
+
demo.launch(debug=False, share=True)
|
310 |
+
|
311 |
+
# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
|
environment.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cinemo
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python >= 3.10
|
7 |
+
- pytorch >= 2.0
|
8 |
+
- torchvision
|
9 |
+
- pytorch-cuda >= 11.7
|
10 |
+
- pip:
|
11 |
+
- timm
|
12 |
+
- diffusers[torch]==0.24.0
|
13 |
+
- accelerate
|
14 |
+
- python-hostlist
|
15 |
+
- tensorboard
|
16 |
+
- einops
|
17 |
+
- transformers
|
18 |
+
- av
|
19 |
+
- scikit-image
|
20 |
+
- decord
|
21 |
+
- pandas
|
example/aircrafts_flying/0.jpg
ADDED
example/aircrafts_flying/aircrafts_flying.mp4
ADDED
Binary file (533 kB). View file
|
|
example/car_moving/0.jpg
ADDED
example/car_moving/car_moving.mp4
ADDED
Binary file (399 kB). View file
|
|
example/fireworks/0.jpg
ADDED
example/fireworks/fireworks.mp4
ADDED
Binary file (479 kB). View file
|
|
example/flowers_swaying/0.jpg
ADDED
example/flowers_swaying/flowers_swaying.mp4
ADDED
Binary file (469 kB). View file
|
|
example/girl_walking_on_the_beach/0.jpg
ADDED
example/girl_walking_on_the_beach/girl_walking_on_the_beach.mp4
ADDED
Binary file (619 kB). View file
|
|
example/house_rotating/0.jpg
ADDED
example/house_rotating/house_rotating.mp4
ADDED
Binary file (481 kB). View file
|
|
example/people_runing/0.jpg
ADDED
example/people_runing/people_runing.mp4
ADDED
Binary file (482 kB). View file
|
|
example/shark_swimming/0.jpg
ADDED
example/shark_swimming/shark_swimming.mp4
ADDED
Binary file (282 kB). View file
|
|
example/windmill_turning/0.jpg
ADDED
example/windmill_turning/windmill_turning.mp4
ADDED
Binary file (403 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4
ADDED
Binary file (226 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4
ADDED
Binary file (223 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/.nfs6a1237621cfe7a8800009149
ADDED
Binary file (334 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4
ADDED
Binary file (619 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4
ADDED
Binary file (272 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/.nfs88c2a0e49709591000009148
ADDED
Binary file (352 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4
ADDED
Binary file (481 kB). View file
|
|
gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4
ADDED
Binary file (206 kB). View file
|
|
gradio_cached_examples/39/indices.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0
|
2 |
+
1
|
3 |
+
2
|
4 |
+
3
|
5 |
+
4
|
6 |
+
5
|
gradio_cached_examples/39/log.csv
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Generated Animation,flag,username,timestamp
|
2 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/de8039b347e55995b4cb/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/24d44f90b57c00bd6fdccd61cc35a4f4d459388c/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:42:07.603268
|
3 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/b54545fbdd15c944208e/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/deb9620f616c3681cb074388781099f78a25dc8f/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:42:25.127236
|
4 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/5e69f32e801f7ae77024/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/d9e392300169a439b3f5721579849e3e5ce6abf9/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:44:01.949003
|
5 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/b12875c4b9b633b752c4/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/d1a5066828be61057823cf98edae890db907f358/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:00.196580
|
6 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/cf8ea2ef6e0b7eeb7fe6/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/5998620333bb5bdaf52b49f2de86e428df991431/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:28.848377
|
7 |
+
"{""video"": {""path"": ""gradio_cached_examples/39/Generated Animation/98ce26b896864325a1dd/temp.mp4"", ""url"": ""/file=/data/pe1/000scratch/slurm_tmpdir/20240727_job_53250001.VBWa/gradio/7e3c6838256fd5a89b20ad62ce42288212e62097/temp.mp4"", ""size"": null, ""orig_name"": ""temp.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-27 14:45:48.004623
|
models/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.split(sys.path[0])[0])
|
4 |
+
|
5 |
+
from .unet import UNet3DConditionModel
|
6 |
+
from torch.optim.lr_scheduler import LambdaLR
|
7 |
+
|
8 |
+
def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
def fn(step):
|
11 |
+
if warmup_steps > 0:
|
12 |
+
return min(step / warmup_steps, 1)
|
13 |
+
else:
|
14 |
+
return 1
|
15 |
+
return LambdaLR(optimizer, fn)
|
16 |
+
|
17 |
+
|
18 |
+
def get_lr_scheduler(optimizer, name, **kwargs):
|
19 |
+
if name == 'warmup':
|
20 |
+
return customized_lr_scheduler(optimizer, **kwargs)
|
21 |
+
elif name == 'cosine':
|
22 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
23 |
+
return CosineAnnealingLR(optimizer, **kwargs)
|
24 |
+
else:
|
25 |
+
raise NotImplementedError(name)
|
26 |
+
|
27 |
+
def get_models(args):
|
28 |
+
|
29 |
+
if 'UNet' in args.model:
|
30 |
+
return UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet")
|
31 |
+
else:
|
32 |
+
raise '{} Model Not Supported!'.format(args.model)
|
33 |
+
|
models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.73 kB). View file
|
|
models/__pycache__/attention.cpython-312.pyc
ADDED
Binary file (20.6 kB). View file
|
|
models/__pycache__/resnet.cpython-312.pyc
ADDED
Binary file (14.4 kB). View file
|
|
models/__pycache__/rotary_embedding_torch_mx.cpython-312.pyc
ADDED
Binary file (10.6 kB). View file
|
|
models/__pycache__/temporal_attention.cpython-312.pyc
ADDED
Binary file (21.7 kB). View file
|
|