Spaces:
Running
Running
bubbliiiing
commited on
Commit
•
e262715
1
Parent(s):
788d423
update v3
Browse files- .gitignore +160 -0
- app.py +3 -3
- easyanimate/api/api.py +38 -4
- easyanimate/api/post_infer.py +9 -7
- easyanimate/data/dataset_image_video.py +64 -3
- easyanimate/models/attention.py +196 -139
- easyanimate/models/autoencoder_magvit.py +9 -3
- easyanimate/models/motion_module.py +146 -277
- easyanimate/models/norm.py +97 -0
- easyanimate/models/patch.py +1 -1
- easyanimate/models/transformer3d.py +81 -75
- easyanimate/pipeline/pipeline_easyanimate.py +1 -1
- easyanimate/pipeline/pipeline_easyanimate_inpaint.py +257 -91
- easyanimate/ui/ui.py +810 -173
- easyanimate/utils/utils.py +107 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
app.py
CHANGED
@@ -11,9 +11,9 @@ if __name__ == "__main__":
|
|
11 |
server_port = 7860
|
12 |
|
13 |
# Params below is used when ui_mode = "modelscope"
|
14 |
-
edition = "
|
15 |
-
config_path = "config/
|
16 |
-
model_name = "models/Diffusion_Transformer/
|
17 |
savedir_sample = "samples"
|
18 |
|
19 |
if ui_mode == "modelscope":
|
|
|
11 |
server_port = 7860
|
12 |
|
13 |
# Params below is used when ui_mode = "modelscope"
|
14 |
+
edition = "v3"
|
15 |
+
config_path = "config/easyanimate_video_slicevae_motion_module_v3.yaml"
|
16 |
+
model_name = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-InP-512x512"
|
17 |
savedir_sample = "samples"
|
18 |
|
19 |
if ui_mode == "modelscope":
|
easyanimate/api/api.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
import io
|
|
|
2 |
import base64
|
3 |
import torch
|
4 |
import gradio as gr
|
|
|
|
|
5 |
|
6 |
from fastapi import FastAPI
|
7 |
from io import BytesIO
|
|
|
8 |
|
9 |
# Function to encode a file to Base64
|
10 |
def encode_file_to_base64(file_path):
|
@@ -59,16 +63,34 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
59 |
lora_model_path = datas.get('lora_model_path', 'none')
|
60 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
61 |
prompt_textbox = datas.get('prompt_textbox', None)
|
62 |
-
negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
|
63 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
64 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
|
|
65 |
width_slider = datas.get('width_slider', 672)
|
66 |
height_slider = datas.get('height_slider', 384)
|
|
|
67 |
is_image = datas.get('is_image', False)
|
|
|
68 |
length_slider = datas.get('length_slider', 144)
|
|
|
|
|
69 |
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
|
|
|
|
70 |
seed_textbox = datas.get("seed_textbox", 43)
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
try:
|
73 |
save_sample_path, comment = controller.generate(
|
74 |
"",
|
@@ -80,17 +102,29 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
80 |
negative_prompt_textbox,
|
81 |
sampler_dropdown,
|
82 |
sample_step_slider,
|
|
|
83 |
width_slider,
|
84 |
height_slider,
|
85 |
-
|
|
|
86 |
length_slider,
|
|
|
|
|
87 |
cfg_scale_slider,
|
|
|
|
|
88 |
seed_textbox,
|
89 |
is_api = True,
|
90 |
)
|
91 |
except Exception as e:
|
|
|
92 |
torch.cuda.empty_cache()
|
|
|
93 |
save_sample_path = ""
|
94 |
comment = f"Error. error information is {str(e)}"
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import io
|
2 |
+
import gc
|
3 |
import base64
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
+
import tempfile
|
7 |
+
import hashlib
|
8 |
|
9 |
from fastapi import FastAPI
|
10 |
from io import BytesIO
|
11 |
+
from PIL import Image
|
12 |
|
13 |
# Function to encode a file to Base64
|
14 |
def encode_file_to_base64(file_path):
|
|
|
63 |
lora_model_path = datas.get('lora_model_path', 'none')
|
64 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
65 |
prompt_textbox = datas.get('prompt_textbox', None)
|
66 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.')
|
67 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
68 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
69 |
+
resize_method = datas.get('resize_method', "Generate by")
|
70 |
width_slider = datas.get('width_slider', 672)
|
71 |
height_slider = datas.get('height_slider', 384)
|
72 |
+
base_resolution = datas.get('base_resolution', 512)
|
73 |
is_image = datas.get('is_image', False)
|
74 |
+
generation_method = datas.get('generation_method', False)
|
75 |
length_slider = datas.get('length_slider', 144)
|
76 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
77 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
78 |
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
79 |
+
start_image = datas.get('start_image', None)
|
80 |
+
end_image = datas.get('end_image', None)
|
81 |
seed_textbox = datas.get("seed_textbox", 43)
|
82 |
|
83 |
+
generation_method = "Image Generation" if is_image else generation_method
|
84 |
+
|
85 |
+
temp_directory = tempfile.gettempdir()
|
86 |
+
if start_image is not None:
|
87 |
+
start_image = base64.b64decode(start_image)
|
88 |
+
start_image = [Image.open(BytesIO(start_image))]
|
89 |
+
|
90 |
+
if end_image is not None:
|
91 |
+
end_image = base64.b64decode(end_image)
|
92 |
+
end_image = [Image.open(BytesIO(end_image))]
|
93 |
+
|
94 |
try:
|
95 |
save_sample_path, comment = controller.generate(
|
96 |
"",
|
|
|
102 |
negative_prompt_textbox,
|
103 |
sampler_dropdown,
|
104 |
sample_step_slider,
|
105 |
+
resize_method,
|
106 |
width_slider,
|
107 |
height_slider,
|
108 |
+
base_resolution,
|
109 |
+
generation_method,
|
110 |
length_slider,
|
111 |
+
overlap_video_length,
|
112 |
+
partial_video_length,
|
113 |
cfg_scale_slider,
|
114 |
+
start_image,
|
115 |
+
end_image,
|
116 |
seed_textbox,
|
117 |
is_api = True,
|
118 |
)
|
119 |
except Exception as e:
|
120 |
+
gc.collect()
|
121 |
torch.cuda.empty_cache()
|
122 |
+
torch.cuda.ipc_collect()
|
123 |
save_sample_path = ""
|
124 |
comment = f"Error. error information is {str(e)}"
|
125 |
+
return {"message": comment}
|
126 |
+
|
127 |
+
if save_sample_path != "":
|
128 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
129 |
+
else:
|
130 |
+
return {"message": comment, "save_sample_path": save_sample_path}
|
easyanimate/api/post_infer.py
CHANGED
@@ -26,7 +26,7 @@ def post_update_edition(edition, url='http://0.0.0.0:7860'):
|
|
26 |
data = r.content.decode('utf-8')
|
27 |
return data
|
28 |
|
29 |
-
def post_infer(
|
30 |
datas = json.dumps({
|
31 |
"base_model_path": "none",
|
32 |
"motion_module_path": "none",
|
@@ -38,7 +38,7 @@ def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
|
|
38 |
"sample_step_slider": 30,
|
39 |
"width_slider": 672,
|
40 |
"height_slider": 384,
|
41 |
-
"
|
42 |
"length_slider": length_slider,
|
43 |
"cfg_scale_slider": 6,
|
44 |
"seed_textbox": 43,
|
@@ -55,29 +55,31 @@ if __name__ == '__main__':
|
|
55 |
# -------------------------- #
|
56 |
# Step 1: update edition
|
57 |
# -------------------------- #
|
58 |
-
edition = "
|
59 |
outputs = post_update_edition(edition)
|
60 |
print('Output update edition: ', outputs)
|
61 |
|
62 |
# -------------------------- #
|
63 |
# Step 2: update edition
|
64 |
# -------------------------- #
|
65 |
-
diffusion_transformer_path = "
|
66 |
outputs = post_diffusion_transformer(diffusion_transformer_path)
|
67 |
print('Output update edition: ', outputs)
|
68 |
|
69 |
# -------------------------- #
|
70 |
# Step 3: infer
|
71 |
# -------------------------- #
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
# Get decoded data
|
77 |
outputs = json.loads(outputs)
|
78 |
base64_encoding = outputs["base64_encoding"]
|
79 |
decoded_data = base64.b64decode(base64_encoding)
|
80 |
|
|
|
81 |
if is_image or length_slider == 1:
|
82 |
file_path = "1.png"
|
83 |
else:
|
|
|
26 |
data = r.content.decode('utf-8')
|
27 |
return data
|
28 |
|
29 |
+
def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
|
30 |
datas = json.dumps({
|
31 |
"base_model_path": "none",
|
32 |
"motion_module_path": "none",
|
|
|
38 |
"sample_step_slider": 30,
|
39 |
"width_slider": 672,
|
40 |
"height_slider": 384,
|
41 |
+
"generation_method": "Video Generation",
|
42 |
"length_slider": length_slider,
|
43 |
"cfg_scale_slider": 6,
|
44 |
"seed_textbox": 43,
|
|
|
55 |
# -------------------------- #
|
56 |
# Step 1: update edition
|
57 |
# -------------------------- #
|
58 |
+
edition = "v3"
|
59 |
outputs = post_update_edition(edition)
|
60 |
print('Output update edition: ', outputs)
|
61 |
|
62 |
# -------------------------- #
|
63 |
# Step 2: update edition
|
64 |
# -------------------------- #
|
65 |
+
diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-512x512"
|
66 |
outputs = post_diffusion_transformer(diffusion_transformer_path)
|
67 |
print('Output update edition: ', outputs)
|
68 |
|
69 |
# -------------------------- #
|
70 |
# Step 3: infer
|
71 |
# -------------------------- #
|
72 |
+
# "Video Generation" and "Image Generation"
|
73 |
+
generation_method = "Video Generation"
|
74 |
+
length_slider = 72
|
75 |
+
outputs = post_infer(generation_method, length_slider)
|
76 |
|
77 |
# Get decoded data
|
78 |
outputs = json.loads(outputs)
|
79 |
base64_encoding = outputs["base64_encoding"]
|
80 |
decoded_data = base64.b64decode(base64_encoding)
|
81 |
|
82 |
+
is_image = True if generation_method == "Image Generation" else False
|
83 |
if is_image or length_slider == 1:
|
84 |
file_path = "1.png"
|
85 |
else:
|
easyanimate/data/dataset_image_video.py
CHANGED
@@ -12,6 +12,7 @@ import gc
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import torchvision.transforms as transforms
|
|
|
15 |
from func_timeout import func_timeout, FunctionTimedOut
|
16 |
from decord import VideoReader
|
17 |
from PIL import Image
|
@@ -21,6 +22,52 @@ from contextlib import contextmanager
|
|
21 |
|
22 |
VIDEO_READER_TIMEOUT = 20
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
class ImageVideoSampler(BatchSampler):
|
25 |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
26 |
|
@@ -88,10 +135,11 @@ class ImageVideoDataset(Dataset):
|
|
88 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
89 |
image_sample_size=512,
|
90 |
video_repeat=0,
|
91 |
-
text_drop_ratio
|
92 |
enable_bucket=False,
|
93 |
video_length_drop_start=0.1,
|
94 |
video_length_drop_end=0.9,
|
|
|
95 |
):
|
96 |
# Loading annotations from files
|
97 |
print(f"loading annotations from {ann_path} ...")
|
@@ -120,6 +168,8 @@ class ImageVideoDataset(Dataset):
|
|
120 |
# TODO: enable bucket training
|
121 |
self.enable_bucket = enable_bucket
|
122 |
self.text_drop_ratio = text_drop_ratio
|
|
|
|
|
123 |
self.video_length_drop_start = video_length_drop_start
|
124 |
self.video_length_drop_end = video_length_drop_end
|
125 |
|
@@ -165,7 +215,7 @@ class ImageVideoDataset(Dataset):
|
|
165 |
|
166 |
video_length = int(self.video_length_drop_end * len(video_reader))
|
167 |
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
168 |
-
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
|
169 |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
170 |
|
171 |
try:
|
@@ -230,6 +280,17 @@ class ImageVideoDataset(Dataset):
|
|
230 |
except Exception as e:
|
231 |
print(e, self.dataset[idx % len(self.dataset)])
|
232 |
idx = random.randint(0, self.length-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
return sample
|
234 |
|
235 |
if __name__ == "__main__":
|
@@ -238,4 +299,4 @@ if __name__ == "__main__":
|
|
238 |
)
|
239 |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
|
240 |
for idx, batch in enumerate(dataloader):
|
241 |
-
print(batch["pixel_values"].shape, len(batch["text"]))
|
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import torchvision.transforms as transforms
|
15 |
+
|
16 |
from func_timeout import func_timeout, FunctionTimedOut
|
17 |
from decord import VideoReader
|
18 |
from PIL import Image
|
|
|
22 |
|
23 |
VIDEO_READER_TIMEOUT = 20
|
24 |
|
25 |
+
def get_random_mask(shape):
|
26 |
+
f, c, h, w = shape
|
27 |
+
|
28 |
+
if f != 1:
|
29 |
+
mask_index = np.random.randint(1, 4)
|
30 |
+
else:
|
31 |
+
mask_index = np.random.randint(1, 2)
|
32 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
33 |
+
|
34 |
+
if mask_index == 0:
|
35 |
+
center_x = torch.randint(0, w, (1,)).item()
|
36 |
+
center_y = torch.randint(0, h, (1,)).item()
|
37 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
38 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
39 |
+
|
40 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
41 |
+
end_x = min(center_x + block_size_x // 2, w)
|
42 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
43 |
+
end_y = min(center_y + block_size_y // 2, h)
|
44 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
45 |
+
elif mask_index == 1:
|
46 |
+
mask[:, :, :, :] = 1
|
47 |
+
elif mask_index == 2:
|
48 |
+
mask_frame_index = np.random.randint(1, 5)
|
49 |
+
mask[mask_frame_index:, :, :, :] = 1
|
50 |
+
elif mask_index == 3:
|
51 |
+
mask_frame_index = np.random.randint(1, 5)
|
52 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
53 |
+
elif mask_index == 4:
|
54 |
+
center_x = torch.randint(0, w, (1,)).item()
|
55 |
+
center_y = torch.randint(0, h, (1,)).item()
|
56 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
57 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
58 |
+
|
59 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
60 |
+
end_x = min(center_x + block_size_x // 2, w)
|
61 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
62 |
+
end_y = min(center_y + block_size_y // 2, h)
|
63 |
+
|
64 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
65 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
66 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
67 |
+
else:
|
68 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
69 |
+
return mask
|
70 |
+
|
71 |
class ImageVideoSampler(BatchSampler):
|
72 |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
73 |
|
|
|
135 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
136 |
image_sample_size=512,
|
137 |
video_repeat=0,
|
138 |
+
text_drop_ratio=-1,
|
139 |
enable_bucket=False,
|
140 |
video_length_drop_start=0.1,
|
141 |
video_length_drop_end=0.9,
|
142 |
+
enable_inpaint=False,
|
143 |
):
|
144 |
# Loading annotations from files
|
145 |
print(f"loading annotations from {ann_path} ...")
|
|
|
168 |
# TODO: enable bucket training
|
169 |
self.enable_bucket = enable_bucket
|
170 |
self.text_drop_ratio = text_drop_ratio
|
171 |
+
self.enable_inpaint = enable_inpaint
|
172 |
+
|
173 |
self.video_length_drop_start = video_length_drop_start
|
174 |
self.video_length_drop_end = video_length_drop_end
|
175 |
|
|
|
215 |
|
216 |
video_length = int(self.video_length_drop_end * len(video_reader))
|
217 |
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
218 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
219 |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
220 |
|
221 |
try:
|
|
|
280 |
except Exception as e:
|
281 |
print(e, self.dataset[idx % len(self.dataset)])
|
282 |
idx = random.randint(0, self.length-1)
|
283 |
+
|
284 |
+
if self.enable_inpaint and not self.enable_bucket:
|
285 |
+
mask = get_random_mask(pixel_values.size())
|
286 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
287 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
288 |
+
sample["mask"] = mask
|
289 |
+
|
290 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
291 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
292 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
293 |
+
|
294 |
return sample
|
295 |
|
296 |
if __name__ == "__main__":
|
|
|
299 |
)
|
300 |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
|
301 |
for idx, batch in enumerate(dataloader):
|
302 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
easyanimate/models/attention.py
CHANGED
@@ -11,17 +11,25 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
-
import math
|
15 |
from typing import Any, Dict, Optional
|
16 |
|
|
|
|
|
17 |
import torch
|
18 |
import torch.nn.functional as F
|
19 |
import torch.nn.init as init
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from diffusers.models.attention import AdaLayerNorm, FeedForward
|
22 |
-
from diffusers.models.attention_processor import Attention
|
23 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
24 |
-
from diffusers.models.lora import LoRACompatibleLinear
|
25 |
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
|
26 |
from diffusers.utils import USE_PEFT_BACKEND
|
27 |
from diffusers.utils.import_utils import is_xformers_available
|
@@ -29,7 +37,8 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
29 |
from einops import rearrange, repeat
|
30 |
from torch import nn
|
31 |
|
32 |
-
from .motion_module import get_motion_module
|
|
|
33 |
|
34 |
if is_xformers_available():
|
35 |
import xformers
|
@@ -38,6 +47,13 @@ else:
|
|
38 |
xformers = None
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
@maybe_allow_in_graph
|
42 |
class GatedSelfAttentionDense(nn.Module):
|
43 |
r"""
|
@@ -59,8 +75,8 @@ class GatedSelfAttentionDense(nn.Module):
|
|
59 |
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
60 |
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
61 |
|
62 |
-
self.norm1 =
|
63 |
-
self.norm2 =
|
64 |
|
65 |
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
66 |
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
@@ -80,14 +96,6 @@ class GatedSelfAttentionDense(nn.Module):
|
|
80 |
return x
|
81 |
|
82 |
|
83 |
-
def zero_module(module):
|
84 |
-
# Zero out the parameters of a module and return it.
|
85 |
-
for p in module.parameters():
|
86 |
-
p.detach().zero_()
|
87 |
-
return module
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
class KVCompressionCrossAttention(nn.Module):
|
92 |
r"""
|
93 |
A cross attention layer.
|
@@ -154,7 +162,7 @@ class KVCompressionCrossAttention(nn.Module):
|
|
154 |
stride=2,
|
155 |
bias=True
|
156 |
)
|
157 |
-
self.kv_compression_norm =
|
158 |
init.constant_(self.kv_compression.weight, 1 / 4)
|
159 |
if self.kv_compression.bias is not None:
|
160 |
init.constant_(self.kv_compression.bias, 0)
|
@@ -410,6 +418,8 @@ class TemporalTransformerBlock(nn.Module):
|
|
410 |
# motion module kwargs
|
411 |
motion_module_type = "VanillaGrid",
|
412 |
motion_module_kwargs = None,
|
|
|
|
|
413 |
):
|
414 |
super().__init__()
|
415 |
self.only_cross_attention = only_cross_attention
|
@@ -442,7 +452,7 @@ class TemporalTransformerBlock(nn.Module):
|
|
442 |
elif self.use_ada_layer_norm_zero:
|
443 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
444 |
else:
|
445 |
-
self.norm1 =
|
446 |
|
447 |
self.kvcompression = kvcompression
|
448 |
if kvcompression:
|
@@ -456,16 +466,28 @@ class TemporalTransformerBlock(nn.Module):
|
|
456 |
upcast_attention=upcast_attention,
|
457 |
)
|
458 |
else:
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
self.attn_temporal = get_motion_module(
|
471 |
in_channels = dim,
|
@@ -481,27 +503,45 @@ class TemporalTransformerBlock(nn.Module):
|
|
481 |
self.norm2 = (
|
482 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
483 |
if self.use_ada_layer_norm
|
484 |
-
else
|
485 |
)
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
else:
|
496 |
self.norm2 = None
|
497 |
self.attn2 = None
|
498 |
|
499 |
# 3. Feed-forward
|
500 |
if not self.use_ada_layer_norm_single:
|
501 |
-
self.norm3 =
|
502 |
|
503 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
504 |
|
|
|
|
|
|
|
|
|
|
|
505 |
# 4. Fuser
|
506 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
507 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
@@ -654,6 +694,9 @@ class TemporalTransformerBlock(nn.Module):
|
|
654 |
)
|
655 |
else:
|
656 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
|
|
|
|
|
|
657 |
|
658 |
if self.use_ada_layer_norm_zero:
|
659 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
@@ -723,6 +766,8 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
723 |
attention_type: str = "default",
|
724 |
positional_embeddings: Optional[str] = None,
|
725 |
num_positional_embeddings: Optional[int] = None,
|
|
|
|
|
726 |
):
|
727 |
super().__init__()
|
728 |
self.only_cross_attention = only_cross_attention
|
@@ -755,17 +800,30 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
755 |
elif self.use_ada_layer_norm_zero:
|
756 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
757 |
else:
|
758 |
-
self.norm1 =
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
# 2. Cross-Attn
|
771 |
if cross_attention_dim is not None or double_self_attention:
|
@@ -775,27 +833,45 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
775 |
self.norm2 = (
|
776 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
777 |
if self.use_ada_layer_norm
|
778 |
-
else
|
779 |
)
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
else:
|
790 |
self.norm2 = None
|
791 |
self.attn2 = None
|
792 |
|
793 |
# 3. Feed-forward
|
794 |
if not self.use_ada_layer_norm_single:
|
795 |
-
self.norm3 =
|
796 |
|
797 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
798 |
|
|
|
|
|
|
|
|
|
|
|
799 |
# 4. Fuser
|
800 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
801 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
@@ -927,6 +1003,9 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
|
|
927 |
)
|
928 |
else:
|
929 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
|
|
|
|
|
|
930 |
|
931 |
if self.use_ada_layer_norm_zero:
|
932 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
@@ -997,6 +1076,8 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
997 |
positional_embeddings: Optional[str] = None,
|
998 |
num_positional_embeddings: Optional[int] = None,
|
999 |
kvcompression: Optional[bool] = False,
|
|
|
|
|
1000 |
):
|
1001 |
super().__init__()
|
1002 |
self.only_cross_attention = only_cross_attention
|
@@ -1029,7 +1110,7 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
1029 |
elif self.use_ada_layer_norm_zero:
|
1030 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
1031 |
else:
|
1032 |
-
self.norm1 =
|
1033 |
|
1034 |
self.kvcompression = kvcompression
|
1035 |
if kvcompression:
|
@@ -1043,16 +1124,28 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
1043 |
upcast_attention=upcast_attention,
|
1044 |
)
|
1045 |
else:
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1056 |
|
1057 |
# 2. Cross-Attn
|
1058 |
if cross_attention_dim is not None or double_self_attention:
|
@@ -1062,27 +1155,45 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
1062 |
self.norm2 = (
|
1063 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
1064 |
if self.use_ada_layer_norm
|
1065 |
-
else
|
1066 |
)
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
1071 |
-
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
else:
|
1077 |
self.norm2 = None
|
1078 |
self.attn2 = None
|
1079 |
|
1080 |
# 3. Feed-forward
|
1081 |
if not self.use_ada_layer_norm_single:
|
1082 |
-
self.norm3 =
|
1083 |
|
1084 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
1085 |
|
|
|
|
|
|
|
|
|
|
|
1086 |
# 4. Fuser
|
1087 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
1088 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
@@ -1229,6 +1340,9 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
1229 |
)
|
1230 |
else:
|
1231 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
|
|
|
|
|
|
1232 |
|
1233 |
if self.use_ada_layer_norm_zero:
|
1234 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
@@ -1239,61 +1353,4 @@ class KVCompressionTransformerBlock(nn.Module):
|
|
1239 |
if hidden_states.ndim == 4:
|
1240 |
hidden_states = hidden_states.squeeze(1)
|
1241 |
|
1242 |
-
return hidden_states
|
1243 |
-
|
1244 |
-
|
1245 |
-
class FeedForward(nn.Module):
|
1246 |
-
r"""
|
1247 |
-
A feed-forward layer.
|
1248 |
-
|
1249 |
-
Parameters:
|
1250 |
-
dim (`int`): The number of channels in the input.
|
1251 |
-
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
1252 |
-
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
1253 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
1254 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
1255 |
-
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
1256 |
-
"""
|
1257 |
-
|
1258 |
-
def __init__(
|
1259 |
-
self,
|
1260 |
-
dim: int,
|
1261 |
-
dim_out: Optional[int] = None,
|
1262 |
-
mult: int = 4,
|
1263 |
-
dropout: float = 0.0,
|
1264 |
-
activation_fn: str = "geglu",
|
1265 |
-
final_dropout: bool = False,
|
1266 |
-
):
|
1267 |
-
super().__init__()
|
1268 |
-
inner_dim = int(dim * mult)
|
1269 |
-
dim_out = dim_out if dim_out is not None else dim
|
1270 |
-
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
1271 |
-
|
1272 |
-
if activation_fn == "gelu":
|
1273 |
-
act_fn = GELU(dim, inner_dim)
|
1274 |
-
if activation_fn == "gelu-approximate":
|
1275 |
-
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
1276 |
-
elif activation_fn == "geglu":
|
1277 |
-
act_fn = GEGLU(dim, inner_dim)
|
1278 |
-
elif activation_fn == "geglu-approximate":
|
1279 |
-
act_fn = ApproximateGELU(dim, inner_dim)
|
1280 |
-
|
1281 |
-
self.net = nn.ModuleList([])
|
1282 |
-
# project in
|
1283 |
-
self.net.append(act_fn)
|
1284 |
-
# project dropout
|
1285 |
-
self.net.append(nn.Dropout(dropout))
|
1286 |
-
# project out
|
1287 |
-
self.net.append(linear_cls(inner_dim, dim_out))
|
1288 |
-
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
1289 |
-
if final_dropout:
|
1290 |
-
self.net.append(nn.Dropout(dropout))
|
1291 |
-
|
1292 |
-
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
1293 |
-
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
1294 |
-
for module in self.net:
|
1295 |
-
if isinstance(module, compatible_cls):
|
1296 |
-
hidden_states = module(hidden_states, scale)
|
1297 |
-
else:
|
1298 |
-
hidden_states = module(hidden_states)
|
1299 |
-
return hidden_states
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
|
|
14 |
from typing import Any, Dict, Optional
|
15 |
|
16 |
+
import diffusers
|
17 |
+
import pkg_resources
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
import torch.nn.init as init
|
21 |
+
|
22 |
+
installed_version = diffusers.__version__
|
23 |
+
|
24 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
25 |
+
from diffusers.models.attention_processor import (Attention,
|
26 |
+
AttnProcessor2_0,
|
27 |
+
HunyuanAttnProcessor2_0)
|
28 |
+
else:
|
29 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
30 |
+
|
31 |
from diffusers.models.attention import AdaLayerNorm, FeedForward
|
|
|
32 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
|
|
33 |
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
|
34 |
from diffusers.utils import USE_PEFT_BACKEND
|
35 |
from diffusers.utils.import_utils import is_xformers_available
|
|
|
37 |
from einops import rearrange, repeat
|
38 |
from torch import nn
|
39 |
|
40 |
+
from .motion_module import PositionalEncoding, get_motion_module
|
41 |
+
from .norm import FP32LayerNorm
|
42 |
|
43 |
if is_xformers_available():
|
44 |
import xformers
|
|
|
47 |
xformers = None
|
48 |
|
49 |
|
50 |
+
def zero_module(module):
|
51 |
+
# Zero out the parameters of a module and return it.
|
52 |
+
for p in module.parameters():
|
53 |
+
p.detach().zero_()
|
54 |
+
return module
|
55 |
+
|
56 |
+
|
57 |
@maybe_allow_in_graph
|
58 |
class GatedSelfAttentionDense(nn.Module):
|
59 |
r"""
|
|
|
75 |
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
76 |
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
77 |
|
78 |
+
self.norm1 = FP32LayerNorm(query_dim)
|
79 |
+
self.norm2 = FP32LayerNorm(query_dim)
|
80 |
|
81 |
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
82 |
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
|
|
96 |
return x
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
class KVCompressionCrossAttention(nn.Module):
|
100 |
r"""
|
101 |
A cross attention layer.
|
|
|
162 |
stride=2,
|
163 |
bias=True
|
164 |
)
|
165 |
+
self.kv_compression_norm = FP32LayerNorm(query_dim)
|
166 |
init.constant_(self.kv_compression.weight, 1 / 4)
|
167 |
if self.kv_compression.bias is not None:
|
168 |
init.constant_(self.kv_compression.bias, 0)
|
|
|
418 |
# motion module kwargs
|
419 |
motion_module_type = "VanillaGrid",
|
420 |
motion_module_kwargs = None,
|
421 |
+
qk_norm = False,
|
422 |
+
after_norm = False,
|
423 |
):
|
424 |
super().__init__()
|
425 |
self.only_cross_attention = only_cross_attention
|
|
|
452 |
elif self.use_ada_layer_norm_zero:
|
453 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
454 |
else:
|
455 |
+
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
456 |
|
457 |
self.kvcompression = kvcompression
|
458 |
if kvcompression:
|
|
|
466 |
upcast_attention=upcast_attention,
|
467 |
)
|
468 |
else:
|
469 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
470 |
+
self.attn1 = Attention(
|
471 |
+
query_dim=dim,
|
472 |
+
heads=num_attention_heads,
|
473 |
+
dim_head=attention_head_dim,
|
474 |
+
dropout=dropout,
|
475 |
+
bias=attention_bias,
|
476 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
477 |
+
upcast_attention=upcast_attention,
|
478 |
+
qk_norm="layer_norm" if qk_norm else None,
|
479 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
self.attn1 = Attention(
|
483 |
+
query_dim=dim,
|
484 |
+
heads=num_attention_heads,
|
485 |
+
dim_head=attention_head_dim,
|
486 |
+
dropout=dropout,
|
487 |
+
bias=attention_bias,
|
488 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
489 |
+
upcast_attention=upcast_attention,
|
490 |
+
)
|
491 |
|
492 |
self.attn_temporal = get_motion_module(
|
493 |
in_channels = dim,
|
|
|
503 |
self.norm2 = (
|
504 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
505 |
if self.use_ada_layer_norm
|
506 |
+
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
507 |
)
|
508 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
509 |
+
self.attn2 = Attention(
|
510 |
+
query_dim=dim,
|
511 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
512 |
+
heads=num_attention_heads,
|
513 |
+
dim_head=attention_head_dim,
|
514 |
+
dropout=dropout,
|
515 |
+
bias=attention_bias,
|
516 |
+
upcast_attention=upcast_attention,
|
517 |
+
qk_norm="layer_norm" if qk_norm else None,
|
518 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
519 |
+
) # is self-attn if encoder_hidden_states is none
|
520 |
+
else:
|
521 |
+
self.attn2 = Attention(
|
522 |
+
query_dim=dim,
|
523 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
524 |
+
heads=num_attention_heads,
|
525 |
+
dim_head=attention_head_dim,
|
526 |
+
dropout=dropout,
|
527 |
+
bias=attention_bias,
|
528 |
+
upcast_attention=upcast_attention,
|
529 |
+
) # is self-attn if encoder_hidden_states is none
|
530 |
else:
|
531 |
self.norm2 = None
|
532 |
self.attn2 = None
|
533 |
|
534 |
# 3. Feed-forward
|
535 |
if not self.use_ada_layer_norm_single:
|
536 |
+
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
537 |
|
538 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
539 |
|
540 |
+
if after_norm:
|
541 |
+
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
542 |
+
else:
|
543 |
+
self.norm4 = None
|
544 |
+
|
545 |
# 4. Fuser
|
546 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
547 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
|
|
694 |
)
|
695 |
else:
|
696 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
697 |
+
|
698 |
+
if self.norm4 is not None:
|
699 |
+
ff_output = self.norm4(ff_output)
|
700 |
|
701 |
if self.use_ada_layer_norm_zero:
|
702 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
766 |
attention_type: str = "default",
|
767 |
positional_embeddings: Optional[str] = None,
|
768 |
num_positional_embeddings: Optional[int] = None,
|
769 |
+
qk_norm = False,
|
770 |
+
after_norm = False,
|
771 |
):
|
772 |
super().__init__()
|
773 |
self.only_cross_attention = only_cross_attention
|
|
|
800 |
elif self.use_ada_layer_norm_zero:
|
801 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
802 |
else:
|
803 |
+
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
804 |
|
805 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
806 |
+
self.attn1 = Attention(
|
807 |
+
query_dim=dim,
|
808 |
+
heads=num_attention_heads,
|
809 |
+
dim_head=attention_head_dim,
|
810 |
+
dropout=dropout,
|
811 |
+
bias=attention_bias,
|
812 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
813 |
+
upcast_attention=upcast_attention,
|
814 |
+
qk_norm="layer_norm" if qk_norm else None,
|
815 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
816 |
+
)
|
817 |
+
else:
|
818 |
+
self.attn1 = Attention(
|
819 |
+
query_dim=dim,
|
820 |
+
heads=num_attention_heads,
|
821 |
+
dim_head=attention_head_dim,
|
822 |
+
dropout=dropout,
|
823 |
+
bias=attention_bias,
|
824 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
825 |
+
upcast_attention=upcast_attention,
|
826 |
+
)
|
827 |
|
828 |
# 2. Cross-Attn
|
829 |
if cross_attention_dim is not None or double_self_attention:
|
|
|
833 |
self.norm2 = (
|
834 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
835 |
if self.use_ada_layer_norm
|
836 |
+
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
837 |
)
|
838 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
839 |
+
self.attn2 = Attention(
|
840 |
+
query_dim=dim,
|
841 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
842 |
+
heads=num_attention_heads,
|
843 |
+
dim_head=attention_head_dim,
|
844 |
+
dropout=dropout,
|
845 |
+
bias=attention_bias,
|
846 |
+
upcast_attention=upcast_attention,
|
847 |
+
qk_norm="layer_norm" if qk_norm else None,
|
848 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
849 |
+
) # is self-attn if encoder_hidden_states is none
|
850 |
+
else:
|
851 |
+
self.attn2 = Attention(
|
852 |
+
query_dim=dim,
|
853 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
854 |
+
heads=num_attention_heads,
|
855 |
+
dim_head=attention_head_dim,
|
856 |
+
dropout=dropout,
|
857 |
+
bias=attention_bias,
|
858 |
+
upcast_attention=upcast_attention,
|
859 |
+
) # is self-attn if encoder_hidden_states is none
|
860 |
else:
|
861 |
self.norm2 = None
|
862 |
self.attn2 = None
|
863 |
|
864 |
# 3. Feed-forward
|
865 |
if not self.use_ada_layer_norm_single:
|
866 |
+
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
867 |
|
868 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
869 |
|
870 |
+
if after_norm:
|
871 |
+
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
872 |
+
else:
|
873 |
+
self.norm4 = None
|
874 |
+
|
875 |
# 4. Fuser
|
876 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
877 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
|
|
1003 |
)
|
1004 |
else:
|
1005 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
1006 |
+
|
1007 |
+
if self.norm4 is not None:
|
1008 |
+
ff_output = self.norm4(ff_output)
|
1009 |
|
1010 |
if self.use_ada_layer_norm_zero:
|
1011 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
1076 |
positional_embeddings: Optional[str] = None,
|
1077 |
num_positional_embeddings: Optional[int] = None,
|
1078 |
kvcompression: Optional[bool] = False,
|
1079 |
+
qk_norm = False,
|
1080 |
+
after_norm = False,
|
1081 |
):
|
1082 |
super().__init__()
|
1083 |
self.only_cross_attention = only_cross_attention
|
|
|
1110 |
elif self.use_ada_layer_norm_zero:
|
1111 |
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
1112 |
else:
|
1113 |
+
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1114 |
|
1115 |
self.kvcompression = kvcompression
|
1116 |
if kvcompression:
|
|
|
1124 |
upcast_attention=upcast_attention,
|
1125 |
)
|
1126 |
else:
|
1127 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
1128 |
+
self.attn1 = Attention(
|
1129 |
+
query_dim=dim,
|
1130 |
+
heads=num_attention_heads,
|
1131 |
+
dim_head=attention_head_dim,
|
1132 |
+
dropout=dropout,
|
1133 |
+
bias=attention_bias,
|
1134 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
1135 |
+
upcast_attention=upcast_attention,
|
1136 |
+
qk_norm="layer_norm" if qk_norm else None,
|
1137 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
1138 |
+
)
|
1139 |
+
else:
|
1140 |
+
self.attn1 = Attention(
|
1141 |
+
query_dim=dim,
|
1142 |
+
heads=num_attention_heads,
|
1143 |
+
dim_head=attention_head_dim,
|
1144 |
+
dropout=dropout,
|
1145 |
+
bias=attention_bias,
|
1146 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
1147 |
+
upcast_attention=upcast_attention,
|
1148 |
+
)
|
1149 |
|
1150 |
# 2. Cross-Attn
|
1151 |
if cross_attention_dim is not None or double_self_attention:
|
|
|
1155 |
self.norm2 = (
|
1156 |
AdaLayerNorm(dim, num_embeds_ada_norm)
|
1157 |
if self.use_ada_layer_norm
|
1158 |
+
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1159 |
)
|
1160 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
1161 |
+
self.attn2 = Attention(
|
1162 |
+
query_dim=dim,
|
1163 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
1164 |
+
heads=num_attention_heads,
|
1165 |
+
dim_head=attention_head_dim,
|
1166 |
+
dropout=dropout,
|
1167 |
+
bias=attention_bias,
|
1168 |
+
upcast_attention=upcast_attention,
|
1169 |
+
qk_norm="layer_norm" if qk_norm else None,
|
1170 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
1171 |
+
) # is self-attn if encoder_hidden_states is none
|
1172 |
+
else:
|
1173 |
+
self.attn2 = Attention(
|
1174 |
+
query_dim=dim,
|
1175 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
1176 |
+
heads=num_attention_heads,
|
1177 |
+
dim_head=attention_head_dim,
|
1178 |
+
dropout=dropout,
|
1179 |
+
bias=attention_bias,
|
1180 |
+
upcast_attention=upcast_attention,
|
1181 |
+
) # is self-attn if encoder_hidden_states is none
|
1182 |
else:
|
1183 |
self.norm2 = None
|
1184 |
self.attn2 = None
|
1185 |
|
1186 |
# 3. Feed-forward
|
1187 |
if not self.use_ada_layer_norm_single:
|
1188 |
+
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1189 |
|
1190 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
1191 |
|
1192 |
+
if after_norm:
|
1193 |
+
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1194 |
+
else:
|
1195 |
+
self.norm4 = None
|
1196 |
+
|
1197 |
# 4. Fuser
|
1198 |
if attention_type == "gated" or attention_type == "gated-text-image":
|
1199 |
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
|
|
1340 |
)
|
1341 |
else:
|
1342 |
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
1343 |
+
|
1344 |
+
if self.norm4 is not None:
|
1345 |
+
ff_output = self.norm4(ff_output)
|
1346 |
|
1347 |
if self.use_ada_layer_norm_zero:
|
1348 |
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
1353 |
if hidden_states.ndim == 4:
|
1354 |
hidden_states = hidden_states.squeeze(1)
|
1355 |
|
1356 |
+
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
easyanimate/models/autoencoder_magvit.py
CHANGED
@@ -17,7 +17,12 @@ import torch
|
|
17 |
import torch.nn as nn
|
18 |
import torch.nn.functional as F
|
19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
from diffusers.models.attention_processor import (
|
22 |
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
|
23 |
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
|
@@ -93,6 +98,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
93 |
norm_num_groups: int = 32,
|
94 |
scaling_factor: float = 0.1825,
|
95 |
slice_compression_vae=False,
|
|
|
96 |
mini_batch_encoder=9,
|
97 |
mini_batch_decoder=3,
|
98 |
):
|
@@ -145,8 +151,8 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
145 |
self.mini_batch_encoder = mini_batch_encoder
|
146 |
self.mini_batch_decoder = mini_batch_decoder
|
147 |
self.use_slicing = False
|
148 |
-
self.use_tiling =
|
149 |
-
self.tile_sample_min_size =
|
150 |
self.tile_overlap_factor = 0.25
|
151 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
|
152 |
self.scaling_factor = scaling_factor
|
|
|
17 |
import torch.nn as nn
|
18 |
import torch.nn.functional as F
|
19 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
+
|
21 |
+
try:
|
22 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
23 |
+
except:
|
24 |
+
from diffusers.loaders import FromOriginalModelMixin as FromOriginalVAEMixin
|
25 |
+
|
26 |
from diffusers.models.attention_processor import (
|
27 |
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
|
28 |
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
|
|
|
98 |
norm_num_groups: int = 32,
|
99 |
scaling_factor: float = 0.1825,
|
100 |
slice_compression_vae=False,
|
101 |
+
use_tiling=False,
|
102 |
mini_batch_encoder=9,
|
103 |
mini_batch_decoder=3,
|
104 |
):
|
|
|
151 |
self.mini_batch_encoder = mini_batch_encoder
|
152 |
self.mini_batch_decoder = mini_batch_decoder
|
153 |
self.use_slicing = False
|
154 |
+
self.use_tiling = use_tiling
|
155 |
+
self.tile_sample_min_size = 384
|
156 |
self.tile_overlap_factor = 0.25
|
157 |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
|
158 |
self.scaling_factor = scaling_factor
|
easyanimate/models/motion_module.py
CHANGED
@@ -1,248 +1,33 @@
|
|
1 |
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
2 |
"""
|
3 |
import math
|
4 |
-
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
|
|
|
|
|
6 |
import torch
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from diffusers.models.attention import FeedForward
|
9 |
from diffusers.utils.import_utils import is_xformers_available
|
10 |
from einops import rearrange, repeat
|
11 |
from torch import nn
|
12 |
|
|
|
|
|
13 |
if is_xformers_available():
|
14 |
import xformers
|
15 |
import xformers.ops
|
16 |
else:
|
17 |
xformers = None
|
18 |
|
19 |
-
class CrossAttention(nn.Module):
|
20 |
-
r"""
|
21 |
-
A cross attention layer.
|
22 |
-
|
23 |
-
Parameters:
|
24 |
-
query_dim (`int`): The number of channels in the query.
|
25 |
-
cross_attention_dim (`int`, *optional*):
|
26 |
-
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
27 |
-
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
28 |
-
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
29 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
30 |
-
bias (`bool`, *optional*, defaults to False):
|
31 |
-
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(
|
35 |
-
self,
|
36 |
-
query_dim: int,
|
37 |
-
cross_attention_dim: Optional[int] = None,
|
38 |
-
heads: int = 8,
|
39 |
-
dim_head: int = 64,
|
40 |
-
dropout: float = 0.0,
|
41 |
-
bias=False,
|
42 |
-
upcast_attention: bool = False,
|
43 |
-
upcast_softmax: bool = False,
|
44 |
-
added_kv_proj_dim: Optional[int] = None,
|
45 |
-
norm_num_groups: Optional[int] = None,
|
46 |
-
):
|
47 |
-
super().__init__()
|
48 |
-
inner_dim = dim_head * heads
|
49 |
-
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
50 |
-
self.upcast_attention = upcast_attention
|
51 |
-
self.upcast_softmax = upcast_softmax
|
52 |
-
|
53 |
-
self.scale = dim_head**-0.5
|
54 |
-
|
55 |
-
self.heads = heads
|
56 |
-
# for slice_size > 0 the attention score computation
|
57 |
-
# is split across the batch axis to save memory
|
58 |
-
# You can set slice_size with `set_attention_slice`
|
59 |
-
self.sliceable_head_dim = heads
|
60 |
-
self._slice_size = None
|
61 |
-
self._use_memory_efficient_attention_xformers = False
|
62 |
-
self.added_kv_proj_dim = added_kv_proj_dim
|
63 |
-
|
64 |
-
if norm_num_groups is not None:
|
65 |
-
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
66 |
-
else:
|
67 |
-
self.group_norm = None
|
68 |
-
|
69 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
70 |
-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
71 |
-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
72 |
-
|
73 |
-
if self.added_kv_proj_dim is not None:
|
74 |
-
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
75 |
-
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
76 |
-
|
77 |
-
self.to_out = nn.ModuleList([])
|
78 |
-
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
79 |
-
self.to_out.append(nn.Dropout(dropout))
|
80 |
-
|
81 |
-
def set_use_memory_efficient_attention_xformers(
|
82 |
-
self, valid: bool, attention_op: Optional[Callable] = None
|
83 |
-
) -> None:
|
84 |
-
self._use_memory_efficient_attention_xformers = valid
|
85 |
-
|
86 |
-
def reshape_heads_to_batch_dim(self, tensor):
|
87 |
-
batch_size, seq_len, dim = tensor.shape
|
88 |
-
head_size = self.heads
|
89 |
-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
90 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
91 |
-
return tensor
|
92 |
-
|
93 |
-
def reshape_batch_dim_to_heads(self, tensor):
|
94 |
-
batch_size, seq_len, dim = tensor.shape
|
95 |
-
head_size = self.heads
|
96 |
-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
97 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
98 |
-
return tensor
|
99 |
-
|
100 |
-
def set_attention_slice(self, slice_size):
|
101 |
-
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
102 |
-
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
103 |
-
|
104 |
-
self._slice_size = slice_size
|
105 |
-
|
106 |
-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
107 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
108 |
-
|
109 |
-
encoder_hidden_states = encoder_hidden_states
|
110 |
-
|
111 |
-
if self.group_norm is not None:
|
112 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
113 |
-
|
114 |
-
query = self.to_q(hidden_states)
|
115 |
-
dim = query.shape[-1]
|
116 |
-
query = self.reshape_heads_to_batch_dim(query)
|
117 |
-
|
118 |
-
if self.added_kv_proj_dim is not None:
|
119 |
-
key = self.to_k(hidden_states)
|
120 |
-
value = self.to_v(hidden_states)
|
121 |
-
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
122 |
-
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
123 |
-
|
124 |
-
key = self.reshape_heads_to_batch_dim(key)
|
125 |
-
value = self.reshape_heads_to_batch_dim(value)
|
126 |
-
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
127 |
-
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
128 |
-
|
129 |
-
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
130 |
-
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
131 |
-
else:
|
132 |
-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
133 |
-
key = self.to_k(encoder_hidden_states)
|
134 |
-
value = self.to_v(encoder_hidden_states)
|
135 |
-
|
136 |
-
key = self.reshape_heads_to_batch_dim(key)
|
137 |
-
value = self.reshape_heads_to_batch_dim(value)
|
138 |
-
|
139 |
-
if attention_mask is not None:
|
140 |
-
if attention_mask.shape[-1] != query.shape[1]:
|
141 |
-
target_length = query.shape[1]
|
142 |
-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
143 |
-
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
144 |
-
|
145 |
-
# attention, what we cannot get enough of
|
146 |
-
if self._use_memory_efficient_attention_xformers:
|
147 |
-
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
148 |
-
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
149 |
-
hidden_states = hidden_states.to(query.dtype)
|
150 |
-
else:
|
151 |
-
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
152 |
-
hidden_states = self._attention(query, key, value, attention_mask)
|
153 |
-
else:
|
154 |
-
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
155 |
-
|
156 |
-
# linear proj
|
157 |
-
hidden_states = self.to_out[0](hidden_states)
|
158 |
-
|
159 |
-
# dropout
|
160 |
-
hidden_states = self.to_out[1](hidden_states)
|
161 |
-
return hidden_states
|
162 |
-
|
163 |
-
def _attention(self, query, key, value, attention_mask=None):
|
164 |
-
if self.upcast_attention:
|
165 |
-
query = query.float()
|
166 |
-
key = key.float()
|
167 |
-
|
168 |
-
attention_scores = torch.baddbmm(
|
169 |
-
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
170 |
-
query,
|
171 |
-
key.transpose(-1, -2),
|
172 |
-
beta=0,
|
173 |
-
alpha=self.scale,
|
174 |
-
)
|
175 |
-
|
176 |
-
if attention_mask is not None:
|
177 |
-
attention_scores = attention_scores + attention_mask
|
178 |
-
|
179 |
-
if self.upcast_softmax:
|
180 |
-
attention_scores = attention_scores.float()
|
181 |
-
|
182 |
-
attention_probs = attention_scores.softmax(dim=-1)
|
183 |
-
|
184 |
-
# cast back to the original dtype
|
185 |
-
attention_probs = attention_probs.to(value.dtype)
|
186 |
-
|
187 |
-
# compute attention output
|
188 |
-
hidden_states = torch.bmm(attention_probs, value)
|
189 |
-
|
190 |
-
# reshape hidden_states
|
191 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
192 |
-
return hidden_states
|
193 |
-
|
194 |
-
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
195 |
-
batch_size_attention = query.shape[0]
|
196 |
-
hidden_states = torch.zeros(
|
197 |
-
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
198 |
-
)
|
199 |
-
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
200 |
-
for i in range(hidden_states.shape[0] // slice_size):
|
201 |
-
start_idx = i * slice_size
|
202 |
-
end_idx = (i + 1) * slice_size
|
203 |
-
|
204 |
-
query_slice = query[start_idx:end_idx]
|
205 |
-
key_slice = key[start_idx:end_idx]
|
206 |
-
|
207 |
-
if self.upcast_attention:
|
208 |
-
query_slice = query_slice.float()
|
209 |
-
key_slice = key_slice.float()
|
210 |
-
|
211 |
-
attn_slice = torch.baddbmm(
|
212 |
-
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
213 |
-
query_slice,
|
214 |
-
key_slice.transpose(-1, -2),
|
215 |
-
beta=0,
|
216 |
-
alpha=self.scale,
|
217 |
-
)
|
218 |
-
|
219 |
-
if attention_mask is not None:
|
220 |
-
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
221 |
-
|
222 |
-
if self.upcast_softmax:
|
223 |
-
attn_slice = attn_slice.float()
|
224 |
-
|
225 |
-
attn_slice = attn_slice.softmax(dim=-1)
|
226 |
-
|
227 |
-
# cast back to the original dtype
|
228 |
-
attn_slice = attn_slice.to(value.dtype)
|
229 |
-
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
230 |
-
|
231 |
-
hidden_states[start_idx:end_idx] = attn_slice
|
232 |
-
|
233 |
-
# reshape hidden_states
|
234 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
235 |
-
return hidden_states
|
236 |
-
|
237 |
-
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
238 |
-
# TODO attention_mask
|
239 |
-
query = query.contiguous()
|
240 |
-
key = key.contiguous()
|
241 |
-
value = value.contiguous()
|
242 |
-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
243 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
244 |
-
return hidden_states
|
245 |
-
|
246 |
def zero_module(module):
|
247 |
# Zero out the parameters of a module and return it.
|
248 |
for p in module.parameters():
|
@@ -275,6 +60,11 @@ class VanillaTemporalModule(nn.Module):
|
|
275 |
zero_initialize = True,
|
276 |
block_size = 1,
|
277 |
grid = False,
|
|
|
|
|
|
|
|
|
|
|
278 |
):
|
279 |
super().__init__()
|
280 |
|
@@ -289,17 +79,87 @@ class VanillaTemporalModule(nn.Module):
|
|
289 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
290 |
grid=grid,
|
291 |
block_size=block_size,
|
|
|
|
|
292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
if zero_initialize:
|
294 |
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
|
|
|
|
295 |
|
296 |
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
|
297 |
hidden_states = input_tensor
|
298 |
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
|
|
|
|
299 |
|
300 |
output = hidden_states
|
301 |
return output
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
class TemporalTransformer3DModel(nn.Module):
|
304 |
def __init__(
|
305 |
self,
|
@@ -321,6 +181,8 @@ class TemporalTransformer3DModel(nn.Module):
|
|
321 |
temporal_position_encoding_max_len = 4096,
|
322 |
grid = False,
|
323 |
block_size = 1,
|
|
|
|
|
324 |
):
|
325 |
super().__init__()
|
326 |
|
@@ -348,6 +210,8 @@ class TemporalTransformer3DModel(nn.Module):
|
|
348 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
349 |
block_size=block_size,
|
350 |
grid=grid,
|
|
|
|
|
351 |
)
|
352 |
for d in range(num_layers)
|
353 |
]
|
@@ -398,6 +262,8 @@ class TemporalTransformerBlock(nn.Module):
|
|
398 |
temporal_position_encoding_max_len = 4096,
|
399 |
block_size = 1,
|
400 |
grid = False,
|
|
|
|
|
401 |
):
|
402 |
super().__init__()
|
403 |
|
@@ -422,15 +288,36 @@ class TemporalTransformerBlock(nn.Module):
|
|
422 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
423 |
block_size=block_size,
|
424 |
grid=grid,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
)
|
426 |
)
|
427 |
-
norms.append(
|
428 |
|
429 |
self.attention_blocks = nn.ModuleList(attention_blocks)
|
430 |
self.norms = nn.ModuleList(norms)
|
431 |
|
432 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
433 |
-
self.ff_norm =
|
434 |
|
435 |
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
|
436 |
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
@@ -468,7 +355,7 @@ class PositionalEncoding(nn.Module):
|
|
468 |
x = x + self.pe[:, :x.size(1)]
|
469 |
return self.dropout(x)
|
470 |
|
471 |
-
class VersatileAttention(
|
472 |
def __init__(
|
473 |
self,
|
474 |
attention_mode = None,
|
@@ -477,21 +364,23 @@ class VersatileAttention(CrossAttention):
|
|
477 |
temporal_position_encoding_max_len = 4096,
|
478 |
grid = False,
|
479 |
block_size = 1,
|
|
|
480 |
*args, **kwargs
|
481 |
):
|
482 |
super().__init__(*args, **kwargs)
|
483 |
-
assert attention_mode == "Temporal"
|
484 |
|
485 |
self.attention_mode = attention_mode
|
486 |
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
487 |
|
488 |
self.block_size = block_size
|
489 |
self.grid = grid
|
|
|
490 |
self.pos_encoder = PositionalEncoding(
|
491 |
kwargs["query_dim"],
|
492 |
dropout=0.,
|
493 |
max_len=temporal_position_encoding_max_len
|
494 |
-
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
495 |
|
496 |
def extra_repr(self):
|
497 |
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
@@ -503,8 +392,13 @@ class VersatileAttention(CrossAttention):
|
|
503 |
# for add pos_encoder
|
504 |
_, before_d, _c = hidden_states.size()
|
505 |
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
if self.grid:
|
510 |
hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
|
@@ -515,61 +409,36 @@ class VersatileAttention(CrossAttention):
|
|
515 |
else:
|
516 |
d = before_d
|
517 |
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
else:
|
519 |
raise NotImplementedError
|
520 |
|
521 |
-
encoder_hidden_states = encoder_hidden_states
|
522 |
-
|
523 |
-
if self.group_norm is not None:
|
524 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
525 |
-
|
526 |
-
query = self.to_q(hidden_states)
|
527 |
-
dim = query.shape[-1]
|
528 |
-
query = self.reshape_heads_to_batch_dim(query)
|
529 |
-
|
530 |
-
if self.added_kv_proj_dim is not None:
|
531 |
-
raise NotImplementedError
|
532 |
-
|
533 |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
534 |
-
key = self.to_k(encoder_hidden_states)
|
535 |
-
value = self.to_v(encoder_hidden_states)
|
536 |
-
|
537 |
-
key = self.reshape_heads_to_batch_dim(key)
|
538 |
-
value = self.reshape_heads_to_batch_dim(value)
|
539 |
-
|
540 |
-
if attention_mask is not None:
|
541 |
-
if attention_mask.shape[-1] != query.shape[1]:
|
542 |
-
target_length = query.shape[1]
|
543 |
-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
544 |
-
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
545 |
|
546 |
bs = 512
|
547 |
new_hidden_states = []
|
548 |
-
for i in range(0,
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
|
556 |
-
hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
|
557 |
-
else:
|
558 |
-
hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
|
559 |
-
new_hidden_states.append(hidden_states)
|
560 |
hidden_states = torch.cat(new_hidden_states, dim = 0)
|
561 |
|
562 |
-
# linear proj
|
563 |
-
hidden_states = self.to_out[0](hidden_states)
|
564 |
-
|
565 |
-
# dropout
|
566 |
-
hidden_states = self.to_out[1](hidden_states)
|
567 |
-
|
568 |
if self.attention_mode == "Temporal":
|
569 |
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
570 |
if self.grid:
|
571 |
hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
|
572 |
hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
|
573 |
hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
|
|
|
|
|
574 |
|
575 |
return hidden_states
|
|
|
1 |
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
2 |
"""
|
3 |
import math
|
|
|
4 |
|
5 |
+
import diffusers
|
6 |
+
import pkg_resources
|
7 |
import torch
|
8 |
+
|
9 |
+
installed_version = diffusers.__version__
|
10 |
+
|
11 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
12 |
+
from diffusers.models.attention_processor import (Attention,
|
13 |
+
AttnProcessor2_0,
|
14 |
+
HunyuanAttnProcessor2_0)
|
15 |
+
else:
|
16 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
17 |
+
|
18 |
from diffusers.models.attention import FeedForward
|
19 |
from diffusers.utils.import_utils import is_xformers_available
|
20 |
from einops import rearrange, repeat
|
21 |
from torch import nn
|
22 |
|
23 |
+
from .norm import FP32LayerNorm
|
24 |
+
|
25 |
if is_xformers_available():
|
26 |
import xformers
|
27 |
import xformers.ops
|
28 |
else:
|
29 |
xformers = None
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def zero_module(module):
|
32 |
# Zero out the parameters of a module and return it.
|
33 |
for p in module.parameters():
|
|
|
60 |
zero_initialize = True,
|
61 |
block_size = 1,
|
62 |
grid = False,
|
63 |
+
remove_time_embedding_in_photo = False,
|
64 |
+
|
65 |
+
global_num_attention_heads = 16,
|
66 |
+
global_attention = False,
|
67 |
+
qk_norm = False,
|
68 |
):
|
69 |
super().__init__()
|
70 |
|
|
|
79 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
80 |
grid=grid,
|
81 |
block_size=block_size,
|
82 |
+
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
|
83 |
+
qk_norm=qk_norm,
|
84 |
)
|
85 |
+
self.global_transformer = GlobalTransformer3DModel(
|
86 |
+
in_channels=in_channels,
|
87 |
+
num_attention_heads=global_num_attention_heads,
|
88 |
+
attention_head_dim=in_channels // global_num_attention_heads // temporal_attention_dim_div,
|
89 |
+
qk_norm=qk_norm,
|
90 |
+
) if global_attention else None
|
91 |
if zero_initialize:
|
92 |
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
93 |
+
if global_attention:
|
94 |
+
self.global_transformer.proj_out = zero_module(self.global_transformer.proj_out)
|
95 |
|
96 |
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
|
97 |
hidden_states = input_tensor
|
98 |
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
99 |
+
if self.global_transformer is not None:
|
100 |
+
hidden_states = self.global_transformer(hidden_states)
|
101 |
|
102 |
output = hidden_states
|
103 |
return output
|
104 |
|
105 |
+
class GlobalTransformer3DModel(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
in_channels,
|
109 |
+
num_attention_heads,
|
110 |
+
attention_head_dim,
|
111 |
+
dropout = 0.0,
|
112 |
+
attention_bias = False,
|
113 |
+
upcast_attention = False,
|
114 |
+
qk_norm = False,
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
inner_dim = num_attention_heads * attention_head_dim
|
119 |
+
|
120 |
+
self.norm1 = FP32LayerNorm(inner_dim)
|
121 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
122 |
+
self.norm2 = FP32LayerNorm(inner_dim)
|
123 |
+
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
|
124 |
+
self.attention = Attention(
|
125 |
+
query_dim=inner_dim,
|
126 |
+
heads=num_attention_heads,
|
127 |
+
dim_head=attention_head_dim,
|
128 |
+
dropout=dropout,
|
129 |
+
bias=attention_bias,
|
130 |
+
upcast_attention=upcast_attention,
|
131 |
+
qk_norm="layer_norm" if qk_norm else None,
|
132 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
self.attention = Attention(
|
136 |
+
query_dim=inner_dim,
|
137 |
+
heads=num_attention_heads,
|
138 |
+
dim_head=attention_head_dim,
|
139 |
+
dropout=dropout,
|
140 |
+
bias=attention_bias,
|
141 |
+
upcast_attention=upcast_attention,
|
142 |
+
)
|
143 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
144 |
+
|
145 |
+
def forward(self, hidden_states):
|
146 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
147 |
+
video_length, height, width = hidden_states.shape[2], hidden_states.shape[3], hidden_states.shape[4]
|
148 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
149 |
+
|
150 |
+
residual = hidden_states
|
151 |
+
hidden_states = self.norm1(hidden_states)
|
152 |
+
hidden_states = self.proj_in(hidden_states)
|
153 |
+
|
154 |
+
# Attention Blocks
|
155 |
+
hidden_states = self.norm2(hidden_states)
|
156 |
+
hidden_states = self.attention(hidden_states)
|
157 |
+
hidden_states = self.proj_out(hidden_states)
|
158 |
+
|
159 |
+
output = hidden_states + residual
|
160 |
+
output = rearrange(output, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
|
161 |
+
return output
|
162 |
+
|
163 |
class TemporalTransformer3DModel(nn.Module):
|
164 |
def __init__(
|
165 |
self,
|
|
|
181 |
temporal_position_encoding_max_len = 4096,
|
182 |
grid = False,
|
183 |
block_size = 1,
|
184 |
+
remove_time_embedding_in_photo = False,
|
185 |
+
qk_norm = False,
|
186 |
):
|
187 |
super().__init__()
|
188 |
|
|
|
210 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
211 |
block_size=block_size,
|
212 |
grid=grid,
|
213 |
+
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
|
214 |
+
qk_norm=qk_norm
|
215 |
)
|
216 |
for d in range(num_layers)
|
217 |
]
|
|
|
262 |
temporal_position_encoding_max_len = 4096,
|
263 |
block_size = 1,
|
264 |
grid = False,
|
265 |
+
remove_time_embedding_in_photo = False,
|
266 |
+
qk_norm = False,
|
267 |
):
|
268 |
super().__init__()
|
269 |
|
|
|
288 |
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
289 |
block_size=block_size,
|
290 |
grid=grid,
|
291 |
+
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
|
292 |
+
qk_norm="layer_norm" if qk_norm else None,
|
293 |
+
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
|
294 |
+
) if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2") else \
|
295 |
+
VersatileAttention(
|
296 |
+
attention_mode=block_name.split("_")[0],
|
297 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
298 |
+
|
299 |
+
query_dim=dim,
|
300 |
+
heads=num_attention_heads,
|
301 |
+
dim_head=attention_head_dim,
|
302 |
+
dropout=dropout,
|
303 |
+
bias=attention_bias,
|
304 |
+
upcast_attention=upcast_attention,
|
305 |
+
|
306 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
307 |
+
temporal_position_encoding=temporal_position_encoding,
|
308 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
309 |
+
block_size=block_size,
|
310 |
+
grid=grid,
|
311 |
+
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
|
312 |
)
|
313 |
)
|
314 |
+
norms.append(FP32LayerNorm(dim))
|
315 |
|
316 |
self.attention_blocks = nn.ModuleList(attention_blocks)
|
317 |
self.norms = nn.ModuleList(norms)
|
318 |
|
319 |
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
320 |
+
self.ff_norm = FP32LayerNorm(dim)
|
321 |
|
322 |
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
|
323 |
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
|
|
355 |
x = x + self.pe[:, :x.size(1)]
|
356 |
return self.dropout(x)
|
357 |
|
358 |
+
class VersatileAttention(Attention):
|
359 |
def __init__(
|
360 |
self,
|
361 |
attention_mode = None,
|
|
|
364 |
temporal_position_encoding_max_len = 4096,
|
365 |
grid = False,
|
366 |
block_size = 1,
|
367 |
+
remove_time_embedding_in_photo = False,
|
368 |
*args, **kwargs
|
369 |
):
|
370 |
super().__init__(*args, **kwargs)
|
371 |
+
assert attention_mode == "Temporal" or attention_mode == "Global"
|
372 |
|
373 |
self.attention_mode = attention_mode
|
374 |
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
375 |
|
376 |
self.block_size = block_size
|
377 |
self.grid = grid
|
378 |
+
self.remove_time_embedding_in_photo = remove_time_embedding_in_photo
|
379 |
self.pos_encoder = PositionalEncoding(
|
380 |
kwargs["query_dim"],
|
381 |
dropout=0.,
|
382 |
max_len=temporal_position_encoding_max_len
|
383 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") or (temporal_position_encoding and attention_mode == "Global") else None
|
384 |
|
385 |
def extra_repr(self):
|
386 |
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
|
|
392 |
# for add pos_encoder
|
393 |
_, before_d, _c = hidden_states.size()
|
394 |
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
395 |
+
|
396 |
+
if self.remove_time_embedding_in_photo:
|
397 |
+
if self.pos_encoder is not None and video_length > 1:
|
398 |
+
hidden_states = self.pos_encoder(hidden_states)
|
399 |
+
else:
|
400 |
+
if self.pos_encoder is not None:
|
401 |
+
hidden_states = self.pos_encoder(hidden_states)
|
402 |
|
403 |
if self.grid:
|
404 |
hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
|
|
|
409 |
else:
|
410 |
d = before_d
|
411 |
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
412 |
+
elif self.attention_mode == "Global":
|
413 |
+
# for add pos_encoder
|
414 |
+
_, d, _c = hidden_states.size()
|
415 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
416 |
+
if self.pos_encoder is not None:
|
417 |
+
hidden_states = self.pos_encoder(hidden_states)
|
418 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", f=video_length, d=d)
|
419 |
else:
|
420 |
raise NotImplementedError
|
421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
bs = 512
|
425 |
new_hidden_states = []
|
426 |
+
for i in range(0, hidden_states.shape[0], bs):
|
427 |
+
__hidden_states = super().forward(
|
428 |
+
hidden_states[i : i + bs],
|
429 |
+
encoder_hidden_states=encoder_hidden_states[i : i + bs],
|
430 |
+
attention_mask=attention_mask
|
431 |
+
)
|
432 |
+
new_hidden_states.append(__hidden_states)
|
|
|
|
|
|
|
|
|
|
|
433 |
hidden_states = torch.cat(new_hidden_states, dim = 0)
|
434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
if self.attention_mode == "Temporal":
|
436 |
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
437 |
if self.grid:
|
438 |
hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
|
439 |
hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
|
440 |
hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
|
441 |
+
elif self.attention_mode == "Global":
|
442 |
+
hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length, d=d)
|
443 |
|
444 |
return hidden_states
|
easyanimate/models/norm.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
def zero_module(module):
|
10 |
+
# Zero out the parameters of a module and return it.
|
11 |
+
for p in module.parameters():
|
12 |
+
p.detach().zero_()
|
13 |
+
return module
|
14 |
+
|
15 |
+
|
16 |
+
class FP32LayerNorm(nn.LayerNorm):
|
17 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
18 |
+
origin_dtype = inputs.dtype
|
19 |
+
if hasattr(self, 'weight') and self.weight is not None:
|
20 |
+
return F.layer_norm(
|
21 |
+
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
|
22 |
+
).to(origin_dtype)
|
23 |
+
else:
|
24 |
+
return F.layer_norm(
|
25 |
+
inputs.float(), self.normalized_shape, None, None, self.eps
|
26 |
+
).to(origin_dtype)
|
27 |
+
|
28 |
+
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
29 |
+
"""
|
30 |
+
For PixArt-Alpha.
|
31 |
+
|
32 |
+
Reference:
|
33 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.outdim = size_emb_dim
|
40 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
41 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
42 |
+
|
43 |
+
self.use_additional_conditions = use_additional_conditions
|
44 |
+
if use_additional_conditions:
|
45 |
+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
46 |
+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
47 |
+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
48 |
+
|
49 |
+
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
|
50 |
+
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
|
51 |
+
|
52 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
53 |
+
timesteps_proj = self.time_proj(timestep)
|
54 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
55 |
+
|
56 |
+
if self.use_additional_conditions:
|
57 |
+
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
58 |
+
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
59 |
+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
60 |
+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
61 |
+
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
62 |
+
else:
|
63 |
+
conditioning = timesteps_emb
|
64 |
+
|
65 |
+
return conditioning
|
66 |
+
|
67 |
+
class AdaLayerNormSingle(nn.Module):
|
68 |
+
r"""
|
69 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
70 |
+
|
71 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
72 |
+
|
73 |
+
Parameters:
|
74 |
+
embedding_dim (`int`): The size of each embedding vector.
|
75 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
82 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
83 |
+
)
|
84 |
+
|
85 |
+
self.silu = nn.SiLU()
|
86 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
87 |
+
|
88 |
+
def forward(
|
89 |
+
self,
|
90 |
+
timestep: torch.Tensor,
|
91 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
92 |
+
batch_size: Optional[int] = None,
|
93 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
94 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
95 |
+
# No modulation happening here.
|
96 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
97 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
easyanimate/models/patch.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
from typing import Optional
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
import torch.nn.init as init
|
7 |
-
import math
|
8 |
from einops import rearrange
|
9 |
from torch import nn
|
10 |
|
|
|
1 |
+
import math
|
2 |
from typing import Optional
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
import torch.nn.init as init
|
|
|
8 |
from einops import rearrange
|
9 |
from torch import nn
|
10 |
|
easyanimate/models/transformer3d.py
CHANGED
@@ -15,26 +15,30 @@ import json
|
|
15 |
import math
|
16 |
import os
|
17 |
from dataclasses import dataclass
|
18 |
-
from typing import Any, Dict, Optional
|
19 |
|
20 |
import numpy as np
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
import torch.nn.init as init
|
24 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
-
from diffusers.models.attention import BasicTransformerBlock
|
26 |
-
from diffusers.models.embeddings import PatchEmbed,
|
|
|
27 |
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
28 |
from diffusers.models.modeling_utils import ModelMixin
|
29 |
-
from diffusers.models.normalization import
|
30 |
-
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
|
|
|
|
|
31 |
from einops import rearrange
|
32 |
from torch import nn
|
33 |
-
from typing import Dict, Optional, Tuple
|
34 |
|
35 |
from .attention import (SelfAttentionTemporalTransformerBlock,
|
36 |
TemporalTransformerBlock)
|
37 |
-
from .
|
|
|
|
|
38 |
|
39 |
try:
|
40 |
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
@@ -48,77 +52,25 @@ def zero_module(module):
|
|
48 |
p.detach().zero_()
|
49 |
return module
|
50 |
|
51 |
-
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
52 |
-
"""
|
53 |
-
For PixArt-Alpha.
|
54 |
|
55 |
-
|
56 |
-
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
57 |
"""
|
|
|
58 |
|
59 |
-
|
60 |
-
super().__init__()
|
61 |
-
|
62 |
-
self.outdim = size_emb_dim
|
63 |
-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
64 |
-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
65 |
-
|
66 |
-
self.use_additional_conditions = use_additional_conditions
|
67 |
-
if use_additional_conditions:
|
68 |
-
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
69 |
-
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
70 |
-
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
71 |
-
|
72 |
-
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
|
73 |
-
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
|
74 |
-
|
75 |
-
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
76 |
-
timesteps_proj = self.time_proj(timestep)
|
77 |
-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
78 |
-
|
79 |
-
if self.use_additional_conditions:
|
80 |
-
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
81 |
-
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
82 |
-
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
83 |
-
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
84 |
-
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
85 |
-
else:
|
86 |
-
conditioning = timesteps_emb
|
87 |
-
|
88 |
-
return conditioning
|
89 |
-
|
90 |
-
class AdaLayerNormSingle(nn.Module):
|
91 |
-
r"""
|
92 |
-
Norm layer adaptive layer norm single (adaLN-single).
|
93 |
-
|
94 |
-
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
95 |
-
|
96 |
-
Parameters:
|
97 |
-
embedding_dim (`int`): The size of each embedding vector.
|
98 |
-
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
99 |
"""
|
100 |
|
101 |
-
def __init__(self,
|
102 |
super().__init__()
|
103 |
-
|
104 |
-
self.
|
105 |
-
|
106 |
-
)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
self,
|
113 |
-
timestep: torch.Tensor,
|
114 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
115 |
-
batch_size: Optional[int] = None,
|
116 |
-
hidden_dtype: Optional[torch.dtype] = None,
|
117 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
118 |
-
# No modulation happening here.
|
119 |
-
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
120 |
-
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
121 |
-
|
122 |
|
123 |
class TimePositionalEncoding(nn.Module):
|
124 |
def __init__(
|
@@ -229,9 +181,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
229 |
# motion module kwargs
|
230 |
motion_module_type = "VanillaGrid",
|
231 |
motion_module_kwargs = None,
|
|
|
|
|
232 |
|
233 |
# time position encoding
|
234 |
-
time_position_encoding_before_transformer = False
|
|
|
|
|
|
|
235 |
):
|
236 |
super().__init__()
|
237 |
self.use_linear_projection = use_linear_projection
|
@@ -320,6 +277,35 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
320 |
attention_type=attention_type,
|
321 |
motion_module_type=motion_module_type,
|
322 |
motion_module_kwargs=motion_module_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
)
|
324 |
for d in range(num_layers)
|
325 |
]
|
@@ -346,6 +332,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
346 |
kvcompression=False if d < 14 else True,
|
347 |
motion_module_type=motion_module_type,
|
348 |
motion_module_kwargs=motion_module_kwargs,
|
|
|
|
|
349 |
)
|
350 |
for d in range(num_layers)
|
351 |
]
|
@@ -369,6 +357,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
369 |
norm_elementwise_affine=norm_elementwise_affine,
|
370 |
norm_eps=norm_eps,
|
371 |
attention_type=attention_type,
|
|
|
|
|
372 |
)
|
373 |
for d in range(num_layers)
|
374 |
]
|
@@ -438,8 +428,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
438 |
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
439 |
|
440 |
self.caption_projection = None
|
|
|
441 |
if caption_channels is not None:
|
442 |
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
|
|
|
|
443 |
|
444 |
self.gradient_checkpointing = False
|
445 |
|
@@ -456,12 +449,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
456 |
hidden_states: torch.Tensor,
|
457 |
inpaint_latents: torch.Tensor = None,
|
458 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
459 |
timestep: Optional[torch.LongTensor] = None,
|
460 |
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
461 |
class_labels: Optional[torch.LongTensor] = None,
|
462 |
cross_attention_kwargs: Dict[str, Any] = None,
|
463 |
attention_mask: Optional[torch.Tensor] = None,
|
464 |
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
|
465 |
return_dict: bool = True,
|
466 |
):
|
467 |
"""
|
@@ -520,6 +515,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
520 |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
521 |
attention_mask = attention_mask.unsqueeze(1)
|
522 |
|
|
|
|
|
523 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
524 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
525 |
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
|
@@ -560,6 +557,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
560 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
561 |
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
skips = []
|
564 |
skip_index = 0
|
565 |
for index, block in enumerate(self.transformer_blocks):
|
@@ -590,7 +594,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
590 |
args = {
|
591 |
"basic": [],
|
592 |
"motionmodule": [video_length, height, width],
|
593 |
-
"
|
|
|
594 |
"kvcompression_motionmodule": [video_length, height, width],
|
595 |
}[self.basic_block_type]
|
596 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
@@ -609,7 +614,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
609 |
kwargs = {
|
610 |
"basic": {},
|
611 |
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
612 |
-
"
|
|
|
613 |
"kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
614 |
}[self.basic_block_type]
|
615 |
hidden_states = block(
|
|
|
15 |
import math
|
16 |
import os
|
17 |
from dataclasses import dataclass
|
18 |
+
from typing import Any, Dict, Optional, Tuple
|
19 |
|
20 |
import numpy as np
|
21 |
import torch
|
22 |
import torch.nn.functional as F
|
23 |
import torch.nn.init as init
|
24 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
+
from diffusers.models.attention import BasicTransformerBlock, FeedForward
|
26 |
+
from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
|
27 |
+
TimestepEmbedding, Timesteps)
|
28 |
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
29 |
from diffusers.models.modeling_utils import ModelMixin
|
30 |
+
from diffusers.models.normalization import AdaLayerNormContinuous
|
31 |
+
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
|
32 |
+
logging)
|
33 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
34 |
from einops import rearrange
|
35 |
from torch import nn
|
|
|
36 |
|
37 |
from .attention import (SelfAttentionTemporalTransformerBlock,
|
38 |
TemporalTransformerBlock)
|
39 |
+
from .norm import AdaLayerNormSingle
|
40 |
+
from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D,
|
41 |
+
TemporalUpsampler3D, UnPatch1D)
|
42 |
|
43 |
try:
|
44 |
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
|
|
52 |
p.detach().zero_()
|
53 |
return module
|
54 |
|
|
|
|
|
|
|
55 |
|
56 |
+
class CLIPProjection(nn.Module):
|
|
|
57 |
"""
|
58 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
59 |
|
60 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
"""
|
62 |
|
63 |
+
def __init__(self, in_features, hidden_size, num_tokens=120):
|
64 |
super().__init__()
|
65 |
+
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
66 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
68 |
+
self.linear_2 = zero_module(self.linear_2)
|
69 |
+
def forward(self, caption):
|
70 |
+
hidden_states = self.linear_1(caption)
|
71 |
+
hidden_states = self.act_1(hidden_states)
|
72 |
+
hidden_states = self.linear_2(hidden_states)
|
73 |
+
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
class TimePositionalEncoding(nn.Module):
|
76 |
def __init__(
|
|
|
181 |
# motion module kwargs
|
182 |
motion_module_type = "VanillaGrid",
|
183 |
motion_module_kwargs = None,
|
184 |
+
motion_module_kwargs_odd = None,
|
185 |
+
motion_module_kwargs_even = None,
|
186 |
|
187 |
# time position encoding
|
188 |
+
time_position_encoding_before_transformer = False,
|
189 |
+
|
190 |
+
qk_norm = False,
|
191 |
+
after_norm = False,
|
192 |
):
|
193 |
super().__init__()
|
194 |
self.use_linear_projection = use_linear_projection
|
|
|
277 |
attention_type=attention_type,
|
278 |
motion_module_type=motion_module_type,
|
279 |
motion_module_kwargs=motion_module_kwargs,
|
280 |
+
qk_norm=qk_norm,
|
281 |
+
after_norm=after_norm,
|
282 |
+
)
|
283 |
+
for d in range(num_layers)
|
284 |
+
]
|
285 |
+
)
|
286 |
+
elif self.basic_block_type == "global_motionmodule":
|
287 |
+
self.transformer_blocks = nn.ModuleList(
|
288 |
+
[
|
289 |
+
TemporalTransformerBlock(
|
290 |
+
inner_dim,
|
291 |
+
num_attention_heads,
|
292 |
+
attention_head_dim,
|
293 |
+
dropout=dropout,
|
294 |
+
cross_attention_dim=cross_attention_dim,
|
295 |
+
activation_fn=activation_fn,
|
296 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
297 |
+
attention_bias=attention_bias,
|
298 |
+
only_cross_attention=only_cross_attention,
|
299 |
+
double_self_attention=double_self_attention,
|
300 |
+
upcast_attention=upcast_attention,
|
301 |
+
norm_type=norm_type,
|
302 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
303 |
+
norm_eps=norm_eps,
|
304 |
+
attention_type=attention_type,
|
305 |
+
motion_module_type=motion_module_type,
|
306 |
+
motion_module_kwargs=motion_module_kwargs_even if d % 2 == 0 else motion_module_kwargs_odd,
|
307 |
+
qk_norm=qk_norm,
|
308 |
+
after_norm=after_norm,
|
309 |
)
|
310 |
for d in range(num_layers)
|
311 |
]
|
|
|
332 |
kvcompression=False if d < 14 else True,
|
333 |
motion_module_type=motion_module_type,
|
334 |
motion_module_kwargs=motion_module_kwargs,
|
335 |
+
qk_norm=qk_norm,
|
336 |
+
after_norm=after_norm,
|
337 |
)
|
338 |
for d in range(num_layers)
|
339 |
]
|
|
|
357 |
norm_elementwise_affine=norm_elementwise_affine,
|
358 |
norm_eps=norm_eps,
|
359 |
attention_type=attention_type,
|
360 |
+
qk_norm=qk_norm,
|
361 |
+
after_norm=after_norm,
|
362 |
)
|
363 |
for d in range(num_layers)
|
364 |
]
|
|
|
428 |
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
429 |
|
430 |
self.caption_projection = None
|
431 |
+
self.clip_projection = None
|
432 |
if caption_channels is not None:
|
433 |
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
434 |
+
if in_channels == 12:
|
435 |
+
self.clip_projection = CLIPProjection(in_features=768, hidden_size=inner_dim * 8)
|
436 |
|
437 |
self.gradient_checkpointing = False
|
438 |
|
|
|
449 |
hidden_states: torch.Tensor,
|
450 |
inpaint_latents: torch.Tensor = None,
|
451 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
452 |
+
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
453 |
timestep: Optional[torch.LongTensor] = None,
|
454 |
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
455 |
class_labels: Optional[torch.LongTensor] = None,
|
456 |
cross_attention_kwargs: Dict[str, Any] = None,
|
457 |
attention_mask: Optional[torch.Tensor] = None,
|
458 |
encoder_attention_mask: Optional[torch.Tensor] = None,
|
459 |
+
clip_attention_mask: Optional[torch.Tensor] = None,
|
460 |
return_dict: bool = True,
|
461 |
):
|
462 |
"""
|
|
|
515 |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
516 |
attention_mask = attention_mask.unsqueeze(1)
|
517 |
|
518 |
+
if clip_attention_mask is not None:
|
519 |
+
encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1)
|
520 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
521 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
522 |
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
|
|
|
557 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
558 |
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
559 |
|
560 |
+
if clip_encoder_hidden_states is not None and encoder_hidden_states is not None:
|
561 |
+
batch_size = hidden_states.shape[0]
|
562 |
+
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
|
563 |
+
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
564 |
+
|
565 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, clip_encoder_hidden_states], dim = 1)
|
566 |
+
|
567 |
skips = []
|
568 |
skip_index = 0
|
569 |
for index, block in enumerate(self.transformer_blocks):
|
|
|
594 |
args = {
|
595 |
"basic": [],
|
596 |
"motionmodule": [video_length, height, width],
|
597 |
+
"global_motionmodule": [video_length, height, width],
|
598 |
+
"selfattentiontemporal": [],
|
599 |
"kvcompression_motionmodule": [video_length, height, width],
|
600 |
}[self.basic_block_type]
|
601 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
614 |
kwargs = {
|
615 |
"basic": {},
|
616 |
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
617 |
+
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
618 |
+
"selfattentiontemporal": {},
|
619 |
"kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
620 |
}[self.basic_block_type]
|
621 |
hidden_states = block(
|
easyanimate/pipeline/pipeline_easyanimate.py
CHANGED
@@ -578,7 +578,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
578 |
|
579 |
def decode_latents(self, latents):
|
580 |
video_length = latents.shape[2]
|
581 |
-
latents = 1 /
|
582 |
if self.vae.quant_conv.weight.ndim==5:
|
583 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
584 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
|
|
578 |
|
579 |
def decode_latents(self, latents):
|
580 |
video_length = latents.shape[2]
|
581 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
582 |
if self.vae.quant_conv.weight.ndim==5:
|
583 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
584 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
easyanimate/pipeline/pipeline_easyanimate_inpaint.py
CHANGED
@@ -15,13 +15,16 @@
|
|
15 |
import html
|
16 |
import inspect
|
17 |
import re
|
|
|
18 |
import copy
|
19 |
import urllib.parse as ul
|
20 |
from dataclasses import dataclass
|
|
|
21 |
from typing import Callable, List, Optional, Tuple, Union
|
22 |
|
23 |
import numpy as np
|
24 |
import torch
|
|
|
25 |
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
26 |
from diffusers.image_processor import VaeImageProcessor
|
27 |
from diffusers.models import AutoencoderKL
|
@@ -33,6 +36,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
|
33 |
from einops import rearrange
|
34 |
from tqdm import tqdm
|
35 |
from transformers import T5EncoderModel, T5Tokenizer
|
|
|
36 |
|
37 |
from ..models.transformer3d import Transformer3DModel
|
38 |
|
@@ -109,11 +113,15 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
109 |
vae: AutoencoderKL,
|
110 |
transformer: Transformer3DModel,
|
111 |
scheduler: DPMSolverMultistepScheduler,
|
|
|
|
|
112 |
):
|
113 |
super().__init__()
|
114 |
|
115 |
self.register_modules(
|
116 |
-
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
|
|
|
|
117 |
)
|
118 |
|
119 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
@@ -503,41 +511,64 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
503 |
return_video_latents=False,
|
504 |
):
|
505 |
if self.vae.quant_conv.weight.ndim==5:
|
506 |
-
|
|
|
|
|
507 |
else:
|
508 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
509 |
if isinstance(generator, list) and len(generator) != batch_size:
|
510 |
raise ValueError(
|
511 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
512 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
513 |
)
|
514 |
-
|
515 |
if return_video_latents or (latents is None and not is_strength_max):
|
516 |
-
video = video.to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
-
if video.shape[1] == 4:
|
519 |
-
video_latents = video
|
520 |
else:
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
526 |
|
527 |
if latents is None:
|
528 |
-
|
529 |
-
|
530 |
-
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
531 |
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
532 |
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
|
|
|
|
533 |
else:
|
534 |
noise = latents.to(device)
|
535 |
-
|
536 |
-
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
537 |
-
latents = latents.to(device)
|
538 |
|
539 |
# scale the initial noise by the standard deviation required by the scheduler
|
540 |
-
latents = latents * self.scheduler.init_noise_sigma
|
541 |
outputs = (latents,)
|
542 |
|
543 |
if return_noise:
|
@@ -548,33 +579,61 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
548 |
|
549 |
return outputs
|
550 |
|
551 |
-
def
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
for i in range(0, latents.shape[2], mini_batch_decoder):
|
559 |
with torch.no_grad():
|
560 |
start_index = i
|
561 |
end_index = i + mini_batch_decoder
|
562 |
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
video = video.clamp(-1, 1)
|
|
|
575 |
else:
|
576 |
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
577 |
-
# video = self.vae.decode(latents).sample
|
578 |
video = []
|
579 |
for frame_idx in tqdm(range(latents.shape[0])):
|
580 |
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
@@ -599,6 +658,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
599 |
|
600 |
return image_latents
|
601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
def prepare_mask_latents(
|
603 |
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
604 |
):
|
@@ -610,19 +679,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
610 |
mask = mask.to(device=device, dtype=self.vae.dtype)
|
611 |
if self.vae.quant_conv.weight.ndim==5:
|
612 |
bs = 1
|
|
|
613 |
new_mask = []
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
for j in range(0, mask.shape[2], mini_batch):
|
618 |
-
mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
|
619 |
mask_bs = self.vae.encode(mask_bs)[0]
|
620 |
mask_bs = mask_bs.sample()
|
621 |
-
|
622 |
-
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
mask = torch.cat(new_mask, dim = 0)
|
625 |
-
mask = mask *
|
626 |
|
627 |
else:
|
628 |
if mask.shape[1] == 4:
|
@@ -636,19 +712,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
636 |
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
637 |
if self.vae.quant_conv.weight.ndim==5:
|
638 |
bs = 1
|
|
|
639 |
new_mask_pixel_values = []
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
for j in range(0, masked_image.shape[2], mini_batch):
|
644 |
-
mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
|
645 |
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
646 |
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
647 |
-
|
648 |
-
|
649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
651 |
-
masked_image_latents = masked_image_latents *
|
652 |
|
653 |
else:
|
654 |
if masked_image.shape[1] == 4:
|
@@ -693,7 +776,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
693 |
callback_steps: int = 1,
|
694 |
clean_caption: bool = True,
|
695 |
mask_feature: bool = True,
|
696 |
-
max_sequence_length: int = 120
|
|
|
|
|
697 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
698 |
"""
|
699 |
Function invoked when calling the pipeline for generation.
|
@@ -767,6 +852,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
767 |
# 1. Check inputs. Raise error if not correct
|
768 |
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
769 |
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
|
|
|
|
770 |
|
771 |
# 2. Default height and width to transformer
|
772 |
if prompt is not None and isinstance(prompt, str):
|
@@ -806,11 +893,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
806 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
807 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
808 |
|
809 |
-
# 4.
|
810 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
811 |
-
timesteps = self.
|
|
|
|
|
812 |
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
813 |
-
latent_timestep = timesteps[:1].repeat(batch_size)
|
814 |
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
815 |
is_strength_max = strength == 1.0
|
816 |
|
@@ -825,7 +914,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
825 |
# Prepare latent variables
|
826 |
num_channels_latents = self.vae.config.latent_channels
|
827 |
num_channels_transformer = self.transformer.config.in_channels
|
828 |
-
return_image_latents = num_channels_transformer == 4
|
829 |
|
830 |
# 5. Prepare latents.
|
831 |
latents_outputs = self.prepare_latents(
|
@@ -857,30 +946,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
857 |
mask_condition = mask_condition.to(dtype=torch.float32)
|
858 |
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
859 |
|
860 |
-
if
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
else:
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
876 |
else:
|
877 |
-
|
878 |
-
|
879 |
|
880 |
# Check that sizes of mask, masked image and latents match
|
881 |
if num_channels_transformer == 12:
|
882 |
# default case for runwayml/stable-diffusion-inpainting
|
883 |
-
num_channels_mask =
|
884 |
num_channels_masked_image = masked_video_latents.shape[1]
|
885 |
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
|
886 |
raise ValueError(
|
@@ -890,12 +1032,12 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
890 |
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
891 |
" `pipeline.transformer` or your `mask_image` or `image` input."
|
892 |
)
|
893 |
-
elif num_channels_transformer
|
894 |
raise ValueError(
|
895 |
f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
|
896 |
)
|
897 |
|
898 |
-
#
|
899 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
900 |
|
901 |
# 6.1 Prepare micro-conditions.
|
@@ -912,21 +1054,25 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
912 |
|
913 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
914 |
|
915 |
-
|
916 |
-
|
|
|
917 |
|
|
|
|
|
|
|
918 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
919 |
for i, t in enumerate(timesteps):
|
920 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
921 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
922 |
|
923 |
-
if
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
current_timestep = t
|
931 |
if not torch.is_tensor(current_timestep):
|
932 |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
@@ -949,7 +1095,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
949 |
encoder_attention_mask=prompt_attention_mask,
|
950 |
timestep=current_timestep,
|
951 |
added_cond_kwargs=added_cond_kwargs,
|
952 |
-
inpaint_latents=inpaint_latents
|
|
|
|
|
953 |
return_dict=False,
|
954 |
)[0]
|
955 |
|
@@ -964,6 +1112,17 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
964 |
# compute previous image: x_t -> x_t-1
|
965 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
966 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
967 |
# call the callback, if provided
|
968 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
969 |
progress_bar.update()
|
@@ -971,9 +1130,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
|
971 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
972 |
callback(step_idx, t, latents)
|
973 |
|
|
|
|
|
|
|
|
|
974 |
# Post-processing
|
975 |
video = self.decode_latents(latents)
|
976 |
-
|
|
|
|
|
|
|
977 |
# Convert to tensor
|
978 |
if output_type == "latent":
|
979 |
video = torch.from_numpy(video)
|
|
|
15 |
import html
|
16 |
import inspect
|
17 |
import re
|
18 |
+
import gc
|
19 |
import copy
|
20 |
import urllib.parse as ul
|
21 |
from dataclasses import dataclass
|
22 |
+
from PIL import Image
|
23 |
from typing import Callable, List, Optional, Tuple, Union
|
24 |
|
25 |
import numpy as np
|
26 |
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
29 |
from diffusers.image_processor import VaeImageProcessor
|
30 |
from diffusers.models import AutoencoderKL
|
|
|
36 |
from einops import rearrange
|
37 |
from tqdm import tqdm
|
38 |
from transformers import T5EncoderModel, T5Tokenizer
|
39 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
40 |
|
41 |
from ..models.transformer3d import Transformer3DModel
|
42 |
|
|
|
113 |
vae: AutoencoderKL,
|
114 |
transformer: Transformer3DModel,
|
115 |
scheduler: DPMSolverMultistepScheduler,
|
116 |
+
clip_image_processor:CLIPImageProcessor = None,
|
117 |
+
clip_image_encoder:CLIPVisionModelWithProjection = None,
|
118 |
):
|
119 |
super().__init__()
|
120 |
|
121 |
self.register_modules(
|
122 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
123 |
+
scheduler=scheduler,
|
124 |
+
clip_image_processor=clip_image_processor, clip_image_encoder=clip_image_encoder,
|
125 |
)
|
126 |
|
127 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
|
511 |
return_video_latents=False,
|
512 |
):
|
513 |
if self.vae.quant_conv.weight.ndim==5:
|
514 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
515 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
516 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
517 |
else:
|
518 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
519 |
+
|
520 |
if isinstance(generator, list) and len(generator) != batch_size:
|
521 |
raise ValueError(
|
522 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
523 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
524 |
)
|
525 |
+
|
526 |
if return_video_latents or (latents is None and not is_strength_max):
|
527 |
+
video = video.to(device=device, dtype=self.vae.dtype)
|
528 |
+
if self.vae.quant_conv.weight.ndim==5:
|
529 |
+
bs = 1
|
530 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
531 |
+
new_video = []
|
532 |
+
if self.vae.slice_compression_vae:
|
533 |
+
for i in range(0, video.shape[0], bs):
|
534 |
+
video_bs = video[i : i + bs]
|
535 |
+
video_bs = self.vae.encode(video_bs)[0]
|
536 |
+
video_bs = video_bs.sample()
|
537 |
+
new_video.append(video_bs)
|
538 |
+
else:
|
539 |
+
for i in range(0, video.shape[0], bs):
|
540 |
+
new_video_mini_batch = []
|
541 |
+
for j in range(0, video.shape[2], mini_batch_encoder):
|
542 |
+
video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
543 |
+
video_bs = self.vae.encode(video_bs)[0]
|
544 |
+
video_bs = video_bs.sample()
|
545 |
+
new_video_mini_batch.append(video_bs)
|
546 |
+
new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2)
|
547 |
+
new_video.append(new_video_mini_batch)
|
548 |
+
video = torch.cat(new_video, dim = 0)
|
549 |
+
video = video * self.vae.config.scaling_factor
|
550 |
|
|
|
|
|
551 |
else:
|
552 |
+
if video.shape[1] == 4:
|
553 |
+
video = video
|
554 |
+
else:
|
555 |
+
video_length = video.shape[2]
|
556 |
+
video = rearrange(video, "b c f h w -> (b f) c h w")
|
557 |
+
video = self._encode_vae_image(video, generator=generator)
|
558 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
559 |
+
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
|
560 |
|
561 |
if latents is None:
|
562 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
563 |
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
564 |
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
565 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
566 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
567 |
else:
|
568 |
noise = latents.to(device)
|
569 |
+
latents = noise * self.scheduler.init_noise_sigma
|
|
|
|
|
570 |
|
571 |
# scale the initial noise by the standard deviation required by the scheduler
|
|
|
572 |
outputs = (latents,)
|
573 |
|
574 |
if return_noise:
|
|
|
579 |
|
580 |
return outputs
|
581 |
|
582 |
+
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
583 |
+
if video.size()[2] <= mini_batch_encoder:
|
584 |
+
return video
|
585 |
+
prefix_index_before = mini_batch_encoder // 2
|
586 |
+
prefix_index_after = mini_batch_encoder - prefix_index_before
|
587 |
+
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
588 |
+
|
589 |
+
if self.vae.slice_compression_vae:
|
590 |
+
latents = self.vae.encode(pixel_values)[0]
|
591 |
+
latents = latents.sample()
|
592 |
+
else:
|
593 |
+
new_pixel_values = []
|
594 |
+
for i in range(0, pixel_values.shape[2], mini_batch_encoder):
|
595 |
+
with torch.no_grad():
|
596 |
+
pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
|
597 |
+
pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
|
598 |
+
pixel_values_bs = pixel_values_bs.sample()
|
599 |
+
new_pixel_values.append(pixel_values_bs)
|
600 |
+
latents = torch.cat(new_pixel_values, dim = 2)
|
601 |
+
|
602 |
+
if self.vae.slice_compression_vae:
|
603 |
+
middle_video = self.vae.decode(latents)[0]
|
604 |
+
else:
|
605 |
+
middle_video = []
|
606 |
for i in range(0, latents.shape[2], mini_batch_decoder):
|
607 |
with torch.no_grad():
|
608 |
start_index = i
|
609 |
end_index = i + mini_batch_decoder
|
610 |
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
611 |
+
middle_video.append(latents_bs)
|
612 |
+
middle_video = torch.cat(middle_video, 2)
|
613 |
+
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
614 |
+
return video
|
615 |
+
|
616 |
+
def decode_latents(self, latents):
|
617 |
+
video_length = latents.shape[2]
|
618 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
619 |
+
if self.vae.quant_conv.weight.ndim==5:
|
620 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
621 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
622 |
+
if self.vae.slice_compression_vae:
|
623 |
+
video = self.vae.decode(latents)[0]
|
624 |
+
else:
|
625 |
+
video = []
|
626 |
+
for i in range(0, latents.shape[2], mini_batch_decoder):
|
627 |
+
with torch.no_grad():
|
628 |
+
start_index = i
|
629 |
+
end_index = i + mini_batch_decoder
|
630 |
+
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
631 |
+
video.append(latents_bs)
|
632 |
+
video = torch.cat(video, 2)
|
633 |
video = video.clamp(-1, 1)
|
634 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
635 |
else:
|
636 |
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
|
|
637 |
video = []
|
638 |
for frame_idx in tqdm(range(latents.shape[0])):
|
639 |
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
|
|
658 |
|
659 |
return image_latents
|
660 |
|
661 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
662 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
663 |
+
# get the original timestep using init_timestep
|
664 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
665 |
+
|
666 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
667 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
668 |
+
|
669 |
+
return timesteps, num_inference_steps - t_start
|
670 |
+
|
671 |
def prepare_mask_latents(
|
672 |
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
673 |
):
|
|
|
679 |
mask = mask.to(device=device, dtype=self.vae.dtype)
|
680 |
if self.vae.quant_conv.weight.ndim==5:
|
681 |
bs = 1
|
682 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
683 |
new_mask = []
|
684 |
+
if self.vae.slice_compression_vae:
|
685 |
+
for i in range(0, mask.shape[0], bs):
|
686 |
+
mask_bs = mask[i : i + bs]
|
|
|
|
|
687 |
mask_bs = self.vae.encode(mask_bs)[0]
|
688 |
mask_bs = mask_bs.sample()
|
689 |
+
new_mask.append(mask_bs)
|
690 |
+
else:
|
691 |
+
for i in range(0, mask.shape[0], bs):
|
692 |
+
new_mask_mini_batch = []
|
693 |
+
for j in range(0, mask.shape[2], mini_batch_encoder):
|
694 |
+
mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
695 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
696 |
+
mask_bs = mask_bs.sample()
|
697 |
+
new_mask_mini_batch.append(mask_bs)
|
698 |
+
new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
|
699 |
+
new_mask.append(new_mask_mini_batch)
|
700 |
mask = torch.cat(new_mask, dim = 0)
|
701 |
+
mask = mask * self.vae.config.scaling_factor
|
702 |
|
703 |
else:
|
704 |
if mask.shape[1] == 4:
|
|
|
712 |
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
713 |
if self.vae.quant_conv.weight.ndim==5:
|
714 |
bs = 1
|
715 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
716 |
new_mask_pixel_values = []
|
717 |
+
if self.vae.slice_compression_vae:
|
718 |
+
for i in range(0, masked_image.shape[0], bs):
|
719 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
|
|
|
|
720 |
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
721 |
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
722 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
723 |
+
else:
|
724 |
+
for i in range(0, masked_image.shape[0], bs):
|
725 |
+
new_mask_pixel_values_mini_batch = []
|
726 |
+
for j in range(0, masked_image.shape[2], mini_batch_encoder):
|
727 |
+
mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :]
|
728 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
729 |
+
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
730 |
+
new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
|
731 |
+
new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
|
732 |
+
new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
|
733 |
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
734 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
735 |
|
736 |
else:
|
737 |
if masked_image.shape[1] == 4:
|
|
|
776 |
callback_steps: int = 1,
|
777 |
clean_caption: bool = True,
|
778 |
mask_feature: bool = True,
|
779 |
+
max_sequence_length: int = 120,
|
780 |
+
clip_image: Image = None,
|
781 |
+
clip_apply_ratio: float = 0.50,
|
782 |
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
783 |
"""
|
784 |
Function invoked when calling the pipeline for generation.
|
|
|
852 |
# 1. Check inputs. Raise error if not correct
|
853 |
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
854 |
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
855 |
+
height = int(height // 16 * 16)
|
856 |
+
width = int(width // 16 * 16)
|
857 |
|
858 |
# 2. Default height and width to transformer
|
859 |
if prompt is not None and isinstance(prompt, str):
|
|
|
893 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
894 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
895 |
|
896 |
+
# 4. set timesteps
|
897 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
898 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
899 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
900 |
+
)
|
901 |
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
902 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
903 |
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
904 |
is_strength_max = strength == 1.0
|
905 |
|
|
|
914 |
# Prepare latent variables
|
915 |
num_channels_latents = self.vae.config.latent_channels
|
916 |
num_channels_transformer = self.transformer.config.in_channels
|
917 |
+
return_image_latents = True # num_channels_transformer == 4
|
918 |
|
919 |
# 5. Prepare latents.
|
920 |
latents_outputs = self.prepare_latents(
|
|
|
946 |
mask_condition = mask_condition.to(dtype=torch.float32)
|
947 |
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
948 |
|
949 |
+
if num_channels_transformer == 12:
|
950 |
+
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
|
951 |
+
if masked_video_latents is None:
|
952 |
+
masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
|
953 |
+
else:
|
954 |
+
masked_video = masked_video_latents
|
955 |
+
|
956 |
+
mask_latents, masked_video_latents = self.prepare_mask_latents(
|
957 |
+
mask_condition_tile,
|
958 |
+
masked_video,
|
959 |
+
batch_size,
|
960 |
+
height,
|
961 |
+
width,
|
962 |
+
prompt_embeds.dtype,
|
963 |
+
device,
|
964 |
+
generator,
|
965 |
+
do_classifier_free_guidance,
|
966 |
+
)
|
967 |
+
mask = torch.tile(mask_condition, [1, num_channels_transformer // 3, 1, 1, 1])
|
968 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
969 |
+
|
970 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
971 |
+
masked_video_latents_input = (
|
972 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
973 |
+
)
|
974 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
975 |
else:
|
976 |
+
mask = torch.tile(mask_condition, [1, num_channels_transformer, 1, 1, 1])
|
977 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
978 |
+
|
979 |
+
inpaint_latents = None
|
980 |
+
else:
|
981 |
+
if num_channels_transformer == 12:
|
982 |
+
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
983 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
984 |
+
|
985 |
+
mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
986 |
+
masked_video_latents_input = (
|
987 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
988 |
+
)
|
989 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
990 |
+
else:
|
991 |
+
mask = torch.zeros_like(init_video[:, :1])
|
992 |
+
mask = torch.tile(mask, [1, num_channels_transformer, 1, 1, 1])
|
993 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
994 |
+
|
995 |
+
inpaint_latents = None
|
996 |
+
|
997 |
+
if clip_image is not None:
|
998 |
+
inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
|
999 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
|
1000 |
+
clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds
|
1001 |
+
clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
|
1002 |
+
|
1003 |
+
clip_attention_mask = torch.ones([batch_size, 8]).to(latents.device, dtype=latents.dtype)
|
1004 |
+
clip_attention_mask_neg = torch.zeros([batch_size, 8]).to(latents.device, dtype=latents.dtype)
|
1005 |
+
|
1006 |
+
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if do_classifier_free_guidance else clip_encoder_hidden_states
|
1007 |
+
clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if do_classifier_free_guidance else clip_attention_mask
|
1008 |
+
|
1009 |
+
elif clip_image is None and num_channels_transformer == 12:
|
1010 |
+
clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
|
1011 |
+
|
1012 |
+
clip_attention_mask = torch.zeros([batch_size, 8])
|
1013 |
+
clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
|
1014 |
+
|
1015 |
+
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if do_classifier_free_guidance else clip_encoder_hidden_states
|
1016 |
+
clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if do_classifier_free_guidance else clip_attention_mask
|
1017 |
+
|
1018 |
else:
|
1019 |
+
clip_encoder_hidden_states_input = None
|
1020 |
+
clip_attention_mask_input = None
|
1021 |
|
1022 |
# Check that sizes of mask, masked image and latents match
|
1023 |
if num_channels_transformer == 12:
|
1024 |
# default case for runwayml/stable-diffusion-inpainting
|
1025 |
+
num_channels_mask = mask_latents.shape[1]
|
1026 |
num_channels_masked_image = masked_video_latents.shape[1]
|
1027 |
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
|
1028 |
raise ValueError(
|
|
|
1032 |
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
1033 |
" `pipeline.transformer` or your `mask_image` or `image` input."
|
1034 |
)
|
1035 |
+
elif num_channels_transformer != 4:
|
1036 |
raise ValueError(
|
1037 |
f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
|
1038 |
)
|
1039 |
|
1040 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1041 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1042 |
|
1043 |
# 6.1 Prepare micro-conditions.
|
|
|
1054 |
|
1055 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
1056 |
|
1057 |
+
gc.collect()
|
1058 |
+
torch.cuda.empty_cache()
|
1059 |
+
torch.cuda.ipc_collect()
|
1060 |
|
1061 |
+
# 10. Denoising loop
|
1062 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1063 |
+
self._num_timesteps = len(timesteps)
|
1064 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1065 |
for i, t in enumerate(timesteps):
|
1066 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1067 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1068 |
|
1069 |
+
if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
|
1070 |
+
clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
|
1071 |
+
clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
|
1072 |
+
else:
|
1073 |
+
clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
|
1074 |
+
clip_attention_mask_actual_input = clip_attention_mask_input
|
1075 |
+
|
1076 |
current_timestep = t
|
1077 |
if not torch.is_tensor(current_timestep):
|
1078 |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
|
|
1095 |
encoder_attention_mask=prompt_attention_mask,
|
1096 |
timestep=current_timestep,
|
1097 |
added_cond_kwargs=added_cond_kwargs,
|
1098 |
+
inpaint_latents=inpaint_latents,
|
1099 |
+
clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
|
1100 |
+
clip_attention_mask=clip_attention_mask_actual_input,
|
1101 |
return_dict=False,
|
1102 |
)[0]
|
1103 |
|
|
|
1112 |
# compute previous image: x_t -> x_t-1
|
1113 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1114 |
|
1115 |
+
if num_channels_transformer == 4:
|
1116 |
+
init_latents_proper = image_latents
|
1117 |
+
init_mask = mask
|
1118 |
+
if i < len(timesteps) - 1:
|
1119 |
+
noise_timestep = timesteps[i + 1]
|
1120 |
+
init_latents_proper = self.scheduler.add_noise(
|
1121 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
1125 |
+
|
1126 |
# call the callback, if provided
|
1127 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1128 |
progress_bar.update()
|
|
|
1130 |
step_idx = i // getattr(self.scheduler, "order", 1)
|
1131 |
callback(step_idx, t, latents)
|
1132 |
|
1133 |
+
gc.collect()
|
1134 |
+
torch.cuda.empty_cache()
|
1135 |
+
torch.cuda.ipc_collect()
|
1136 |
+
|
1137 |
# Post-processing
|
1138 |
video = self.decode_latents(latents)
|
1139 |
+
|
1140 |
+
gc.collect()
|
1141 |
+
torch.cuda.empty_cache()
|
1142 |
+
torch.cuda.ipc_collect()
|
1143 |
# Convert to tensor
|
1144 |
if output_type == "latent":
|
1145 |
video = torch.from_numpy(video)
|
easyanimate/ui/ui.py
CHANGED
@@ -1,35 +1,40 @@
|
|
1 |
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
2 |
"""
|
|
|
3 |
import gc
|
4 |
import json
|
5 |
import os
|
6 |
import random
|
7 |
-
import base64
|
8 |
-
import requests
|
9 |
-
import pkg_resources
|
10 |
from datetime import datetime
|
11 |
from glob import glob
|
12 |
|
13 |
import gradio as gr
|
14 |
-
import torch
|
15 |
import numpy as np
|
|
|
|
|
|
|
16 |
from diffusers import (AutoencoderKL, DDIMScheduler,
|
17 |
DPMSolverMultistepScheduler,
|
18 |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
19 |
PNDMScheduler)
|
20 |
-
from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
|
21 |
from diffusers.utils.import_utils import is_xformers_available
|
22 |
from omegaconf import OmegaConf
|
|
|
23 |
from safetensors import safe_open
|
24 |
-
from transformers import
|
|
|
25 |
|
|
|
|
|
26 |
from easyanimate.models.transformer3d import Transformer3DModel
|
27 |
from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
|
|
|
|
|
28 |
from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
|
29 |
-
from easyanimate.utils.utils import
|
30 |
-
|
|
|
31 |
|
32 |
-
sample_idx = 0
|
33 |
scheduler_dict = {
|
34 |
"Euler": EulerDiscreteScheduler,
|
35 |
"Euler A": EulerAncestralDiscreteScheduler,
|
@@ -60,8 +65,8 @@ class EasyAnimateController:
|
|
60 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
61 |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
62 |
self.savedir_sample = os.path.join(self.savedir, "sample")
|
63 |
-
self.edition = "
|
64 |
-
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "
|
65 |
os.makedirs(self.savedir, exist_ok=True)
|
66 |
|
67 |
self.diffusion_transformer_list = []
|
@@ -85,14 +90,14 @@ class EasyAnimateController:
|
|
85 |
self.weight_dtype = torch.bfloat16
|
86 |
|
87 |
def refresh_diffusion_transformer(self):
|
88 |
-
self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
|
89 |
|
90 |
def refresh_motion_module(self):
|
91 |
-
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
|
92 |
self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
|
93 |
|
94 |
def refresh_personalized_model(self):
|
95 |
-
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
|
96 |
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
97 |
|
98 |
def update_edition(self, edition):
|
@@ -100,19 +105,24 @@ class EasyAnimateController:
|
|
100 |
self.edition = edition
|
101 |
if edition == "v1":
|
102 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
|
103 |
-
return gr.
|
104 |
-
gr.update(
|
105 |
gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
|
106 |
-
|
107 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
|
108 |
-
return gr.
|
109 |
-
gr.update(
|
110 |
gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
113 |
print("Update diffusion transformer")
|
114 |
if diffusion_transformer_dropdown == "none":
|
115 |
-
return gr.
|
116 |
if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
|
117 |
Choosen_AutoencoderKL = AutoencoderKLMagvit
|
118 |
else:
|
@@ -130,25 +140,42 @@ class EasyAnimateController:
|
|
130 |
self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
|
131 |
|
132 |
# Get pipeline
|
133 |
-
self.
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
print("Update diffusion transformer done")
|
142 |
-
return gr.
|
143 |
|
144 |
def update_motion_module(self, motion_module_dropdown):
|
145 |
self.motion_module_path = motion_module_dropdown
|
146 |
print("Update motion module")
|
147 |
if motion_module_dropdown == "none":
|
148 |
-
return gr.
|
149 |
if self.transformer is None:
|
150 |
gr.Info(f"Please select a pretrained model path.")
|
151 |
-
return gr.
|
152 |
else:
|
153 |
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
154 |
if motion_module_dropdown.endswith(".safetensors"):
|
@@ -160,16 +187,16 @@ class EasyAnimateController:
|
|
160 |
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
161 |
missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
|
162 |
print("Update motion module done.")
|
163 |
-
return gr.
|
164 |
|
165 |
def update_base_model(self, base_model_dropdown):
|
166 |
self.base_model_path = base_model_dropdown
|
167 |
print("Update base model")
|
168 |
if base_model_dropdown == "none":
|
169 |
-
return gr.
|
170 |
if self.transformer is None:
|
171 |
gr.Info(f"Please select a pretrained model path.")
|
172 |
-
return gr.
|
173 |
else:
|
174 |
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
175 |
base_model_state_dict = {}
|
@@ -178,16 +205,16 @@ class EasyAnimateController:
|
|
178 |
base_model_state_dict[key] = f.get_tensor(key)
|
179 |
self.transformer.load_state_dict(base_model_state_dict, strict=False)
|
180 |
print("Update base done")
|
181 |
-
return gr.
|
182 |
|
183 |
def update_lora_model(self, lora_model_dropdown):
|
184 |
print("Update lora model")
|
185 |
if lora_model_dropdown == "none":
|
186 |
self.lora_model_path = "none"
|
187 |
-
return gr.
|
188 |
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
189 |
self.lora_model_path = lora_model_dropdown
|
190 |
-
return gr.
|
191 |
|
192 |
def generate(
|
193 |
self,
|
@@ -200,15 +227,24 @@ class EasyAnimateController:
|
|
200 |
negative_prompt_textbox,
|
201 |
sampler_dropdown,
|
202 |
sample_step_slider,
|
|
|
203 |
width_slider,
|
204 |
height_slider,
|
205 |
-
|
|
|
206 |
length_slider,
|
|
|
|
|
207 |
cfg_scale_slider,
|
|
|
|
|
208 |
seed_textbox,
|
209 |
is_api = False,
|
210 |
):
|
211 |
-
|
|
|
|
|
|
|
212 |
if self.transformer is None:
|
213 |
raise gr.Error(f"Please select a pretrained model path.")
|
214 |
|
@@ -221,6 +257,39 @@ class EasyAnimateController:
|
|
221 |
if self.lora_model_path != lora_model_dropdown:
|
222 |
print("Update lora model")
|
223 |
self.update_lora_model(lora_model_dropdown)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
226 |
|
@@ -235,16 +304,98 @@ class EasyAnimateController:
|
|
235 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
236 |
|
237 |
try:
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
except Exception as e:
|
249 |
gc.collect()
|
250 |
torch.cuda.empty_cache()
|
@@ -254,7 +405,11 @@ class EasyAnimateController:
|
|
254 |
if is_api:
|
255 |
return "", f"Error. error information is {str(e)}"
|
256 |
else:
|
257 |
-
return gr.
|
|
|
|
|
|
|
|
|
258 |
|
259 |
# lora part
|
260 |
if self.lora_model_path != "none":
|
@@ -296,7 +451,10 @@ class EasyAnimateController:
|
|
296 |
if is_api:
|
297 |
return save_sample_path, "Success"
|
298 |
else:
|
299 |
-
|
|
|
|
|
|
|
300 |
else:
|
301 |
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
302 |
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
@@ -304,7 +462,10 @@ class EasyAnimateController:
|
|
304 |
if is_api:
|
305 |
return save_sample_path, "Success"
|
306 |
else:
|
307 |
-
|
|
|
|
|
|
|
308 |
|
309 |
|
310 |
def ui():
|
@@ -325,24 +486,24 @@ def ui():
|
|
325 |
with gr.Column(variant="panel"):
|
326 |
gr.Markdown(
|
327 |
"""
|
328 |
-
### 1. EasyAnimate Edition (
|
329 |
"""
|
330 |
)
|
331 |
with gr.Row():
|
332 |
easyanimate_edition_dropdown = gr.Dropdown(
|
333 |
-
label="The config of EasyAnimate Edition",
|
334 |
-
choices=["v1", "v2"],
|
335 |
-
value="
|
336 |
interactive=True,
|
337 |
)
|
338 |
gr.Markdown(
|
339 |
"""
|
340 |
-
### 2. Model checkpoints (
|
341 |
"""
|
342 |
)
|
343 |
with gr.Row():
|
344 |
diffusion_transformer_dropdown = gr.Dropdown(
|
345 |
-
label="Pretrained Model Path",
|
346 |
choices=controller.diffusion_transformer_list,
|
347 |
value="none",
|
348 |
interactive=True,
|
@@ -356,12 +517,12 @@ def ui():
|
|
356 |
diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
357 |
def refresh_diffusion_transformer():
|
358 |
controller.refresh_diffusion_transformer()
|
359 |
-
return gr.
|
360 |
diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
|
361 |
|
362 |
with gr.Row():
|
363 |
motion_module_dropdown = gr.Dropdown(
|
364 |
-
label="Select motion module",
|
365 |
choices=controller.motion_module_list,
|
366 |
value="none",
|
367 |
interactive=True,
|
@@ -371,78 +532,139 @@ def ui():
|
|
371 |
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
|
372 |
def update_motion_module():
|
373 |
controller.refresh_motion_module()
|
374 |
-
return gr.
|
375 |
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
376 |
|
377 |
base_model_dropdown = gr.Dropdown(
|
378 |
-
label="Select base Dreambooth model (
|
379 |
choices=controller.personalized_model_list,
|
380 |
value="none",
|
381 |
interactive=True,
|
382 |
)
|
383 |
|
384 |
lora_model_dropdown = gr.Dropdown(
|
385 |
-
label="Select LoRA model (
|
386 |
choices=["none"] + controller.personalized_model_list,
|
387 |
value="none",
|
388 |
interactive=True,
|
389 |
)
|
390 |
|
391 |
-
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
|
392 |
|
393 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
394 |
def update_personalized_model():
|
395 |
controller.refresh_personalized_model()
|
396 |
return [
|
397 |
-
gr.
|
398 |
-
gr.
|
399 |
]
|
400 |
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
401 |
|
402 |
with gr.Column(variant="panel"):
|
403 |
gr.Markdown(
|
404 |
"""
|
405 |
-
### 3. Configs for Generation.
|
406 |
"""
|
407 |
)
|
408 |
|
409 |
-
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="
|
410 |
-
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.
|
411 |
|
412 |
with gr.Row():
|
413 |
with gr.Column():
|
414 |
with gr.Row():
|
415 |
-
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
416 |
-
sample_step_slider = gr.Slider(label="Sampling steps", value=
|
417 |
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
with gr.Row():
|
426 |
-
seed_textbox = gr.Textbox(label="Seed", value=43)
|
427 |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
428 |
-
seed_button.click(
|
|
|
|
|
|
|
|
|
429 |
|
430 |
-
generate_button = gr.Button(value="Generate", variant='primary')
|
431 |
|
432 |
with gr.Column():
|
433 |
-
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
434 |
-
result_video = gr.Video(label="Generated Animation", interactive=False)
|
435 |
infer_progress = gr.Textbox(
|
436 |
-
label="Generation Info",
|
437 |
value="No task currently",
|
438 |
interactive=False
|
439 |
)
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
)
|
|
|
446 |
easyanimate_edition_dropdown.change(
|
447 |
fn=controller.update_edition,
|
448 |
inputs=[easyanimate_edition_dropdown],
|
@@ -451,7 +673,6 @@ def ui():
|
|
451 |
diffusion_transformer_dropdown,
|
452 |
motion_module_dropdown,
|
453 |
motion_module_refresh_button,
|
454 |
-
is_image,
|
455 |
width_slider,
|
456 |
height_slider,
|
457 |
length_slider,
|
@@ -469,11 +690,17 @@ def ui():
|
|
469 |
negative_prompt_textbox,
|
470 |
sampler_dropdown,
|
471 |
sample_step_slider,
|
|
|
472 |
width_slider,
|
473 |
height_slider,
|
474 |
-
|
|
|
475 |
length_slider,
|
|
|
|
|
476 |
cfg_scale_slider,
|
|
|
|
|
477 |
seed_textbox,
|
478 |
],
|
479 |
outputs=[result_image, result_video, infer_progress]
|
@@ -483,11 +710,18 @@ def ui():
|
|
483 |
|
484 |
class EasyAnimateController_Modelscope:
|
485 |
def __init__(self, edition, config_path, model_name, savedir_sample):
|
486 |
-
#
|
487 |
-
weight_dtype
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
490 |
|
|
|
491 |
self.edition = edition
|
492 |
self.inference_config = OmegaConf.load(config_path)
|
493 |
# Get Transformer
|
@@ -513,32 +747,107 @@ class EasyAnimateController_Modelscope:
|
|
513 |
subfolder="text_encoder",
|
514 |
torch_dtype=weight_dtype
|
515 |
)
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
print("Update diffusion transformer done")
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
def generate(
|
527 |
self,
|
|
|
|
|
|
|
|
|
|
|
528 |
prompt_textbox,
|
529 |
negative_prompt_textbox,
|
530 |
sampler_dropdown,
|
531 |
sample_step_slider,
|
|
|
532 |
width_slider,
|
533 |
height_slider,
|
534 |
-
|
|
|
535 |
length_slider,
|
536 |
cfg_scale_slider,
|
537 |
-
|
|
|
|
|
|
|
538 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
540 |
|
541 |
self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
|
|
|
|
|
|
542 |
self.pipeline.to("cuda")
|
543 |
|
544 |
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
@@ -546,21 +855,52 @@ class EasyAnimateController_Modelscope:
|
|
546 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
547 |
|
548 |
try:
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
except Exception as e:
|
560 |
gc.collect()
|
561 |
torch.cuda.empty_cache()
|
562 |
torch.cuda.ipc_collect()
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
|
565 |
if not os.path.exists(self.savedir_sample):
|
566 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
@@ -578,11 +918,23 @@ class EasyAnimateController_Modelscope:
|
|
578 |
image = (image * 255).numpy().astype(np.uint8)
|
579 |
image = Image.fromarray(image)
|
580 |
image.save(save_sample_path)
|
581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
else:
|
583 |
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
584 |
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
|
587 |
|
588 |
def ui_modelscope(edition, config_path, model_name, savedir_sample):
|
@@ -601,71 +953,197 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample):
|
|
601 |
"""
|
602 |
)
|
603 |
with gr.Column(variant="panel"):
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
with gr.Row():
|
608 |
with gr.Column():
|
609 |
with gr.Row():
|
610 |
-
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
611 |
-
sample_step_slider = gr.Slider(label="Sampling steps", value=
|
612 |
|
613 |
if edition == "v1":
|
614 |
-
width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
|
615 |
-
height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
else:
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
623 |
with gr.Column():
|
624 |
gr.Markdown(
|
625 |
"""
|
626 |
-
|
627 |
-
|
|
|
|
|
|
|
628 |
"""
|
629 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
with gr.Row():
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
with gr.Row():
|
636 |
-
seed_textbox = gr.Textbox(label="Seed", value=43)
|
637 |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
638 |
-
seed_button.click(
|
|
|
|
|
|
|
|
|
639 |
|
640 |
-
generate_button = gr.Button(value="Generate", variant='primary')
|
641 |
|
642 |
with gr.Column():
|
643 |
-
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
644 |
-
result_video = gr.Video(label="Generated Animation", interactive=False)
|
645 |
infer_progress = gr.Textbox(
|
646 |
-
label="Generation Info",
|
647 |
value="No task currently",
|
648 |
interactive=False
|
649 |
)
|
650 |
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
)
|
656 |
|
657 |
generate_button.click(
|
658 |
fn=controller.generate,
|
659 |
inputs=[
|
|
|
|
|
|
|
|
|
|
|
660 |
prompt_textbox,
|
661 |
negative_prompt_textbox,
|
662 |
sampler_dropdown,
|
663 |
sample_step_slider,
|
|
|
664 |
width_slider,
|
665 |
height_slider,
|
666 |
-
|
|
|
667 |
length_slider,
|
668 |
cfg_scale_slider,
|
|
|
|
|
669 |
seed_textbox,
|
670 |
],
|
671 |
outputs=[result_image, result_video, infer_progress]
|
@@ -674,31 +1152,51 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample):
|
|
674 |
|
675 |
|
676 |
def post_eas(
|
|
|
|
|
677 |
prompt_textbox, negative_prompt_textbox,
|
678 |
-
sampler_dropdown, sample_step_slider, width_slider, height_slider,
|
679 |
-
|
|
|
680 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
681 |
datas = {
|
682 |
-
"base_model_path":
|
683 |
-
"motion_module_path":
|
684 |
-
"lora_model_path":
|
685 |
-
"lora_alpha_slider":
|
686 |
"prompt_textbox": prompt_textbox,
|
687 |
"negative_prompt_textbox": negative_prompt_textbox,
|
688 |
"sampler_dropdown": sampler_dropdown,
|
689 |
"sample_step_slider": sample_step_slider,
|
|
|
690 |
"width_slider": width_slider,
|
691 |
"height_slider": height_slider,
|
692 |
-
"
|
|
|
693 |
"length_slider": length_slider,
|
694 |
"cfg_scale_slider": cfg_scale_slider,
|
|
|
|
|
695 |
"seed_textbox": seed_textbox,
|
696 |
}
|
697 |
-
|
698 |
session = requests.session()
|
699 |
session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
|
700 |
|
701 |
-
response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
|
|
|
702 |
outputs = response.json()
|
703 |
return outputs
|
704 |
|
@@ -710,23 +1208,42 @@ class EasyAnimateController_EAS:
|
|
710 |
|
711 |
def generate(
|
712 |
self,
|
|
|
|
|
|
|
|
|
|
|
713 |
prompt_textbox,
|
714 |
negative_prompt_textbox,
|
715 |
sampler_dropdown,
|
716 |
sample_step_slider,
|
|
|
717 |
width_slider,
|
718 |
height_slider,
|
719 |
-
|
|
|
720 |
length_slider,
|
721 |
cfg_scale_slider,
|
|
|
|
|
722 |
seed_textbox
|
723 |
):
|
|
|
|
|
724 |
outputs = post_eas(
|
|
|
|
|
725 |
prompt_textbox, negative_prompt_textbox,
|
726 |
-
sampler_dropdown, sample_step_slider, width_slider, height_slider,
|
727 |
-
|
|
|
|
|
728 |
)
|
729 |
-
|
|
|
|
|
|
|
|
|
730 |
decoded_data = base64.b64decode(base64_encoding)
|
731 |
|
732 |
if not os.path.exists(self.savedir_sample):
|
@@ -768,35 +1285,134 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
|
|
768 |
"""
|
769 |
)
|
770 |
with gr.Column(variant="panel"):
|
771 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
772 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
|
773 |
|
774 |
with gr.Row():
|
775 |
with gr.Column():
|
776 |
with gr.Row():
|
777 |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
778 |
-
sample_step_slider = gr.Slider(label="Sampling steps", value=
|
779 |
|
780 |
if edition == "v1":
|
781 |
width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
|
782 |
height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
|
783 |
-
|
784 |
-
|
785 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
|
787 |
else:
|
788 |
-
|
789 |
-
|
|
|
|
|
|
|
790 |
with gr.Column():
|
791 |
gr.Markdown(
|
792 |
"""
|
793 |
-
|
794 |
-
|
|
|
|
|
|
|
795 |
"""
|
796 |
)
|
797 |
-
|
798 |
-
|
799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
|
801 |
|
802 |
with gr.Row():
|
@@ -819,24 +1435,45 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
|
|
819 |
interactive=False
|
820 |
)
|
821 |
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
)
|
827 |
|
828 |
generate_button.click(
|
829 |
fn=controller.generate,
|
830 |
inputs=[
|
|
|
|
|
|
|
|
|
|
|
831 |
prompt_textbox,
|
832 |
negative_prompt_textbox,
|
833 |
sampler_dropdown,
|
834 |
sample_step_slider,
|
|
|
835 |
width_slider,
|
836 |
height_slider,
|
837 |
-
|
|
|
838 |
length_slider,
|
839 |
cfg_scale_slider,
|
|
|
|
|
840 |
seed_textbox,
|
841 |
],
|
842 |
outputs=[result_image, result_video, infer_progress]
|
|
|
1 |
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
2 |
"""
|
3 |
+
import base64
|
4 |
import gc
|
5 |
import json
|
6 |
import os
|
7 |
import random
|
|
|
|
|
|
|
8 |
from datetime import datetime
|
9 |
from glob import glob
|
10 |
|
11 |
import gradio as gr
|
|
|
12 |
import numpy as np
|
13 |
+
import pkg_resources
|
14 |
+
import requests
|
15 |
+
import torch
|
16 |
from diffusers import (AutoencoderKL, DDIMScheduler,
|
17 |
DPMSolverMultistepScheduler,
|
18 |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
19 |
PNDMScheduler)
|
|
|
20 |
from diffusers.utils.import_utils import is_xformers_available
|
21 |
from omegaconf import OmegaConf
|
22 |
+
from PIL import Image
|
23 |
from safetensors import safe_open
|
24 |
+
from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
|
25 |
+
T5EncoderModel, T5Tokenizer)
|
26 |
|
27 |
+
from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
28 |
+
from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
|
29 |
from easyanimate.models.transformer3d import Transformer3DModel
|
30 |
from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
|
31 |
+
from easyanimate.pipeline.pipeline_easyanimate_inpaint import \
|
32 |
+
EasyAnimateInpaintPipeline
|
33 |
from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
|
34 |
+
from easyanimate.utils.utils import (
|
35 |
+
get_image_to_video_latent,
|
36 |
+
get_width_and_height_from_image_and_base_resolution, save_videos_grid)
|
37 |
|
|
|
38 |
scheduler_dict = {
|
39 |
"Euler": EulerDiscreteScheduler,
|
40 |
"Euler A": EulerAncestralDiscreteScheduler,
|
|
|
65 |
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
66 |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
67 |
self.savedir_sample = os.path.join(self.savedir, "sample")
|
68 |
+
self.edition = "v3"
|
69 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
|
70 |
os.makedirs(self.savedir, exist_ok=True)
|
71 |
|
72 |
self.diffusion_transformer_list = []
|
|
|
90 |
self.weight_dtype = torch.bfloat16
|
91 |
|
92 |
def refresh_diffusion_transformer(self):
|
93 |
+
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
|
94 |
|
95 |
def refresh_motion_module(self):
|
96 |
+
motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
|
97 |
self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
|
98 |
|
99 |
def refresh_personalized_model(self):
|
100 |
+
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
|
101 |
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
102 |
|
103 |
def update_edition(self, edition):
|
|
|
105 |
self.edition = edition
|
106 |
if edition == "v1":
|
107 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
|
108 |
+
return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
|
109 |
+
gr.update(value=512, minimum=384, maximum=704, step=32), \
|
110 |
gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
|
111 |
+
elif edition == "v2":
|
112 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
|
113 |
+
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
114 |
+
gr.update(value=672, minimum=128, maximum=1280, step=16), \
|
115 |
gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
|
116 |
+
else:
|
117 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
|
118 |
+
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
119 |
+
gr.update(value=672, minimum=128, maximum=1280, step=16), \
|
120 |
+
gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
|
121 |
|
122 |
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
123 |
print("Update diffusion transformer")
|
124 |
if diffusion_transformer_dropdown == "none":
|
125 |
+
return gr.update()
|
126 |
if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
|
127 |
Choosen_AutoencoderKL = AutoencoderKLMagvit
|
128 |
else:
|
|
|
140 |
self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
|
141 |
|
142 |
# Get pipeline
|
143 |
+
if self.transformer.config.in_channels != 12:
|
144 |
+
self.pipeline = EasyAnimatePipeline(
|
145 |
+
vae=self.vae,
|
146 |
+
text_encoder=self.text_encoder,
|
147 |
+
tokenizer=self.tokenizer,
|
148 |
+
transformer=self.transformer,
|
149 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
153 |
+
diffusion_transformer_dropdown, subfolder="image_encoder"
|
154 |
+
).to("cuda", self.weight_dtype)
|
155 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(
|
156 |
+
diffusion_transformer_dropdown, subfolder="image_encoder"
|
157 |
+
)
|
158 |
+
self.pipeline = EasyAnimateInpaintPipeline(
|
159 |
+
vae=self.vae,
|
160 |
+
text_encoder=self.text_encoder,
|
161 |
+
tokenizer=self.tokenizer,
|
162 |
+
transformer=self.transformer,
|
163 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
|
164 |
+
clip_image_encoder=clip_image_encoder,
|
165 |
+
clip_image_processor=clip_image_processor,
|
166 |
+
)
|
167 |
+
|
168 |
print("Update diffusion transformer done")
|
169 |
+
return gr.update()
|
170 |
|
171 |
def update_motion_module(self, motion_module_dropdown):
|
172 |
self.motion_module_path = motion_module_dropdown
|
173 |
print("Update motion module")
|
174 |
if motion_module_dropdown == "none":
|
175 |
+
return gr.update()
|
176 |
if self.transformer is None:
|
177 |
gr.Info(f"Please select a pretrained model path.")
|
178 |
+
return gr.update(value=None)
|
179 |
else:
|
180 |
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
181 |
if motion_module_dropdown.endswith(".safetensors"):
|
|
|
187 |
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
188 |
missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
|
189 |
print("Update motion module done.")
|
190 |
+
return gr.update()
|
191 |
|
192 |
def update_base_model(self, base_model_dropdown):
|
193 |
self.base_model_path = base_model_dropdown
|
194 |
print("Update base model")
|
195 |
if base_model_dropdown == "none":
|
196 |
+
return gr.update()
|
197 |
if self.transformer is None:
|
198 |
gr.Info(f"Please select a pretrained model path.")
|
199 |
+
return gr.update(value=None)
|
200 |
else:
|
201 |
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
202 |
base_model_state_dict = {}
|
|
|
205 |
base_model_state_dict[key] = f.get_tensor(key)
|
206 |
self.transformer.load_state_dict(base_model_state_dict, strict=False)
|
207 |
print("Update base done")
|
208 |
+
return gr.update()
|
209 |
|
210 |
def update_lora_model(self, lora_model_dropdown):
|
211 |
print("Update lora model")
|
212 |
if lora_model_dropdown == "none":
|
213 |
self.lora_model_path = "none"
|
214 |
+
return gr.update()
|
215 |
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
216 |
self.lora_model_path = lora_model_dropdown
|
217 |
+
return gr.update()
|
218 |
|
219 |
def generate(
|
220 |
self,
|
|
|
227 |
negative_prompt_textbox,
|
228 |
sampler_dropdown,
|
229 |
sample_step_slider,
|
230 |
+
resize_method,
|
231 |
width_slider,
|
232 |
height_slider,
|
233 |
+
base_resolution,
|
234 |
+
generation_method,
|
235 |
length_slider,
|
236 |
+
overlap_video_length,
|
237 |
+
partial_video_length,
|
238 |
cfg_scale_slider,
|
239 |
+
start_image,
|
240 |
+
end_image,
|
241 |
seed_textbox,
|
242 |
is_api = False,
|
243 |
):
|
244 |
+
gc.collect()
|
245 |
+
torch.cuda.empty_cache()
|
246 |
+
torch.cuda.ipc_collect()
|
247 |
+
|
248 |
if self.transformer is None:
|
249 |
raise gr.Error(f"Please select a pretrained model path.")
|
250 |
|
|
|
257 |
if self.lora_model_path != lora_model_dropdown:
|
258 |
print("Update lora model")
|
259 |
self.update_lora_model(lora_model_dropdown)
|
260 |
+
|
261 |
+
if resize_method == "Resize to the Start Image":
|
262 |
+
if start_image is None:
|
263 |
+
if is_api:
|
264 |
+
return "", f"Please upload an image when using \"Resize to the Start Image\"."
|
265 |
+
else:
|
266 |
+
raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
|
267 |
+
|
268 |
+
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
269 |
+
|
270 |
+
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
|
271 |
+
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
272 |
+
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
273 |
+
|
274 |
+
if self.transformer.config.in_channels != 12 and start_image is not None:
|
275 |
+
if is_api:
|
276 |
+
return "", f"Please select an image to video pretrained model while using image to video."
|
277 |
+
else:
|
278 |
+
raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
|
279 |
+
|
280 |
+
if self.transformer.config.in_channels != 12 and generation_method == "Long Video Generation":
|
281 |
+
if is_api:
|
282 |
+
return "", f"Please select an image to video pretrained model while using long video generation."
|
283 |
+
else:
|
284 |
+
raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
|
285 |
+
|
286 |
+
if start_image is None and end_image is not None:
|
287 |
+
if is_api:
|
288 |
+
return "", f"If specifying the ending image of the video, please specify a starting image of the video."
|
289 |
+
else:
|
290 |
+
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
291 |
+
|
292 |
+
is_image = True if generation_method == "Image Generation" else False
|
293 |
|
294 |
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
295 |
|
|
|
304 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
305 |
|
306 |
try:
|
307 |
+
if self.transformer.config.in_channels == 12:
|
308 |
+
if generation_method == "Long Video Generation":
|
309 |
+
init_frames = 0
|
310 |
+
last_frames = init_frames + partial_video_length
|
311 |
+
while init_frames < length_slider:
|
312 |
+
if last_frames >= length_slider:
|
313 |
+
if self.pipeline.vae.quant_conv.weight.ndim==5:
|
314 |
+
mini_batch_encoder = self.pipeline.vae.mini_batch_encoder
|
315 |
+
_partial_video_length = length_slider - init_frames
|
316 |
+
_partial_video_length = int(_partial_video_length // mini_batch_encoder * mini_batch_encoder)
|
317 |
+
else:
|
318 |
+
_partial_video_length = length_slider - init_frames
|
319 |
+
|
320 |
+
if _partial_video_length <= 0:
|
321 |
+
break
|
322 |
+
else:
|
323 |
+
_partial_video_length = partial_video_length
|
324 |
+
|
325 |
+
if last_frames >= length_slider:
|
326 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
327 |
+
else:
|
328 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
329 |
+
|
330 |
+
with torch.no_grad():
|
331 |
+
sample = self.pipeline(
|
332 |
+
prompt_textbox,
|
333 |
+
negative_prompt = negative_prompt_textbox,
|
334 |
+
num_inference_steps = sample_step_slider,
|
335 |
+
guidance_scale = cfg_scale_slider,
|
336 |
+
width = width_slider,
|
337 |
+
height = height_slider,
|
338 |
+
video_length = _partial_video_length,
|
339 |
+
generator = generator,
|
340 |
+
|
341 |
+
video = input_video,
|
342 |
+
mask_video = input_video_mask,
|
343 |
+
clip_image = clip_image,
|
344 |
+
strength = 1,
|
345 |
+
).videos
|
346 |
+
|
347 |
+
if init_frames != 0:
|
348 |
+
mix_ratio = torch.from_numpy(
|
349 |
+
np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
|
350 |
+
).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
351 |
+
|
352 |
+
new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
|
353 |
+
sample[:, :, :overlap_video_length] * mix_ratio
|
354 |
+
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
|
355 |
+
|
356 |
+
sample = new_sample
|
357 |
+
else:
|
358 |
+
new_sample = sample
|
359 |
+
|
360 |
+
if last_frames >= length_slider:
|
361 |
+
break
|
362 |
+
|
363 |
+
start_image = [
|
364 |
+
Image.fromarray(
|
365 |
+
(sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
|
366 |
+
) for _index in range(-overlap_video_length, 0)
|
367 |
+
]
|
368 |
+
|
369 |
+
init_frames = init_frames + _partial_video_length - overlap_video_length
|
370 |
+
last_frames = init_frames + _partial_video_length
|
371 |
+
else:
|
372 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
373 |
+
|
374 |
+
sample = self.pipeline(
|
375 |
+
prompt_textbox,
|
376 |
+
negative_prompt = negative_prompt_textbox,
|
377 |
+
num_inference_steps = sample_step_slider,
|
378 |
+
guidance_scale = cfg_scale_slider,
|
379 |
+
width = width_slider,
|
380 |
+
height = height_slider,
|
381 |
+
video_length = length_slider if not is_image else 1,
|
382 |
+
generator = generator,
|
383 |
+
|
384 |
+
video = input_video,
|
385 |
+
mask_video = input_video_mask,
|
386 |
+
clip_image = clip_image,
|
387 |
+
).videos
|
388 |
+
else:
|
389 |
+
sample = self.pipeline(
|
390 |
+
prompt_textbox,
|
391 |
+
negative_prompt = negative_prompt_textbox,
|
392 |
+
num_inference_steps = sample_step_slider,
|
393 |
+
guidance_scale = cfg_scale_slider,
|
394 |
+
width = width_slider,
|
395 |
+
height = height_slider,
|
396 |
+
video_length = length_slider if not is_image else 1,
|
397 |
+
generator = generator
|
398 |
+
).videos
|
399 |
except Exception as e:
|
400 |
gc.collect()
|
401 |
torch.cuda.empty_cache()
|
|
|
405 |
if is_api:
|
406 |
return "", f"Error. error information is {str(e)}"
|
407 |
else:
|
408 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
409 |
+
|
410 |
+
gc.collect()
|
411 |
+
torch.cuda.empty_cache()
|
412 |
+
torch.cuda.ipc_collect()
|
413 |
|
414 |
# lora part
|
415 |
if self.lora_model_path != "none":
|
|
|
451 |
if is_api:
|
452 |
return save_sample_path, "Success"
|
453 |
else:
|
454 |
+
if gradio_version_is_above_4:
|
455 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
456 |
+
else:
|
457 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
458 |
else:
|
459 |
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
460 |
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
|
|
462 |
if is_api:
|
463 |
return save_sample_path, "Success"
|
464 |
else:
|
465 |
+
if gradio_version_is_above_4:
|
466 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
467 |
+
else:
|
468 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
469 |
|
470 |
|
471 |
def ui():
|
|
|
486 |
with gr.Column(variant="panel"):
|
487 |
gr.Markdown(
|
488 |
"""
|
489 |
+
### 1. EasyAnimate Edition (EasyAnimate版本).
|
490 |
"""
|
491 |
)
|
492 |
with gr.Row():
|
493 |
easyanimate_edition_dropdown = gr.Dropdown(
|
494 |
+
label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
|
495 |
+
choices=["v1", "v2", "v3"],
|
496 |
+
value="v3",
|
497 |
interactive=True,
|
498 |
)
|
499 |
gr.Markdown(
|
500 |
"""
|
501 |
+
### 2. Model checkpoints (模型路径).
|
502 |
"""
|
503 |
)
|
504 |
with gr.Row():
|
505 |
diffusion_transformer_dropdown = gr.Dropdown(
|
506 |
+
label="Pretrained Model Path (预训练模型路径)",
|
507 |
choices=controller.diffusion_transformer_list,
|
508 |
value="none",
|
509 |
interactive=True,
|
|
|
517 |
diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
518 |
def refresh_diffusion_transformer():
|
519 |
controller.refresh_diffusion_transformer()
|
520 |
+
return gr.update(choices=controller.diffusion_transformer_list)
|
521 |
diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
|
522 |
|
523 |
with gr.Row():
|
524 |
motion_module_dropdown = gr.Dropdown(
|
525 |
+
label="Select motion module (选择运动模块[非必需])",
|
526 |
choices=controller.motion_module_list,
|
527 |
value="none",
|
528 |
interactive=True,
|
|
|
532 |
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
|
533 |
def update_motion_module():
|
534 |
controller.refresh_motion_module()
|
535 |
+
return gr.update(choices=controller.motion_module_list)
|
536 |
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
537 |
|
538 |
base_model_dropdown = gr.Dropdown(
|
539 |
+
label="Select base Dreambooth model (选择基模型[非必需])",
|
540 |
choices=controller.personalized_model_list,
|
541 |
value="none",
|
542 |
interactive=True,
|
543 |
)
|
544 |
|
545 |
lora_model_dropdown = gr.Dropdown(
|
546 |
+
label="Select LoRA model (选择LoRA模型[非必需])",
|
547 |
choices=["none"] + controller.personalized_model_list,
|
548 |
value="none",
|
549 |
interactive=True,
|
550 |
)
|
551 |
|
552 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
|
553 |
|
554 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
555 |
def update_personalized_model():
|
556 |
controller.refresh_personalized_model()
|
557 |
return [
|
558 |
+
gr.update(choices=controller.personalized_model_list),
|
559 |
+
gr.update(choices=["none"] + controller.personalized_model_list)
|
560 |
]
|
561 |
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
562 |
|
563 |
with gr.Column(variant="panel"):
|
564 |
gr.Markdown(
|
565 |
"""
|
566 |
+
### 3. Configs for Generation (生成参数配置).
|
567 |
"""
|
568 |
)
|
569 |
|
570 |
+
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
571 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
|
572 |
|
573 |
with gr.Row():
|
574 |
with gr.Column():
|
575 |
with gr.Row():
|
576 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
577 |
+
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=30, minimum=10, maximum=100, step=1)
|
578 |
|
579 |
+
resize_method = gr.Radio(
|
580 |
+
["Generate by", "Resize to the Start Image"],
|
581 |
+
value="Generate by",
|
582 |
+
show_label=False,
|
583 |
+
)
|
584 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16)
|
585 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16)
|
586 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
|
587 |
+
|
588 |
+
with gr.Group():
|
589 |
+
generation_method = gr.Radio(
|
590 |
+
["Video Generation", "Image Generation", "Long Video Generation"],
|
591 |
+
value="Video Generation",
|
592 |
+
show_label=False,
|
593 |
+
)
|
594 |
+
with gr.Row():
|
595 |
+
length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8, maximum=144, step=8)
|
596 |
+
overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
|
597 |
+
partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8, maximum=144, step=8, visible=False)
|
598 |
+
|
599 |
+
with gr.Accordion("Image to Video (图片到视频)", open=False):
|
600 |
+
start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
|
601 |
+
|
602 |
+
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
603 |
+
def select_template(evt: gr.SelectData):
|
604 |
+
text = {
|
605 |
+
"asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
606 |
+
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
607 |
+
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
608 |
+
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
609 |
+
"asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
610 |
+
}[template_gallery_path[evt.index]]
|
611 |
+
return template_gallery_path[evt.index], text
|
612 |
+
|
613 |
+
template_gallery = gr.Gallery(
|
614 |
+
template_gallery_path,
|
615 |
+
columns=5, rows=1,
|
616 |
+
height=140,
|
617 |
+
allow_preview=False,
|
618 |
+
container=False,
|
619 |
+
label="Template Examples",
|
620 |
+
)
|
621 |
+
template_gallery.select(select_template, None, [start_image, prompt_textbox])
|
622 |
+
|
623 |
+
with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
|
624 |
+
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
625 |
+
|
626 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
|
627 |
|
628 |
with gr.Row():
|
629 |
+
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
630 |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
631 |
+
seed_button.click(
|
632 |
+
fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
|
633 |
+
inputs=[],
|
634 |
+
outputs=[seed_textbox]
|
635 |
+
)
|
636 |
|
637 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
638 |
|
639 |
with gr.Column():
|
640 |
+
result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
|
641 |
+
result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
|
642 |
infer_progress = gr.Textbox(
|
643 |
+
label="Generation Info (生成信息)",
|
644 |
value="No task currently",
|
645 |
interactive=False
|
646 |
)
|
647 |
|
648 |
+
def upload_generation_method(generation_method):
|
649 |
+
if generation_method == "Video Generation":
|
650 |
+
return [gr.update(visible=True, maximum=144, value=144), gr.update(visible=False), gr.update(visible=False)]
|
651 |
+
elif generation_method == "Image Generation":
|
652 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
|
653 |
+
else:
|
654 |
+
return [gr.update(visible=True, maximum=1440), gr.update(visible=True), gr.update(visible=True)]
|
655 |
+
generation_method.change(
|
656 |
+
upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
|
657 |
+
)
|
658 |
+
|
659 |
+
def upload_resize_method(resize_method):
|
660 |
+
if resize_method == "Generate by":
|
661 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
662 |
+
else:
|
663 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
664 |
+
resize_method.change(
|
665 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
666 |
)
|
667 |
+
|
668 |
easyanimate_edition_dropdown.change(
|
669 |
fn=controller.update_edition,
|
670 |
inputs=[easyanimate_edition_dropdown],
|
|
|
673 |
diffusion_transformer_dropdown,
|
674 |
motion_module_dropdown,
|
675 |
motion_module_refresh_button,
|
|
|
676 |
width_slider,
|
677 |
height_slider,
|
678 |
length_slider,
|
|
|
690 |
negative_prompt_textbox,
|
691 |
sampler_dropdown,
|
692 |
sample_step_slider,
|
693 |
+
resize_method,
|
694 |
width_slider,
|
695 |
height_slider,
|
696 |
+
base_resolution,
|
697 |
+
generation_method,
|
698 |
length_slider,
|
699 |
+
overlap_video_length,
|
700 |
+
partial_video_length,
|
701 |
cfg_scale_slider,
|
702 |
+
start_image,
|
703 |
+
end_image,
|
704 |
seed_textbox,
|
705 |
],
|
706 |
outputs=[result_image, result_video, infer_progress]
|
|
|
710 |
|
711 |
class EasyAnimateController_Modelscope:
|
712 |
def __init__(self, edition, config_path, model_name, savedir_sample):
|
713 |
+
# Weight Dtype
|
714 |
+
weight_dtype = torch.bfloat16
|
715 |
+
|
716 |
+
# Basic dir
|
717 |
+
self.basedir = os.getcwd()
|
718 |
+
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
719 |
+
self.lora_model_path = "none"
|
720 |
+
self.savedir_sample = savedir_sample
|
721 |
+
self.refresh_personalized_model()
|
722 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
723 |
|
724 |
+
# Config and model path
|
725 |
self.edition = edition
|
726 |
self.inference_config = OmegaConf.load(config_path)
|
727 |
# Get Transformer
|
|
|
747 |
subfolder="text_encoder",
|
748 |
torch_dtype=weight_dtype
|
749 |
)
|
750 |
+
# Get pipeline
|
751 |
+
if self.transformer.config.in_channels != 12:
|
752 |
+
self.pipeline = EasyAnimatePipeline(
|
753 |
+
vae=self.vae,
|
754 |
+
text_encoder=self.text_encoder,
|
755 |
+
tokenizer=self.tokenizer,
|
756 |
+
transformer=self.transformer,
|
757 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
758 |
+
)
|
759 |
+
else:
|
760 |
+
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
761 |
+
model_name, subfolder="image_encoder"
|
762 |
+
).to("cuda", weight_dtype)
|
763 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(
|
764 |
+
model_name, subfolder="image_encoder"
|
765 |
+
)
|
766 |
+
self.pipeline = EasyAnimateInpaintPipeline(
|
767 |
+
vae=self.vae,
|
768 |
+
text_encoder=self.text_encoder,
|
769 |
+
tokenizer=self.tokenizer,
|
770 |
+
transformer=self.transformer,
|
771 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
|
772 |
+
clip_image_encoder=clip_image_encoder,
|
773 |
+
clip_image_processor=clip_image_processor,
|
774 |
+
)
|
775 |
+
|
776 |
print("Update diffusion transformer done")
|
777 |
|
778 |
+
|
779 |
+
def refresh_personalized_model(self):
|
780 |
+
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
|
781 |
+
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
782 |
+
|
783 |
+
|
784 |
+
def update_lora_model(self, lora_model_dropdown):
|
785 |
+
print("Update lora model")
|
786 |
+
if lora_model_dropdown == "none":
|
787 |
+
self.lora_model_path = "none"
|
788 |
+
return gr.update()
|
789 |
+
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
790 |
+
self.lora_model_path = lora_model_dropdown
|
791 |
+
return gr.update()
|
792 |
+
|
793 |
+
|
794 |
def generate(
|
795 |
self,
|
796 |
+
diffusion_transformer_dropdown,
|
797 |
+
motion_module_dropdown,
|
798 |
+
base_model_dropdown,
|
799 |
+
lora_model_dropdown,
|
800 |
+
lora_alpha_slider,
|
801 |
prompt_textbox,
|
802 |
negative_prompt_textbox,
|
803 |
sampler_dropdown,
|
804 |
sample_step_slider,
|
805 |
+
resize_method,
|
806 |
width_slider,
|
807 |
height_slider,
|
808 |
+
base_resolution,
|
809 |
+
generation_method,
|
810 |
length_slider,
|
811 |
cfg_scale_slider,
|
812 |
+
start_image,
|
813 |
+
end_image,
|
814 |
+
seed_textbox,
|
815 |
+
is_api = False,
|
816 |
):
|
817 |
+
gc.collect()
|
818 |
+
torch.cuda.empty_cache()
|
819 |
+
torch.cuda.ipc_collect()
|
820 |
+
|
821 |
+
if self.transformer is None:
|
822 |
+
raise gr.Error(f"Please select a pretrained model path.")
|
823 |
+
|
824 |
+
if self.lora_model_path != lora_model_dropdown:
|
825 |
+
print("Update lora model")
|
826 |
+
self.update_lora_model(lora_model_dropdown)
|
827 |
+
|
828 |
+
if resize_method == "Resize to the Start Image":
|
829 |
+
if start_image is None:
|
830 |
+
raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
|
831 |
+
|
832 |
+
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
833 |
+
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
|
834 |
+
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
835 |
+
height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
|
836 |
+
|
837 |
+
if self.transformer.config.in_channels != 12 and start_image is not None:
|
838 |
+
raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
|
839 |
+
|
840 |
+
if start_image is None and end_image is not None:
|
841 |
+
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
842 |
+
|
843 |
+
is_image = True if generation_method == "Image Generation" else False
|
844 |
+
|
845 |
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
846 |
|
847 |
self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
848 |
+
if self.lora_model_path != "none":
|
849 |
+
# lora part
|
850 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
851 |
self.pipeline.to("cuda")
|
852 |
|
853 |
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
|
|
855 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
856 |
|
857 |
try:
|
858 |
+
if self.transformer.config.in_channels == 12:
|
859 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
860 |
+
|
861 |
+
sample = self.pipeline(
|
862 |
+
prompt_textbox,
|
863 |
+
negative_prompt = negative_prompt_textbox,
|
864 |
+
num_inference_steps = sample_step_slider,
|
865 |
+
guidance_scale = cfg_scale_slider,
|
866 |
+
width = width_slider,
|
867 |
+
height = height_slider,
|
868 |
+
video_length = length_slider if not is_image else 1,
|
869 |
+
generator = generator,
|
870 |
+
|
871 |
+
video = input_video,
|
872 |
+
mask_video = input_video_mask,
|
873 |
+
clip_image = clip_image,
|
874 |
+
).videos
|
875 |
+
else:
|
876 |
+
sample = self.pipeline(
|
877 |
+
prompt_textbox,
|
878 |
+
negative_prompt = negative_prompt_textbox,
|
879 |
+
num_inference_steps = sample_step_slider,
|
880 |
+
guidance_scale = cfg_scale_slider,
|
881 |
+
width = width_slider,
|
882 |
+
height = height_slider,
|
883 |
+
video_length = length_slider if not is_image else 1,
|
884 |
+
generator = generator
|
885 |
+
).videos
|
886 |
except Exception as e:
|
887 |
gc.collect()
|
888 |
torch.cuda.empty_cache()
|
889 |
torch.cuda.ipc_collect()
|
890 |
+
if self.lora_model_path != "none":
|
891 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
892 |
+
if is_api:
|
893 |
+
return "", f"Error. error information is {str(e)}"
|
894 |
+
else:
|
895 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
896 |
+
|
897 |
+
gc.collect()
|
898 |
+
torch.cuda.empty_cache()
|
899 |
+
torch.cuda.ipc_collect()
|
900 |
+
|
901 |
+
# lora part
|
902 |
+
if self.lora_model_path != "none":
|
903 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
904 |
|
905 |
if not os.path.exists(self.savedir_sample):
|
906 |
os.makedirs(self.savedir_sample, exist_ok=True)
|
|
|
918 |
image = (image * 255).numpy().astype(np.uint8)
|
919 |
image = Image.fromarray(image)
|
920 |
image.save(save_sample_path)
|
921 |
+
if is_api:
|
922 |
+
return save_sample_path, "Success"
|
923 |
+
else:
|
924 |
+
if gradio_version_is_above_4:
|
925 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
926 |
+
else:
|
927 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
928 |
else:
|
929 |
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
930 |
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
931 |
+
if is_api:
|
932 |
+
return save_sample_path, "Success"
|
933 |
+
else:
|
934 |
+
if gradio_version_is_above_4:
|
935 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
936 |
+
else:
|
937 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
938 |
|
939 |
|
940 |
def ui_modelscope(edition, config_path, model_name, savedir_sample):
|
|
|
953 |
"""
|
954 |
)
|
955 |
with gr.Column(variant="panel"):
|
956 |
+
gr.Markdown(
|
957 |
+
"""
|
958 |
+
### 1. Model checkpoints (模型路径).
|
959 |
+
"""
|
960 |
+
)
|
961 |
+
with gr.Row():
|
962 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
963 |
+
label="Pretrained Model Path (预训练模型路径)",
|
964 |
+
choices=[model_name],
|
965 |
+
value=model_name,
|
966 |
+
interactive=False,
|
967 |
+
)
|
968 |
+
with gr.Row():
|
969 |
+
motion_module_dropdown = gr.Dropdown(
|
970 |
+
label="Select motion module (选择运动模块[非必需])",
|
971 |
+
choices=["none"],
|
972 |
+
value="none",
|
973 |
+
interactive=False,
|
974 |
+
visible=False
|
975 |
+
)
|
976 |
+
base_model_dropdown = gr.Dropdown(
|
977 |
+
label="Select base Dreambooth model (选择基模型[非必需])",
|
978 |
+
choices=["none"],
|
979 |
+
value="none",
|
980 |
+
interactive=False,
|
981 |
+
visible=False
|
982 |
+
)
|
983 |
+
with gr.Column(visible=False):
|
984 |
+
gr.Markdown(
|
985 |
+
"""
|
986 |
+
### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
|
987 |
+
"""
|
988 |
+
)
|
989 |
+
with gr.Row():
|
990 |
+
lora_model_dropdown = gr.Dropdown(
|
991 |
+
label="Select LoRA model",
|
992 |
+
choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
|
993 |
+
value="none",
|
994 |
+
interactive=True,
|
995 |
+
)
|
996 |
+
|
997 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
|
998 |
+
|
999 |
+
with gr.Column(variant="panel"):
|
1000 |
+
gr.Markdown(
|
1001 |
+
"""
|
1002 |
+
### 2. Configs for Generation (生成参数配置).
|
1003 |
+
"""
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
1007 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
|
1008 |
|
1009 |
with gr.Row():
|
1010 |
with gr.Column():
|
1011 |
with gr.Row():
|
1012 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
1013 |
+
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=20, minimum=10, maximum=30, step=1, interactive=False)
|
1014 |
|
1015 |
if edition == "v1":
|
1016 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=512, minimum=384, maximum=704, step=32)
|
1017 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=512, minimum=384, maximum=704, step=32)
|
1018 |
+
|
1019 |
+
with gr.Group():
|
1020 |
+
generation_method = gr.Radio(
|
1021 |
+
["Video Generation", "Image Generation"],
|
1022 |
+
value="Video Generation",
|
1023 |
+
show_label=False,
|
1024 |
+
visible=False,
|
1025 |
+
)
|
1026 |
+
length_slider = gr.Slider(label="Animation length (视频帧数)", value=80, minimum=40, maximum=96, step=1)
|
1027 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
|
1028 |
else:
|
1029 |
+
resize_method = gr.Radio(
|
1030 |
+
["Generate by", "Resize to the Start Image"],
|
1031 |
+
value="Generate by",
|
1032 |
+
show_label=False,
|
1033 |
+
)
|
1034 |
with gr.Column():
|
1035 |
gr.Markdown(
|
1036 |
"""
|
1037 |
+
We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s).
|
1038 |
+
|
1039 |
+
If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above.
|
1040 |
+
|
1041 |
+
If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
|
1042 |
"""
|
1043 |
)
|
1044 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
|
1045 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
|
1046 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
|
1047 |
+
|
1048 |
+
with gr.Group():
|
1049 |
+
generation_method = gr.Radio(
|
1050 |
+
["Video Generation", "Image Generation"],
|
1051 |
+
value="Video Generation",
|
1052 |
+
show_label=False,
|
1053 |
+
visible=True,
|
1054 |
+
)
|
1055 |
+
length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8)
|
1056 |
+
|
1057 |
+
with gr.Accordion("Image to Video (图片到视频)", open=True):
|
1058 |
with gr.Row():
|
1059 |
+
start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
|
1060 |
+
|
1061 |
+
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1062 |
+
def select_template(evt: gr.SelectData):
|
1063 |
+
text = {
|
1064 |
+
"asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1065 |
+
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1066 |
+
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1067 |
+
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1068 |
+
"asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1069 |
+
}[template_gallery_path[evt.index]]
|
1070 |
+
return template_gallery_path[evt.index], text
|
1071 |
+
|
1072 |
+
template_gallery = gr.Gallery(
|
1073 |
+
template_gallery_path,
|
1074 |
+
columns=5, rows=1,
|
1075 |
+
height=140,
|
1076 |
+
allow_preview=False,
|
1077 |
+
container=False,
|
1078 |
+
label="Template Examples",
|
1079 |
+
)
|
1080 |
+
template_gallery.select(select_template, None, [start_image, prompt_textbox])
|
1081 |
+
|
1082 |
+
with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
|
1083 |
+
end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
1084 |
+
|
1085 |
+
|
1086 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
|
1087 |
|
1088 |
with gr.Row():
|
1089 |
+
seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
|
1090 |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
1091 |
+
seed_button.click(
|
1092 |
+
fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
|
1093 |
+
inputs=[],
|
1094 |
+
outputs=[seed_textbox]
|
1095 |
+
)
|
1096 |
|
1097 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
1098 |
|
1099 |
with gr.Column():
|
1100 |
+
result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
|
1101 |
+
result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
|
1102 |
infer_progress = gr.Textbox(
|
1103 |
+
label="Generation Info (生成信息)",
|
1104 |
value="No task currently",
|
1105 |
interactive=False
|
1106 |
)
|
1107 |
|
1108 |
+
def upload_generation_method(generation_method):
|
1109 |
+
if generation_method == "Video Generation":
|
1110 |
+
return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
|
1111 |
+
elif generation_method == "Image Generation":
|
1112 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
1113 |
+
generation_method.change(
|
1114 |
+
upload_generation_method, generation_method, [length_slider]
|
1115 |
+
)
|
1116 |
+
|
1117 |
+
def upload_resize_method(resize_method):
|
1118 |
+
if resize_method == "Generate by":
|
1119 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
1120 |
+
else:
|
1121 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
1122 |
+
resize_method.change(
|
1123 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
1124 |
)
|
1125 |
|
1126 |
generate_button.click(
|
1127 |
fn=controller.generate,
|
1128 |
inputs=[
|
1129 |
+
diffusion_transformer_dropdown,
|
1130 |
+
motion_module_dropdown,
|
1131 |
+
base_model_dropdown,
|
1132 |
+
lora_model_dropdown,
|
1133 |
+
lora_alpha_slider,
|
1134 |
prompt_textbox,
|
1135 |
negative_prompt_textbox,
|
1136 |
sampler_dropdown,
|
1137 |
sample_step_slider,
|
1138 |
+
resize_method,
|
1139 |
width_slider,
|
1140 |
height_slider,
|
1141 |
+
base_resolution,
|
1142 |
+
generation_method,
|
1143 |
length_slider,
|
1144 |
cfg_scale_slider,
|
1145 |
+
start_image,
|
1146 |
+
end_image,
|
1147 |
seed_textbox,
|
1148 |
],
|
1149 |
outputs=[result_image, result_video, infer_progress]
|
|
|
1152 |
|
1153 |
|
1154 |
def post_eas(
|
1155 |
+
diffusion_transformer_dropdown, motion_module_dropdown,
|
1156 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
|
1157 |
prompt_textbox, negative_prompt_textbox,
|
1158 |
+
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1159 |
+
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1160 |
+
start_image, end_image, seed_textbox,
|
1161 |
):
|
1162 |
+
if start_image is not None:
|
1163 |
+
with open(start_image, 'rb') as file:
|
1164 |
+
file_content = file.read()
|
1165 |
+
start_image_encoded_content = base64.b64encode(file_content)
|
1166 |
+
start_image = start_image_encoded_content.decode('utf-8')
|
1167 |
+
|
1168 |
+
if end_image is not None:
|
1169 |
+
with open(end_image, 'rb') as file:
|
1170 |
+
file_content = file.read()
|
1171 |
+
end_image_encoded_content = base64.b64encode(file_content)
|
1172 |
+
end_image = end_image_encoded_content.decode('utf-8')
|
1173 |
+
|
1174 |
datas = {
|
1175 |
+
"base_model_path": base_model_dropdown,
|
1176 |
+
"motion_module_path": motion_module_dropdown,
|
1177 |
+
"lora_model_path": lora_model_dropdown,
|
1178 |
+
"lora_alpha_slider": lora_alpha_slider,
|
1179 |
"prompt_textbox": prompt_textbox,
|
1180 |
"negative_prompt_textbox": negative_prompt_textbox,
|
1181 |
"sampler_dropdown": sampler_dropdown,
|
1182 |
"sample_step_slider": sample_step_slider,
|
1183 |
+
"resize_method": resize_method,
|
1184 |
"width_slider": width_slider,
|
1185 |
"height_slider": height_slider,
|
1186 |
+
"base_resolution": base_resolution,
|
1187 |
+
"generation_method": generation_method,
|
1188 |
"length_slider": length_slider,
|
1189 |
"cfg_scale_slider": cfg_scale_slider,
|
1190 |
+
"start_image": start_image,
|
1191 |
+
"end_image": end_image,
|
1192 |
"seed_textbox": seed_textbox,
|
1193 |
}
|
1194 |
+
|
1195 |
session = requests.session()
|
1196 |
session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
|
1197 |
|
1198 |
+
response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas, timeout=300)
|
1199 |
+
|
1200 |
outputs = response.json()
|
1201 |
return outputs
|
1202 |
|
|
|
1208 |
|
1209 |
def generate(
|
1210 |
self,
|
1211 |
+
diffusion_transformer_dropdown,
|
1212 |
+
motion_module_dropdown,
|
1213 |
+
base_model_dropdown,
|
1214 |
+
lora_model_dropdown,
|
1215 |
+
lora_alpha_slider,
|
1216 |
prompt_textbox,
|
1217 |
negative_prompt_textbox,
|
1218 |
sampler_dropdown,
|
1219 |
sample_step_slider,
|
1220 |
+
resize_method,
|
1221 |
width_slider,
|
1222 |
height_slider,
|
1223 |
+
base_resolution,
|
1224 |
+
generation_method,
|
1225 |
length_slider,
|
1226 |
cfg_scale_slider,
|
1227 |
+
start_image,
|
1228 |
+
end_image,
|
1229 |
seed_textbox
|
1230 |
):
|
1231 |
+
is_image = True if generation_method == "Image Generation" else False
|
1232 |
+
|
1233 |
outputs = post_eas(
|
1234 |
+
diffusion_transformer_dropdown, motion_module_dropdown,
|
1235 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
|
1236 |
prompt_textbox, negative_prompt_textbox,
|
1237 |
+
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
1238 |
+
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
1239 |
+
start_image, end_image,
|
1240 |
+
seed_textbox
|
1241 |
)
|
1242 |
+
try:
|
1243 |
+
base64_encoding = outputs["base64_encoding"]
|
1244 |
+
except:
|
1245 |
+
return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
|
1246 |
+
|
1247 |
decoded_data = base64.b64decode(base64_encoding)
|
1248 |
|
1249 |
if not os.path.exists(self.savedir_sample):
|
|
|
1285 |
"""
|
1286 |
)
|
1287 |
with gr.Column(variant="panel"):
|
1288 |
+
gr.Markdown(
|
1289 |
+
"""
|
1290 |
+
### 1. Model checkpoints.
|
1291 |
+
"""
|
1292 |
+
)
|
1293 |
+
with gr.Row():
|
1294 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
1295 |
+
label="Pretrained Model Path",
|
1296 |
+
choices=[model_name],
|
1297 |
+
value=model_name,
|
1298 |
+
interactive=False,
|
1299 |
+
)
|
1300 |
+
with gr.Row():
|
1301 |
+
motion_module_dropdown = gr.Dropdown(
|
1302 |
+
label="Select motion module",
|
1303 |
+
choices=["none"],
|
1304 |
+
value="none",
|
1305 |
+
interactive=False,
|
1306 |
+
visible=False
|
1307 |
+
)
|
1308 |
+
base_model_dropdown = gr.Dropdown(
|
1309 |
+
label="Select base Dreambooth model",
|
1310 |
+
choices=["none"],
|
1311 |
+
value="none",
|
1312 |
+
interactive=False,
|
1313 |
+
visible=False
|
1314 |
+
)
|
1315 |
+
with gr.Column(visible=False):
|
1316 |
+
gr.Markdown(
|
1317 |
+
"""
|
1318 |
+
### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
|
1319 |
+
"""
|
1320 |
+
)
|
1321 |
+
with gr.Row():
|
1322 |
+
lora_model_dropdown = gr.Dropdown(
|
1323 |
+
label="Select LoRA model",
|
1324 |
+
choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
|
1325 |
+
value="none",
|
1326 |
+
interactive=True,
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
|
1330 |
+
|
1331 |
+
with gr.Column(variant="panel"):
|
1332 |
+
gr.Markdown(
|
1333 |
+
"""
|
1334 |
+
### 2. Configs for Generation.
|
1335 |
+
"""
|
1336 |
+
)
|
1337 |
+
|
1338 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
|
1339 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
|
1340 |
|
1341 |
with gr.Row():
|
1342 |
with gr.Column():
|
1343 |
with gr.Row():
|
1344 |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
1345 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=20, minimum=10, maximum=30, step=1, interactive=False)
|
1346 |
|
1347 |
if edition == "v1":
|
1348 |
width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
|
1349 |
height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
|
1350 |
+
|
1351 |
+
with gr.Group():
|
1352 |
+
generation_method = gr.Radio(
|
1353 |
+
["Video Generation", "Image Generation"],
|
1354 |
+
value="Video Generation",
|
1355 |
+
show_label=False,
|
1356 |
+
visible=False,
|
1357 |
+
)
|
1358 |
+
length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
|
1359 |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
|
1360 |
else:
|
1361 |
+
resize_method = gr.Radio(
|
1362 |
+
["Generate by", "Resize to the Start Image"],
|
1363 |
+
value="Generate by",
|
1364 |
+
show_label=False,
|
1365 |
+
)
|
1366 |
with gr.Column():
|
1367 |
gr.Markdown(
|
1368 |
"""
|
1369 |
+
We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s).
|
1370 |
+
|
1371 |
+
If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above.
|
1372 |
+
|
1373 |
+
If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
|
1374 |
"""
|
1375 |
)
|
1376 |
+
width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
|
1377 |
+
height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
|
1378 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
|
1379 |
+
|
1380 |
+
with gr.Group():
|
1381 |
+
generation_method = gr.Radio(
|
1382 |
+
["Video Generation", "Image Generation"],
|
1383 |
+
value="Video Generation",
|
1384 |
+
show_label=False,
|
1385 |
+
visible=True,
|
1386 |
+
)
|
1387 |
+
length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8)
|
1388 |
+
|
1389 |
+
with gr.Accordion("Image to Video", open=True):
|
1390 |
+
start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
|
1391 |
+
|
1392 |
+
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1393 |
+
def select_template(evt: gr.SelectData):
|
1394 |
+
text = {
|
1395 |
+
"asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1396 |
+
"asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1397 |
+
"asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1398 |
+
"asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1399 |
+
"asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
1400 |
+
}[template_gallery_path[evt.index]]
|
1401 |
+
return template_gallery_path[evt.index], text
|
1402 |
+
|
1403 |
+
template_gallery = gr.Gallery(
|
1404 |
+
template_gallery_path,
|
1405 |
+
columns=5, rows=1,
|
1406 |
+
height=140,
|
1407 |
+
allow_preview=False,
|
1408 |
+
container=False,
|
1409 |
+
label="Template Examples",
|
1410 |
+
)
|
1411 |
+
template_gallery.select(select_template, None, [start_image, prompt_textbox])
|
1412 |
+
|
1413 |
+
with gr.Accordion("The image at the ending of the video (Optional)", open=False):
|
1414 |
+
end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
|
1415 |
+
|
1416 |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
|
1417 |
|
1418 |
with gr.Row():
|
|
|
1435 |
interactive=False
|
1436 |
)
|
1437 |
|
1438 |
+
def upload_generation_method(generation_method):
|
1439 |
+
if generation_method == "Video Generation":
|
1440 |
+
return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
|
1441 |
+
elif generation_method == "Image Generation":
|
1442 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
1443 |
+
generation_method.change(
|
1444 |
+
upload_generation_method, generation_method, [length_slider]
|
1445 |
+
)
|
1446 |
+
|
1447 |
+
def upload_resize_method(resize_method):
|
1448 |
+
if resize_method == "Generate by":
|
1449 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
1450 |
+
else:
|
1451 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
1452 |
+
resize_method.change(
|
1453 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
1454 |
)
|
1455 |
|
1456 |
generate_button.click(
|
1457 |
fn=controller.generate,
|
1458 |
inputs=[
|
1459 |
+
diffusion_transformer_dropdown,
|
1460 |
+
motion_module_dropdown,
|
1461 |
+
base_model_dropdown,
|
1462 |
+
lora_model_dropdown,
|
1463 |
+
lora_alpha_slider,
|
1464 |
prompt_textbox,
|
1465 |
negative_prompt_textbox,
|
1466 |
sampler_dropdown,
|
1467 |
sample_step_slider,
|
1468 |
+
resize_method,
|
1469 |
width_slider,
|
1470 |
height_slider,
|
1471 |
+
base_resolution,
|
1472 |
+
generation_method,
|
1473 |
length_slider,
|
1474 |
cfg_scale_slider,
|
1475 |
+
start_image,
|
1476 |
+
end_image,
|
1477 |
seed_textbox,
|
1478 |
],
|
1479 |
outputs=[result_image, result_video, infer_progress]
|
easyanimate/utils/utils.py
CHANGED
@@ -8,6 +8,13 @@ import cv2
|
|
8 |
from einops import rearrange
|
9 |
from PIL import Image
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def color_transfer(sc, dc):
|
13 |
"""
|
@@ -62,3 +69,103 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
|
|
62 |
if path.endswith("mp4"):
|
63 |
path = path.replace('.mp4', '.gif')
|
64 |
outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from einops import rearrange
|
9 |
from PIL import Image
|
10 |
|
11 |
+
def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
|
12 |
+
target_pixels = int(base_resolution) * int(base_resolution)
|
13 |
+
original_width, original_height = Image.open(image).size
|
14 |
+
ratio = (target_pixels / (original_width * original_height)) ** 0.5
|
15 |
+
width_slider = round(original_width * ratio)
|
16 |
+
height_slider = round(original_height * ratio)
|
17 |
+
return height_slider, width_slider
|
18 |
|
19 |
def color_transfer(sc, dc):
|
20 |
"""
|
|
|
69 |
if path.endswith("mp4"):
|
70 |
path = path.replace('.mp4', '.gif')
|
71 |
outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
|
72 |
+
|
73 |
+
def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
|
74 |
+
if validation_image_start is not None and validation_image_end is not None:
|
75 |
+
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
76 |
+
image_start = clip_image = Image.open(validation_image_start)
|
77 |
+
else:
|
78 |
+
image_start = clip_image = validation_image_start
|
79 |
+
if type(validation_image_end) is str and os.path.isfile(validation_image_end):
|
80 |
+
image_end = Image.open(validation_image_end)
|
81 |
+
else:
|
82 |
+
image_end = validation_image_end
|
83 |
+
|
84 |
+
if type(image_start) is list:
|
85 |
+
clip_image = clip_image[0]
|
86 |
+
start_video = torch.cat(
|
87 |
+
[torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
|
88 |
+
dim=2
|
89 |
+
)
|
90 |
+
input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
|
91 |
+
input_video[:, :, :len(image_start)] = start_video
|
92 |
+
|
93 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
94 |
+
input_video_mask[:, :, len(image_start):] = 255
|
95 |
+
else:
|
96 |
+
input_video = torch.tile(
|
97 |
+
torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
|
98 |
+
[1, 1, video_length, 1, 1]
|
99 |
+
)
|
100 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
101 |
+
input_video_mask[:, :, 1:] = 255
|
102 |
+
|
103 |
+
if type(image_end) is list:
|
104 |
+
image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
|
105 |
+
end_video = torch.cat(
|
106 |
+
[torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
|
107 |
+
dim=2
|
108 |
+
)
|
109 |
+
input_video[:, :, -len(end_video):] = end_video
|
110 |
+
|
111 |
+
input_video_mask[:, :, -len(image_end):] = 0
|
112 |
+
else:
|
113 |
+
image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
|
114 |
+
input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
|
115 |
+
input_video_mask[:, :, -1:] = 0
|
116 |
+
|
117 |
+
input_video = input_video / 255
|
118 |
+
|
119 |
+
elif validation_image_start is not None:
|
120 |
+
if type(validation_image_start) is str and os.path.isfile(validation_image_start):
|
121 |
+
image_start = clip_image = Image.open(validation_image_start).convert("RGB")
|
122 |
+
else:
|
123 |
+
image_start = clip_image = validation_image_start
|
124 |
+
|
125 |
+
if type(image_start) is list:
|
126 |
+
clip_image = clip_image[0]
|
127 |
+
start_video = torch.cat(
|
128 |
+
[torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
|
129 |
+
dim=2
|
130 |
+
)
|
131 |
+
input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
|
132 |
+
input_video[:, :, :len(image_start)] = start_video
|
133 |
+
input_video = input_video / 255
|
134 |
+
|
135 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
136 |
+
input_video_mask[:, :, len(image_start):] = 255
|
137 |
+
else:
|
138 |
+
input_video = torch.tile(
|
139 |
+
torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
|
140 |
+
[1, 1, video_length, 1, 1]
|
141 |
+
) / 255
|
142 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
143 |
+
input_video_mask[:, :, 1:, ] = 255
|
144 |
+
else:
|
145 |
+
input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
|
146 |
+
input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
|
147 |
+
clip_image = None
|
148 |
+
|
149 |
+
return input_video, input_video_mask, clip_image
|
150 |
+
|
151 |
+
def video_frames(input_video_path):
|
152 |
+
cap = cv2.VideoCapture(input_video_path)
|
153 |
+
frames = []
|
154 |
+
while True:
|
155 |
+
ret, frame = cap.read()
|
156 |
+
if not ret:
|
157 |
+
break
|
158 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
159 |
+
cap.release()
|
160 |
+
cv2.destroyAllWindows()
|
161 |
+
return frames
|
162 |
+
|
163 |
+
def get_video_to_video_latent(validation_videos, video_length):
|
164 |
+
input_video = video_frames(validation_videos)
|
165 |
+
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
166 |
+
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
167 |
+
|
168 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
169 |
+
input_video_mask[:, :, :] = 255
|
170 |
+
|
171 |
+
return input_video, input_video_mask, None
|