import os import torch import argparse import torchvision from pipeline_videogen import VideoGenPipeline from pipelines.pipeline_inversion import VideoGenInversionPipeline from diffusers.schedulers import DDIMScheduler from diffusers.models import AutoencoderKL from diffusers.models import AutoencoderKLTemporalDecoder from transformers import CLIPTokenizer, CLIPTextModel from omegaconf import OmegaConf import os, sys sys.path.append(os.path.split(sys.path[0])[0]) from utils import find_model from models import get_models import imageio import decord import numpy as np from copy import deepcopy from PIL import Image from datasets import video_transforms from torchvision import transforms def prepare_image(path, vae, transform_video, device, dtype=torch.float16): with open(path, 'rb') as f: image = Image.open(f).convert('RGB') image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2) image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image) image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor) image = image.unsqueeze(2) return image def separation_content_motion(video_clip): """ Separate content and motion in a given video. Args: video_clip: A given video clip, shape [B, C, F, H, W] Return: base_frame: Base frame, shape [B, C, 1, H, W] motions: Motions based on base frame, shape [B, C, F-1, H, W] """ # Selecting the first frame from each video in the batch as the base frame base_frame = video_clip[:, :, :1, :, :] # Calculating the motion (difference between each frame and the base frame) motions = video_clip[:, :, 1:, :, :] - base_frame return base_frame, motions class DecordInit(object): """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" def __init__(self, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) def __call__(self, filename): """Perform the Decord initialization. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ reader = decord.VideoReader(filename, ctx=self.ctx, num_threads=self.num_threads) return reader def __repr__(self): repr_str = (f'{self.__class__.__name__}(' f'sr={self.sr},' f'num_threads={self.num_threads})') return repr_str def main(args): # torch.manual_seed(args.seed) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 # torch.float16 unet = get_models(args).to(device, dtype=torch.float16) state_dict = find_model(args.ckpt) unet.load_state_dict(state_dict) if args.enable_vae_temporal_decoder: if args.use_dct: vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device) else: vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) vae = deepcopy(vae_for_base_content).to(dtype=dtype) else: vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64) vae = deepcopy(vae_for_base_content).to(dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # set eval mode unet.eval() vae.eval() text_encoder.eval() scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule,) scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, # beta_end=0.017, beta_schedule=args.beta_schedule,) videogen_pipeline = VideoGenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler_inversion, unet=unet).to(device) videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, unet=unet).to(device) # videogen_pipeline.enable_xformers_memory_efficient_attention() # videogen_pipeline.enable_vae_slicing() transform_video = video_transforms.Compose([ video_transforms.ToTensorVideo(), video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) # video_path = './video_editing/A_man_walking_on_the_beach.mp4' video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4' video_reader = DecordInit() video = video_reader(video_path) frame_indice = np.linspace(0, 15, 16, dtype=int) video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() video = video / 255.0 video = video * 2.0 - 1.0 latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4) base_content, motion_latents = separation_content_motion(latents) # image_path = "./video_editing/a_man_walking_in_the_park.png" image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png" edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device) if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) # prompt_inversion = 'a man walking on the beach' prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style' latents = videogen_pipeline_inversion(prompt_inversion, latents=motion_latents, base_content=base_content, video_length=args.video_length, height=args.image_size[0], width=args.image_size[1], num_inference_steps=args.num_sampling_steps, guidance_scale=1.0, # guidance_scale=args.guidance_scale, motion_bucket_id=args.motion_bucket_id, output_type="latent").video # prompt = 'a man walking in the park' prompt = 'a corgi walking in the park at sunrise, oil painting style' videos = videogen_pipeline(prompt, latents=latents, base_content=edit_content, video_length=args.video_length, height=args.image_size[0], width=args.image_size[1], num_inference_steps=args.num_sampling_steps, guidance_scale=1.0, # guidance_scale=args.guidance_scale, motion_bucket_id=args.motion_bucket_id, enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0 print('save path {}'.format(args.save_img_path)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/sample.yaml") args = parser.parse_args() main(OmegaConf.load(args.config))