Spaces:
Runtime error
Runtime error
chenyangqi
commited on
Commit
·
4dff355
1
Parent(s):
8214cae
cache the ckpt; fix bugs when input new video
Browse files- .gitignore +2 -1
- FateZero/test_fatezero.py +24 -18
- FateZero/video_diffusion/common/util.py +8 -2
- FateZero/video_diffusion/data/dataset.py +11 -5
- app_fatezero.py +4 -4
- inference_fatezero.py +84 -51
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
trash/*
|
|
|
|
1 |
+
trash/*
|
2 |
+
tmp
|
FateZero/test_fatezero.py
CHANGED
@@ -48,6 +48,10 @@ def test(
|
|
48 |
config: str,
|
49 |
pretrained_model_path: str,
|
50 |
train_dataset: Dict,
|
|
|
|
|
|
|
|
|
51 |
logdir: str = None,
|
52 |
validation_sample_logger_config: Optional[Dict] = None,
|
53 |
test_pipeline_config: Optional[Dict] = None,
|
@@ -79,26 +83,28 @@ def test(
|
|
79 |
set_seed(seed)
|
80 |
|
81 |
# Load the tokenizer
|
82 |
-
tokenizer
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
# Load models and create wrapper for stable diffusion
|
89 |
-
text_encoder
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
vae
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
unet
|
100 |
-
|
101 |
-
|
|
|
102 |
|
103 |
if 'target' not in test_pipeline_config:
|
104 |
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
|
|
48 |
config: str,
|
49 |
pretrained_model_path: str,
|
50 |
train_dataset: Dict,
|
51 |
+
tokenizer = None,
|
52 |
+
text_encoder = None,
|
53 |
+
vae = None,
|
54 |
+
unet = None,
|
55 |
logdir: str = None,
|
56 |
validation_sample_logger_config: Optional[Dict] = None,
|
57 |
test_pipeline_config: Optional[Dict] = None,
|
|
|
83 |
set_seed(seed)
|
84 |
|
85 |
# Load the tokenizer
|
86 |
+
if tokenizer is None:
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
88 |
+
pretrained_model_path,
|
89 |
+
subfolder="tokenizer",
|
90 |
+
use_fast=False,
|
91 |
+
)
|
92 |
|
93 |
# Load models and create wrapper for stable diffusion
|
94 |
+
if text_encoder is None:
|
95 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
96 |
+
pretrained_model_path,
|
97 |
+
subfolder="text_encoder",
|
98 |
+
)
|
99 |
+
if vae is None:
|
100 |
+
vae = AutoencoderKL.from_pretrained(
|
101 |
+
pretrained_model_path,
|
102 |
+
subfolder="vae",
|
103 |
+
)
|
104 |
+
if unet is None:
|
105 |
+
unet = UNetPseudo3DConditionModel.from_2d_model(
|
106 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
107 |
+
)
|
108 |
|
109 |
if 'target' not in test_pipeline_config:
|
110 |
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
FateZero/video_diffusion/common/util.py
CHANGED
@@ -4,7 +4,7 @@ import copy
|
|
4 |
import inspect
|
5 |
import datetime
|
6 |
from typing import List, Tuple, Optional, Dict
|
7 |
-
|
8 |
|
9 |
def glob_files(
|
10 |
root_path: str,
|
@@ -68,6 +68,12 @@ def get_time_string() -> str:
|
|
68 |
def get_function_args() -> Dict:
|
69 |
frame = sys._getframe(1)
|
70 |
args, _, _, values = inspect.getargvalues(frame)
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
return args_dict
|
|
|
4 |
import inspect
|
5 |
import datetime
|
6 |
from typing import List, Tuple, Optional, Dict
|
7 |
+
import torch
|
8 |
|
9 |
def glob_files(
|
10 |
root_path: str,
|
|
|
68 |
def get_function_args() -> Dict:
|
69 |
frame = sys._getframe(1)
|
70 |
args, _, _, values = inspect.getargvalues(frame)
|
71 |
+
tmp_dict = {}
|
72 |
+
for arg in args:
|
73 |
+
v = values[arg]
|
74 |
+
if not isinstance(v, torch.nn.Module) and arg !='tokenizer' :
|
75 |
+
tmp_dict[arg] = v
|
76 |
+
|
77 |
+
args_dict = copy.deepcopy(tmp_dict)
|
78 |
|
79 |
return args_dict
|
FateZero/video_diffusion/data/dataset.py
CHANGED
@@ -6,6 +6,7 @@ from einops import rearrange
|
|
6 |
from pathlib import Path
|
7 |
import imageio
|
8 |
import cv2
|
|
|
9 |
|
10 |
import torch
|
11 |
from torch.utils.data import Dataset
|
@@ -156,7 +157,7 @@ class ImageSequenceDataset(Dataset):
|
|
156 |
images = []
|
157 |
if path[-4:] == '.mp4':
|
158 |
path = self.mp4_to_png(path)
|
159 |
-
|
160 |
|
161 |
for file in sorted(os.listdir(path)):
|
162 |
if file.endswith(IMAGE_EXTENSION):
|
@@ -164,14 +165,19 @@ class ImageSequenceDataset(Dataset):
|
|
164 |
return images
|
165 |
|
166 |
# @staticmethod
|
|
|
167 |
def mp4_to_png(self, video_source=None):
|
168 |
reader = imageio.get_reader(video_source)
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
for i, im in enumerate(reader):
|
172 |
# use :05d to add zero, no space before the 05d
|
173 |
# if (i+1)%10 == 0:
|
174 |
-
path = os.path.join(
|
175 |
# print(path)
|
176 |
cv2.imwrite(path, im[:, :, ::-1])
|
177 |
-
|
|
|
|
6 |
from pathlib import Path
|
7 |
import imageio
|
8 |
import cv2
|
9 |
+
import shutil
|
10 |
|
11 |
import torch
|
12 |
from torch.utils.data import Dataset
|
|
|
157 |
images = []
|
158 |
if path[-4:] == '.mp4':
|
159 |
path = self.mp4_to_png(path)
|
160 |
+
|
161 |
|
162 |
for file in sorted(os.listdir(path)):
|
163 |
if file.endswith(IMAGE_EXTENSION):
|
|
|
165 |
return images
|
166 |
|
167 |
# @staticmethod
|
168 |
+
|
169 |
def mp4_to_png(self, video_source=None):
|
170 |
reader = imageio.get_reader(video_source)
|
171 |
+
dir_path = './tmp/fatezero_user_video'
|
172 |
+
if os.path.exists(dir_path):
|
173 |
+
shutil.rmtree(dir_path)
|
174 |
+
os.makedirs(dir_path, exist_ok=True)
|
175 |
+
|
176 |
for i, im in enumerate(reader):
|
177 |
# use :05d to add zero, no space before the 05d
|
178 |
# if (i+1)%10 == 0:
|
179 |
+
path = os.path.join(dir_path, f"{i:05d}.png")
|
180 |
# print(path)
|
181 |
cv2.imwrite(path, im[:, :, ::-1])
|
182 |
+
self.path = dir_path
|
183 |
+
return self.path
|
app_fatezero.py
CHANGED
@@ -28,7 +28,7 @@ from inference_fatezero import merge_config_then_run
|
|
28 |
# TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
|
29 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
30 |
# pipe = InferencePipeline(HF_TOKEN)
|
31 |
-
|
32 |
# app = InferenceUtil(HF_TOKEN)
|
33 |
|
34 |
with gr.Blocks(css='style.css') as demo:
|
@@ -288,7 +288,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
288 |
*ImageSequenceDataset_list
|
289 |
],
|
290 |
outputs=result,
|
291 |
-
fn=
|
292 |
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
293 |
|
294 |
# model_id.change(fn=app.load_model_info,
|
@@ -312,8 +312,8 @@ with gr.Blocks(css='style.css') as demo:
|
|
312 |
*ImageSequenceDataset_list
|
313 |
]
|
314 |
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
315 |
-
target_prompt.submit(fn=
|
316 |
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
317 |
-
run_button.click(fn=
|
318 |
|
319 |
demo.queue().launch()
|
|
|
28 |
# TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
|
29 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
30 |
# pipe = InferencePipeline(HF_TOKEN)
|
31 |
+
pipe = merge_config_then_run()
|
32 |
# app = InferenceUtil(HF_TOKEN)
|
33 |
|
34 |
with gr.Blocks(css='style.css') as demo:
|
|
|
288 |
*ImageSequenceDataset_list
|
289 |
],
|
290 |
outputs=result,
|
291 |
+
fn=pipe.run,
|
292 |
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
293 |
|
294 |
# model_id.change(fn=app.load_model_info,
|
|
|
312 |
*ImageSequenceDataset_list
|
313 |
]
|
314 |
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
315 |
+
target_prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
316 |
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
317 |
+
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
318 |
|
319 |
demo.queue().launch()
|
inference_fatezero.py
CHANGED
@@ -4,8 +4,40 @@ from FateZero.test_fatezero import *
|
|
4 |
import copy
|
5 |
import gradio as gr
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
model_id,
|
10 |
data_path,
|
11 |
source_prompt,
|
@@ -27,58 +59,59 @@ def merge_config_then_run(
|
|
27 |
top_crop=0,
|
28 |
bottom_crop=0,
|
29 |
):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
# fatezero config
|
62 |
-
p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
|
63 |
-
p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
|
64 |
-
p2p_config_now['self_replace_steps'] = self_replace_steps
|
65 |
-
p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
|
66 |
-
p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
|
67 |
-
config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
# ddim config
|
71 |
-
config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
|
72 |
-
config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
|
80 |
-
return mp4_path
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
if __name__ == "__main__":
|
84 |
-
run()
|
|
|
4 |
import copy
|
5 |
import gradio as gr
|
6 |
|
7 |
+
class merge_config_then_run():
|
8 |
+
def __init__(self) -> None:
|
9 |
+
# Load the tokenizer
|
10 |
+
pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
12 |
+
pretrained_model_path,
|
13 |
+
# 'FateZero/ckpt/stable-diffusion-v1-4',
|
14 |
+
subfolder="tokenizer",
|
15 |
+
use_fast=False,
|
16 |
+
)
|
17 |
|
18 |
+
# Load models and create wrapper for stable diffusion
|
19 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
20 |
+
pretrained_model_path,
|
21 |
+
subfolder="text_encoder",
|
22 |
+
)
|
23 |
+
|
24 |
+
self.vae = AutoencoderKL.from_pretrained(
|
25 |
+
pretrained_model_path,
|
26 |
+
subfolder="vae",
|
27 |
+
)
|
28 |
+
model_config = {
|
29 |
+
"lora": 160,
|
30 |
+
# temporal_downsample_time: 4
|
31 |
+
"SparseCausalAttention_index": ['mid'],
|
32 |
+
"least_sc_channel": 640
|
33 |
+
}
|
34 |
+
self.unet = UNetPseudo3DConditionModel.from_2d_model(
|
35 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
36 |
+
)
|
37 |
+
|
38 |
+
def run(
|
39 |
+
self,
|
40 |
+
# def merge_config_then_run(
|
41 |
model_id,
|
42 |
data_path,
|
43 |
source_prompt,
|
|
|
59 |
top_crop=0,
|
60 |
bottom_crop=0,
|
61 |
):
|
62 |
+
# , ] = inputs
|
63 |
+
default_edit_config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'
|
64 |
+
Omegadict_default_edit_config = OmegaConf.load(default_edit_config)
|
65 |
+
|
66 |
+
dataset_time_string = get_time_string()
|
67 |
+
config_now = copy.deepcopy(Omegadict_default_edit_config)
|
68 |
+
print(f"config_now['pretrained_model_path'] = model_id {model_id}")
|
69 |
+
# config_now['pretrained_model_path'] = model_id
|
70 |
+
config_now['train_dataset']['prompt'] = source_prompt
|
71 |
+
config_now['train_dataset']['path'] = data_path
|
72 |
+
# ImageSequenceDataset_dict = { }
|
73 |
+
offset_dict = {
|
74 |
+
"left": left_crop,
|
75 |
+
"right": right_crop,
|
76 |
+
"top": top_crop,
|
77 |
+
"bottom": bottom_crop,
|
78 |
+
}
|
79 |
+
ImageSequenceDataset_dict = {
|
80 |
+
"start_sample_frame" : start_sample_frame,
|
81 |
+
"n_sample_frame" : n_sample_frame,
|
82 |
+
"stride" : stride,
|
83 |
+
"offset": offset_dict,
|
84 |
+
}
|
85 |
+
config_now['train_dataset'].update(ImageSequenceDataset_dict)
|
86 |
+
if user_input_video and data_path is None:
|
87 |
+
raise gr.Error('You need to upload a video or choose a provided video')
|
88 |
+
if user_input_video is not None and user_input_video.name is not None:
|
89 |
+
config_now['train_dataset']['path'] = user_input_video.name
|
90 |
+
config_now['validation_sample_logger_config']['prompts'] = [target_prompt]
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
# fatezero config
|
94 |
+
p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
|
95 |
+
p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
|
96 |
+
p2p_config_now['self_replace_steps'] = self_replace_steps
|
97 |
+
p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
|
98 |
+
p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
|
99 |
+
config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
|
100 |
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
# ddim config
|
103 |
+
config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
|
104 |
+
config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
|
105 |
+
|
|
|
|
|
106 |
|
107 |
+
logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}'
|
108 |
+
config_now['logdir'] = logdir
|
109 |
+
print(f'Saving at {logdir}')
|
110 |
+
save_path = test(tokenizer = self.tokenizer,
|
111 |
+
text_encoder = self.text_encoder,
|
112 |
+
vae = self.vae,
|
113 |
+
unet = self.unet,
|
114 |
+
config=default_edit_config, **config_now)
|
115 |
+
mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
|
116 |
+
return mp4_path
|
117 |
|
|
|
|