Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,951 Bytes
be791d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os
import torch
import argparse
import torchvision
from pipeline_videogen import VideoGenPipeline
from pipelines.pipeline_inversion import VideoGenInversionPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from utils import find_model
from models import get_models
import imageio
import decord
import numpy as np
from copy import deepcopy
from PIL import Image
from datasets import video_transforms
from torchvision import transforms
def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def separation_content_motion(video_clip):
"""
Separate content and motion in a given video.
Args:
video_clip: A given video clip, shape [B, C, F, H, W]
Return:
base_frame: Base frame, shape [B, C, 1, H, W]
motions: Motions based on base frame, shape [B, C, F-1, H, W]
"""
# Selecting the first frame from each video in the batch as the base frame
base_frame = video_clip[:, :, :1, :, :]
# Calculating the motion (difference between each frame and the base frame)
motions = video_clip[:, :, 1:, :, :] - base_frame
return base_frame, motions
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename,
ctx=self.ctx,
num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'sr={self.sr},'
f'num_threads={self.num_threads})')
return repr_str
def main(args):
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
unet = get_models(args).to(device, dtype=torch.float16)
state_dict = find_model(args.ckpt)
unet.load_state_dict(state_dict)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,)
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
# beta_end=0.017,
beta_schedule=args.beta_schedule,)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler_inversion,
unet=unet).to(device)
videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
# videogen_pipeline.enable_vae_slicing()
transform_video = video_transforms.Compose([
video_transforms.ToTensorVideo(),
video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
# video_path = './video_editing/A_man_walking_on_the_beach.mp4'
video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
video_reader = DecordInit()
video = video_reader(video_path)
frame_indice = np.linspace(0, 15, 16, dtype=int)
video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
video = video / 255.0
video = video * 2.0 - 1.0
latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4)
base_content, motion_latents = separation_content_motion(latents)
# image_path = "./video_editing/a_man_walking_in_the_park.png"
image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device)
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
# prompt_inversion = 'a man walking on the beach'
prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style'
latents = videogen_pipeline_inversion(prompt_inversion,
latents=motion_latents,
base_content=base_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
output_type="latent").video
# prompt = 'a man walking in the park'
prompt = 'a corgi walking in the park at sunrise, oil painting style'
videos = videogen_pipeline(prompt,
latents=latents,
base_content=edit_content,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
print('save path {}'.format(args.save_img_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
main(OmegaConf.load(args.config))
|