maxin-cn commited on
Commit
9e8c2c6
1 Parent(s): 95a07ef

Upload folder using huggingface_hub

Browse files
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import argparse
5
+ import torchvision
6
+
7
+
8
+ from pipelines.pipeline_videogen import VideoGenPipeline
9
+ from diffusers.schedulers import DDIMScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from diffusers.models import AutoencoderKLTemporalDecoder
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
+ from omegaconf import OmegaConf
14
+
15
+ import os, sys
16
+ sys.path.append(os.path.split(sys.path[0])[0])
17
+ from models import get_models
18
+ import imageio
19
+ from PIL import Image
20
+ import numpy as np
21
+ from datasets import video_transforms
22
+ from torchvision import transforms
23
+ from einops import rearrange, repeat
24
+ from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
25
+ from copy import deepcopy
26
+ import spaces
27
+ import requests
28
+ from datetime import datetime
29
+ import random
30
+
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--config", type=str, default="./configs/sample.yaml")
33
+ args = parser.parse_args()
34
+ args = OmegaConf.load(args.config)
35
+
36
+ torch.set_grad_enabled(False)
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ dtype = torch.float16 # torch.float16
39
+
40
+ unet = get_models(args).to(device, dtype=dtype)
41
+
42
+ if args.enable_vae_temporal_decoder:
43
+ if args.use_dct:
44
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
45
+ else:
46
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
47
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
48
+ else:
49
+ vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
50
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
51
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
52
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
53
+
54
+ # set eval mode
55
+ unet.eval()
56
+ vae.eval()
57
+ text_encoder.eval()
58
+
59
+ basedir = os.getcwd()
60
+ savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
61
+ savedir_sample = os.path.join(savedir, "sample")
62
+ os.makedirs(savedir, exist_ok=True)
63
+
64
+ def update_and_resize_image(input_image_path, height_slider, width_slider):
65
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
66
+ pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
67
+ else:
68
+ pil_image = Image.open(input_image_path).convert('RGB')
69
+
70
+ original_width, original_height = pil_image.size
71
+
72
+ if original_height == height_slider and original_width == width_slider:
73
+ return gr.Image(value=np.array(pil_image))
74
+
75
+ ratio1 = height_slider / original_height
76
+ ratio2 = width_slider / original_width
77
+
78
+ if ratio1 > ratio2:
79
+ new_width = int(original_width * ratio1)
80
+ new_height = int(original_height * ratio1)
81
+ else:
82
+ new_width = int(original_width * ratio2)
83
+ new_height = int(original_height * ratio2)
84
+
85
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
86
+
87
+ left = (new_width - width_slider) / 2
88
+ top = (new_height - height_slider) / 2
89
+ right = left + width_slider
90
+ bottom = top + height_slider
91
+
92
+ pil_image = pil_image.crop((left, top, right, bottom))
93
+
94
+ return gr.Image(value=np.array(pil_image))
95
+
96
+
97
+ def update_textbox_and_save_image(input_image, height_slider, width_slider):
98
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
99
+
100
+ original_width, original_height = pil_image.size
101
+
102
+ ratio1 = height_slider / original_height
103
+ ratio2 = width_slider / original_width
104
+
105
+ if ratio1 > ratio2:
106
+ new_width = int(original_width * ratio1)
107
+ new_height = int(original_height * ratio1)
108
+ else:
109
+ new_width = int(original_width * ratio2)
110
+ new_height = int(original_height * ratio2)
111
+
112
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
113
+
114
+ left = (new_width - width_slider) / 2
115
+ top = (new_height - height_slider) / 2
116
+ right = left + width_slider
117
+ bottom = top + height_slider
118
+
119
+ pil_image = pil_image.crop((left, top, right, bottom))
120
+
121
+ img_path = os.path.join(savedir, "input_image.png")
122
+ pil_image.save(img_path)
123
+
124
+ return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
125
+
126
+ def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
127
+ image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
128
+ image = transform_video(image)
129
+ image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
130
+ image = image.unsqueeze(2)
131
+ return image
132
+
133
+
134
+ @spaces.GPU
135
+ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
136
+
137
+ torch.manual_seed(seed)
138
+
139
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
140
+ subfolder="scheduler",
141
+ beta_start=args.beta_start,
142
+ beta_end=args.beta_end,
143
+ beta_schedule=args.beta_schedule)
144
+
145
+ videogen_pipeline = VideoGenPipeline(vae=vae,
146
+ text_encoder=text_encoder,
147
+ tokenizer=tokenizer,
148
+ scheduler=scheduler,
149
+ unet=unet).to(device)
150
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
151
+
152
+ transform_video = transforms.Compose([
153
+ video_transforms.ToTensorVideo(),
154
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
155
+ ])
156
+
157
+ if args.use_dct:
158
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
159
+ else:
160
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
161
+
162
+ if use_dctinit:
163
+ # filter params
164
+ print("Using DCT!")
165
+ base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
166
+
167
+ # define filter
168
+ freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
169
+
170
+ noise = torch.randn(1, 4, 15, 40, 64).to(device)
171
+
172
+ # add noise to base_content
173
+ diffuse_timesteps = torch.full((1,),int(noise_level))
174
+ diffuse_timesteps = diffuse_timesteps.long()
175
+
176
+ # 3d content
177
+ base_content_noise = scheduler.add_noise(
178
+ original_samples=base_content_repeat.to(device),
179
+ noise=noise,
180
+ timesteps=diffuse_timesteps.to(device))
181
+
182
+ # 3d content
183
+ latents = exchanged_mixed_dct_freq(noise=noise,
184
+ base_content=base_content_noise,
185
+ LPF_3d=freq_filter).to(dtype=torch.float16)
186
+
187
+ base_content = base_content.to(dtype=torch.float16)
188
+
189
+ videos = videogen_pipeline(prompt,
190
+ negative_prompt=negative_prompt,
191
+ latents=latents if use_dctinit else None,
192
+ base_content=base_content,
193
+ video_length=15,
194
+ height=height,
195
+ width=width,
196
+ num_inference_steps=diffusion_step,
197
+ guidance_scale=scfg_scale,
198
+ motion_bucket_id=100-motion_bucket_id,
199
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
200
+
201
+ save_path = args.save_img_path + 'temp' + '.mp4'
202
+ # torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
203
+ imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
204
+ return save_path
205
+
206
+
207
+ if not os.path.exists(args.save_img_path):
208
+ os.makedirs(args.save_img_path)
209
+
210
+
211
+ with gr.Blocks() as demo:
212
+
213
+ gr.Markdown("<font color=red size=6.5><center>Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models</center></font>")
214
+ gr.Markdown(
215
+ """<div style="display: flex;align-items: center;justify-content: center">
216
+ [<a href="https://arxiv.org/abs/2407.15642">Arxiv Report</a>] | [<a href="https://https://maxin-cn.github.io/cinemo_project/">Project Page</a>] | [<a href="https://github.com/maxin-cn/Cinemo">Github</a>]</div>
217
+ """
218
+ )
219
+
220
+
221
+ with gr.Column(variant="panel"):
222
+ with gr.Row():
223
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
224
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
225
+
226
+ with gr.Row(equal_height=False):
227
+ with gr.Column():
228
+ with gr.Row():
229
+ input_image = gr.Image(label="Input Image", interactive=True)
230
+ result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
231
+
232
+ generate_button = gr.Button(value="Generate", variant='primary')
233
+
234
+ with gr.Accordion("Advanced options", open=False):
235
+ gr.Markdown(
236
+ """
237
+ - Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
238
+ - Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
239
+ - After setting the input image path, press the "Preview" button to visualize the resized input image.
240
+ """
241
+ )
242
+ with gr.Column():
243
+ with gr.Row():
244
+ input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
245
+ preview_button = gr.Button(value="Preview")
246
+
247
+ with gr.Row():
248
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
249
+
250
+ with gr.Row():
251
+ seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
252
+ # seed_textbox = gr.Textbox(label="Seed", value=100)
253
+ # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
254
+ # seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
255
+
256
+ with gr.Row():
257
+ height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
258
+ width = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=16, interactive=False)
259
+ with gr.Row():
260
+ txt_cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1, interactive=True)
261
+ motion_bucket_id = gr.Slider(label="Motion Intensity", value=10, minimum=1, maximum=20, step=1, interactive=True)
262
+
263
+ with gr.Row():
264
+ use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
265
+ dct_coefficients = gr.Slider(label="DCT Coefficients", value=0.23, minimum=0, maximum=1, step=0.01, interactive=True)
266
+ noise_level = gr.Slider(label="Noise Level", value=985, minimum=1, maximum=999, step=1, interactive=True)
267
+
268
+ input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height, width], outputs=[input_image_path, input_image])
269
+ preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
270
+ input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
271
+
272
+ EXAMPLES = [
273
+ ["./example/aircrafts_flying/0.jpg", "aircrafts flying" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
+ ["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
+ ["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
276
+ ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
277
+ ["./example/house_rotating/0.jpg", "house rotating" , "low quality", 50, 320, 512, 7.5, True, 0.23, 985, 10, 46640174],
278
+ ["./example/people_runing/0.jpg", "people runing" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
279
+ ["./example/shark_swimming/0.jpg", "shark swimming" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 32947978],
280
+ ["./example/car_moving/0.jpg", "car moving" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 75469653],
281
+ ["./example/windmill_turning/0.jpg", "windmill turning" , "background changing", 50, 320, 512, 7.5, True, 0.21, 975, 10, 89378613],
282
+ ]
283
+
284
+ examples = gr.Examples(
285
+ examples = EXAMPLES,
286
+ fn = gen_video,
287
+ inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
288
+ outputs=[result_video],
289
+ # cache_examples=True,
290
+ cache_examples="lazy",
291
+ )
292
+
293
+ generate_button.click(
294
+ fn=gen_video,
295
+ inputs=[
296
+ input_image,
297
+ prompt_textbox,
298
+ negative_prompt_textbox,
299
+ sample_step_slider,
300
+ height,
301
+ width,
302
+ txt_cfg_scale,
303
+ use_dctinit,
304
+ dct_coefficients,
305
+ noise_level,
306
+ motion_bucket_id,
307
+ seed_textbox,
308
+ ],
309
+ outputs=[result_video]
310
+ )
311
+
312
+ demo.launch(debug=False, share=True, server_name="127.0.0.1")
configs/sample.yaml CHANGED
@@ -22,7 +22,7 @@ use_fp16: True
22
  # sample config:
23
  seed:
24
  run_time: 0
25
- use_dct: True
26
  guidance_scale: 7.5 #
27
  motion_bucket_id: 95 # [0-19] The larger the value, the stronger the motion intensity
28
  sample_method: 'DDIM'
 
22
  # sample config:
23
  seed:
24
  run_time: 0
25
+ use_dct: False
26
  guidance_scale: 7.5 #
27
  motion_bucket_id: 95 # [0-19] The larger the value, the stronger the motion intensity
28
  sample_method: 'DDIM'
demo.py CHANGED
@@ -270,7 +270,7 @@ with gr.Blocks() as demo:
270
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
271
 
272
  EXAMPLES = [
273
- ["./example/aircrafts_flying/0.jpg", "aircrafts flying" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
  ["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
  ["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
276
  ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
 
270
  input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
271
 
272
  EXAMPLES = [
273
+ ["./example/red_panda_eating_bamboo/0.jpg", "red panda eating bamboo" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
  ["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
  ["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
276
  ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
example/red_panda_eating_bamboo/0.jpg ADDED
example/red_panda_eating_bamboo/a_red_panda_eating_bamboo.mp4 ADDED
Binary file (781 kB). View file
 
gradio_cached_examples/40/Generated Animation/9cd847472e0becf5b842/.nfsab7733d2089f861000009449 ADDED
Binary file (257 kB). View file
 
pipelines/__pycache__/pipeline_inversion.cpython-312.pyc ADDED
Binary file (35.6 kB). View file
 
pipelines/video_editing.sh CHANGED
@@ -1,2 +1,2 @@
1
- export CUDA_VISIBLE_DEVICES=6
2
  python pipelines/video_editting.py --config configs/sample.yaml
 
1
+ export CUDA_VISIBLE_DEVICES=0
2
  python pipelines/video_editting.py --config configs/sample.yaml
pipelines/video_editting.py CHANGED
@@ -23,6 +23,9 @@ from copy import deepcopy
23
  from PIL import Image
24
  from datasets import video_transforms
25
  from torchvision import transforms
 
 
 
26
 
27
  def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
28
  with open(path, 'rb') as f:
@@ -83,9 +86,10 @@ def main(args):
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
  dtype = torch.float16 # torch.float16
85
 
86
- unet = get_models(args).to(device, dtype=torch.float16)
87
- state_dict = find_model(args.ckpt)
88
- unet.load_state_dict(state_dict)
 
89
 
90
  if args.enable_vae_temporal_decoder:
91
  if args.use_dct:
@@ -140,7 +144,8 @@ def main(args):
140
 
141
 
142
  # video_path = './video_editing/A_man_walking_on_the_beach.mp4'
143
- video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
 
144
 
145
 
146
  video_reader = DecordInit()
@@ -154,14 +159,21 @@ def main(args):
154
  base_content, motion_latents = separation_content_motion(latents)
155
 
156
  # image_path = "./video_editing/a_man_walking_in_the_park.png"
157
- image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
158
- edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device)
 
 
 
 
 
159
 
160
  if not os.path.exists(args.save_img_path):
161
  os.makedirs(args.save_img_path)
162
 
163
  # prompt_inversion = 'a man walking on the beach'
164
- prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style'
 
 
165
  latents = videogen_pipeline_inversion(prompt_inversion,
166
  latents=motion_latents,
167
  base_content=base_content,
@@ -175,7 +187,39 @@ def main(args):
175
  output_type="latent").video
176
 
177
  # prompt = 'a man walking in the park'
178
- prompt = 'a corgi walking in the park at sunrise, oil painting style'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  videos = videogen_pipeline(prompt,
180
  latents=latents,
181
  base_content=edit_content,
@@ -183,8 +227,8 @@ def main(args):
183
  height=args.image_size[0],
184
  width=args.image_size[1],
185
  num_inference_steps=args.num_sampling_steps,
186
- guidance_scale=1.0,
187
- # guidance_scale=args.guidance_scale,
188
  motion_bucket_id=args.motion_bucket_id,
189
  enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
190
  imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
 
23
  from PIL import Image
24
  from datasets import video_transforms
25
  from torchvision import transforms
26
+ from models.unet import UNet3DConditionModel
27
+ from einops import repeat
28
+ from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
29
 
30
  def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
31
  with open(path, 'rb') as f:
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  dtype = torch.float16 # torch.float16
88
 
89
+ # unet = get_models(args).to(device, dtype=torch.float16)
90
+ # state_dict = find_model(args.ckpt)
91
+ # unet.load_state_dict(state_dict)
92
+ unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet").to(device, dtype=torch.float16)
93
 
94
  if args.enable_vae_temporal_decoder:
95
  if args.use_dct:
 
144
 
145
 
146
  # video_path = './video_editing/A_man_walking_on_the_beach.mp4'
147
+ # video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
148
+ video_path = './video_editing/test_03.mp4'
149
 
150
 
151
  video_reader = DecordInit()
 
159
  base_content, motion_latents = separation_content_motion(latents)
160
 
161
  # image_path = "./video_editing/a_man_walking_in_the_park.png"
162
+ # image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
163
+ image_path = "./video_editing/test_03.png"
164
+
165
+ if args.use_dct:
166
+ edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
167
+ else:
168
+ edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
169
 
170
  if not os.path.exists(args.save_img_path):
171
  os.makedirs(args.save_img_path)
172
 
173
  # prompt_inversion = 'a man walking on the beach'
174
+ # prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style'
175
+ # prompt_inversion = 'A girl is playing the guitar in her room'
176
+ prompt_inversion = 'A man is walking inside the church'
177
  latents = videogen_pipeline_inversion(prompt_inversion,
178
  latents=motion_latents,
179
  base_content=base_content,
 
187
  output_type="latent").video
188
 
189
  # prompt = 'a man walking in the park'
190
+ # prompt = 'a corgi walking in the park at sunrise, oil painting style'
191
+ # prompt = 'A girl is playing the guitar in her room'
192
+ prompt = 'A man is walking inside the church'
193
+
194
+ if args.use_dct:
195
+ # filter params
196
+ print("Using DCT!")
197
+ edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
198
+
199
+ # define filter
200
+ freq_filter = dct_low_pass_filter(dct_coefficients=edit_content,
201
+ percentage=0.23)
202
+
203
+ noise = latents.to(dtype=torch.float64)
204
+
205
+ # add noise to base_content
206
+ diffuse_timesteps = torch.full((1,),int(985))
207
+ diffuse_timesteps = diffuse_timesteps.long()
208
+
209
+ # 3d content
210
+ edit_content_noise = scheduler.add_noise(
211
+ original_samples=edit_content_repeat.to(device),
212
+ noise=noise,
213
+ timesteps=diffuse_timesteps.to(device))
214
+
215
+ # 3d content
216
+ latents = exchanged_mixed_dct_freq(noise=noise,
217
+ base_content=edit_content_noise,
218
+ LPF_3d=freq_filter).to(dtype=torch.float16)
219
+
220
+ latents = latents.to(dtype=torch.float16)
221
+ edit_content = edit_content.to(dtype=torch.float16)
222
+
223
  videos = videogen_pipeline(prompt,
224
  latents=latents,
225
  base_content=edit_content,
 
227
  height=args.image_size[0],
228
  width=args.image_size[1],
229
  num_inference_steps=args.num_sampling_steps,
230
+ # guidance_scale=1.0,
231
+ guidance_scale=args.guidance_scale,
232
  motion_bucket_id=args.motion_bucket_id,
233
  enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
234
  imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
sample_videos/A_girl_is_playing_the_guitar_in_her_room_0000-imageio.mp4 ADDED
Binary file (246 kB). View file
 
sample_videos/temp.mp4 CHANGED
Binary files a/sample_videos/temp.mp4 and b/sample_videos/temp.mp4 differ
 
samples/Gradio/2024-08-05T11-02-12/input_image.png ADDED
video_editing/test_03.mp4 ADDED
Binary file (59.6 kB). View file
 
video_editing/test_03.png ADDED
video_editing/test_04.mp4 ADDED
Binary file (137 kB). View file
 
video_editing/test_04.png ADDED
video_editing/test_1.mp4 ADDED
Binary file (64.2 kB). View file
 
video_editing/test_1.png ADDED