import gc import os import torch from extern.depthcrafter.infer import DepthCrafterDemo # from extern.video_depth_anything.vdademo import VDADemo import numpy as np import torch from transformers import T5EncoderModel from omegaconf import OmegaConf from PIL import Image from models.crosstransformer3d import CrossTransformer3DModel from models.autoencoder_magvit import AutoencoderKLCogVideoX from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline from models.utils import * from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler) from transformers import AutoProcessor, Blip2ForConditionalGeneration class TrajCrafter: def __init__(self, opts, gradio=False): self.funwarp = Warper(device=opts.device) # self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device) self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device) self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path) self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device) self.setup_diffusion(opts) if gradio: self.opts=opts def infer_gradual(self,opts): frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) prompt = self.get_caption(opts,frames[opts.video_length//2]) # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] assert frames.shape[0] == opts.video_length pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length) warped_images = [] masks = [] for i in tqdm(range(opts.video_length)): warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False) warped_images.append(warped_frame2) masks.append(mask2) cond_video = (torch.cat(warped_images)+1.)/2. cond_masks = torch.cat(masks) frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. frames_ref = frames[:,:,:10,:,:] cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. generator = torch.Generator(device=opts.device).manual_seed(opts.seed) del self.depth_estimater del self.caption_processor del self.captioner gc.collect() torch.cuda.empty_cache() with torch.no_grad(): sample = self.pipeline( prompt, num_frames = opts.video_length, negative_prompt = opts.negative_prompt, height = opts.sample_size[0], width = opts.sample_size[1], generator = generator, guidance_scale = opts.diffusion_guidance_scale, num_inference_steps = opts.diffusion_inference_steps, video = cond_video, mask_video = cond_masks, reference = frames_ref, ).videos save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) viz = True if viz: tensor_left = frames[0].to(opts.device) tensor_right = sample[0].to(opts.device) interval = torch.ones(3, 49, 384, 30).to(opts.device) result = torch.cat((tensor_left, interval, tensor_right), dim=3) result_reverse = torch.flip(result, dims=[1]) final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) def infer_direct(self,opts): opts.cut = 20 frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) prompt = self.get_caption(opts,frames[opts.video_length//2]) # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] assert frames.shape[0] == opts.video_length pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut) warped_images = [] masks = [] for i in tqdm(range(opts.video_length)): if i < opts.cut: warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) warped_images.append(warped_frame2) masks.append(mask2) else: warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False) warped_images.append(warped_frame2) masks.append(mask2) cond_video = (torch.cat(warped_images)+1.)/2. cond_masks = torch.cat(masks) frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. frames_ref = frames[:,:,:10,:,:] cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. generator = torch.Generator(device=opts.device).manual_seed(opts.seed) del self.depth_estimater del self.caption_processor del self.captioner gc.collect() torch.cuda.empty_cache() with torch.no_grad(): sample = self.pipeline( prompt, num_frames = opts.video_length, negative_prompt = opts.negative_prompt, height = opts.sample_size[0], width = opts.sample_size[1], generator = generator, guidance_scale = opts.diffusion_guidance_scale, num_inference_steps = opts.diffusion_inference_steps, video = cond_video, mask_video = cond_masks, reference = frames_ref, ).videos save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) viz = True if viz: tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device) tensor_right = sample[0][:,opts.cut:,...].to(opts.device) interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device) result = torch.cat((tensor_left, interval, tensor_right), dim=3) result_reverse = torch.flip(result, dims=[1]) final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) def infer_bullet(self,opts): frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) prompt = self.get_caption(opts,frames[opts.video_length//2]) # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] assert frames.shape[0] == opts.video_length pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length) warped_images = [] masks = [] for i in tqdm(range(opts.video_length)): warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) warped_images.append(warped_frame2) masks.append(mask2) cond_video = (torch.cat(warped_images)+1.)/2. cond_masks = torch.cat(masks) frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. frames_ref = frames[:,:,-10:,:,:] cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. generator = torch.Generator(device=opts.device).manual_seed(opts.seed) del self.depth_estimater del self.caption_processor del self.captioner gc.collect() torch.cuda.empty_cache() with torch.no_grad(): sample = self.pipeline( prompt, num_frames = opts.video_length, negative_prompt = opts.negative_prompt, height = opts.sample_size[0], width = opts.sample_size[1], generator = generator, guidance_scale = opts.diffusion_guidance_scale, num_inference_steps = opts.diffusion_inference_steps, video = cond_video, mask_video = cond_masks, reference = frames_ref, ).videos save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) viz = True if viz: tensor_left = frames[0].to(opts.device) tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1) tensor_right = sample[0].to(opts.device) tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1) interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device) result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3) result_reverse = torch.flip(result, dims=[1]) final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4) def get_caption(self,opts,image): image_array = (image * 255).astype(np.uint8) pil_image = Image.fromarray(image_array) inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16) generated_ids = self.captioner.generate(**inputs) generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return generated_text + opts.refine_prompt def get_poses(self,opts,depths,num_frames): radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale radius = min(radius, 5) cx = 512. #depths.shape[-1]//2 cy = 288. #depths.shape[-2]//2 f = 500 #500. K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device) c2w_init = torch.tensor([[-1., 0., 0., 0.], [ 0., 1., 0., 0.], [ 0., 0., -1., 0.], [ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0) if opts.camera == 'target': dtheta, dphi, dr, dx, dy = opts.target_pose poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device) elif opts.camera =='traj': with open(opts.traj_txt, 'r') as file: lines = file.readlines() theta = [float(i) for i in lines[0].split()] phi = [float(i) for i in lines[1].split()] r = [float(i)*radius for i in lines[2].split()] poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device) poses[:,2, 3] = poses[:,2, 3] + radius pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1) pose_t = poses return pose_s, pose_t, K def setup_diffusion(self,opts): # transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype) transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype) # transformer = transformer.to(opts.weight_dtype) vae = AutoencoderKLCogVideoX.from_pretrained( opts.model_name, subfolder="vae" ).to(opts.weight_dtype) text_encoder = T5EncoderModel.from_pretrained( opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype ) # Get Scheduler Choosen_Scheduler = { "Euler": EulerDiscreteScheduler, "Euler A": EulerAncestralDiscreteScheduler, "DPM++": DPMSolverMultistepScheduler, "PNDM": PNDMScheduler, "DDIM_Cog": CogVideoXDDIMScheduler, "DDIM_Origin": DDIMScheduler, }[opts.sampler_name] scheduler = Choosen_Scheduler.from_pretrained( opts.model_name, subfolder="scheduler" ) self.pipeline = TrajCrafter_Pipeline.from_pretrained( opts.model_name, vae=vae, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, torch_dtype=opts.weight_dtype ) if opts.low_gpu_memory_mode: self.pipeline.enable_sequential_cpu_offload() else: self.pipeline.enable_model_cpu_offload() def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed): frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res) prompt = self.get_caption(self.opts,frames[self.opts.video_length//2]) # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device) frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] num_frames = frames.shape[0] assert num_frames == self.opts.video_length radius_scale = float(radius_scale) radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale radius = min(radius, 5) cx = 512. #depths.shape[-1]//2 cy = 288. #depths.shape[-2]//2 f = 500 #500. K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device) c2w_init = torch.tensor([[-1., 0., 0., 0.], [ 0., 1., 0., 0.], [ 0., 0., -1., 0.], [ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0) # import pdb # pdb.set_trace() theta,phi,r,x,y = [float(i) for i in pose.split(';')] # theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()] # target mode poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device) poses[:,2, 3] = poses[:,2, 3] + radius pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1) pose_t = poses warped_images = [] masks = [] for i in tqdm(range(self.opts.video_length)): warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False) warped_images.append(warped_frame2) masks.append(mask2) cond_video = (torch.cat(warped_images)+1.)/2. cond_masks = torch.cat(masks) frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False) cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False) cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest') save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps) save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps) save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps) frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. frames_ref = frames[:,:,:10,:,:] cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. generator = torch.Generator(device=self.opts.device).manual_seed(seed) # del self.depth_estimater # del self.caption_processor # del self.captioner # gc.collect() torch.cuda.empty_cache() with torch.no_grad(): sample = self.pipeline( prompt, num_frames = self.opts.video_length, negative_prompt = self.opts.negative_prompt, height = self.opts.sample_size[0], width = self.opts.sample_size[1], generator = generator, guidance_scale = self.opts.diffusion_guidance_scale, num_inference_steps = steps, video = cond_video, mask_video = cond_masks, reference = frames_ref, ).videos save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps) viz = True if viz: tensor_left = frames[0].to(self.opts.device) tensor_right = sample[0].to(self.opts.device) interval = torch.ones(3, 49, 384, 30).to(self.opts.device) result = torch.cat((tensor_left, interval, tensor_right), dim=3) result_reverse = torch.flip(result, dims=[1]) final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2) return os.path.join(self.opts.save_dir,'viz.mp4')