Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import torch | |
from extern.depthcrafter.infer import DepthCrafterDemo | |
# from extern.video_depth_anything.vdademo import VDADemo | |
import numpy as np | |
import torch | |
from transformers import T5EncoderModel | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from models.crosstransformer3d import CrossTransformer3DModel | |
from models.autoencoder_magvit import AutoencoderKLCogVideoX | |
from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline | |
from models.utils import * | |
from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, | |
PNDMScheduler) | |
from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
class TrajCrafter: | |
def __init__(self, opts, gradio=False): | |
self.funwarp = Warper(device=opts.device) | |
# self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device) | |
self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device) | |
self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path) | |
self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device) | |
self.setup_diffusion(opts) | |
if gradio: | |
self.opts=opts | |
def infer_gradual(self,opts): | |
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
assert frames.shape[0] == opts.video_length | |
pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length) | |
warped_images = [] | |
masks = [] | |
for i in tqdm(range(opts.video_length)): | |
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False) | |
warped_images.append(warped_frame2) | |
masks.append(mask2) | |
cond_video = (torch.cat(warped_images)+1.)/2. | |
cond_masks = torch.cat(masks) | |
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
frames_ref = frames[:,:,:10,:,:] | |
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
del self.depth_estimater | |
del self.caption_processor | |
del self.captioner | |
gc.collect() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
sample = self.pipeline( | |
prompt, | |
num_frames = opts.video_length, | |
negative_prompt = opts.negative_prompt, | |
height = opts.sample_size[0], | |
width = opts.sample_size[1], | |
generator = generator, | |
guidance_scale = opts.diffusion_guidance_scale, | |
num_inference_steps = opts.diffusion_inference_steps, | |
video = cond_video, | |
mask_video = cond_masks, | |
reference = frames_ref, | |
).videos | |
save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
viz = True | |
if viz: | |
tensor_left = frames[0].to(opts.device) | |
tensor_right = sample[0].to(opts.device) | |
interval = torch.ones(3, 49, 384, 30).to(opts.device) | |
result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
result_reverse = torch.flip(result, dims=[1]) | |
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) | |
def infer_direct(self,opts): | |
opts.cut = 20 | |
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
assert frames.shape[0] == opts.video_length | |
pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut) | |
warped_images = [] | |
masks = [] | |
for i in tqdm(range(opts.video_length)): | |
if i < opts.cut: | |
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) | |
warped_images.append(warped_frame2) | |
masks.append(mask2) | |
else: | |
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False) | |
warped_images.append(warped_frame2) | |
masks.append(mask2) | |
cond_video = (torch.cat(warped_images)+1.)/2. | |
cond_masks = torch.cat(masks) | |
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
frames_ref = frames[:,:,:10,:,:] | |
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
del self.depth_estimater | |
del self.caption_processor | |
del self.captioner | |
gc.collect() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
sample = self.pipeline( | |
prompt, | |
num_frames = opts.video_length, | |
negative_prompt = opts.negative_prompt, | |
height = opts.sample_size[0], | |
width = opts.sample_size[1], | |
generator = generator, | |
guidance_scale = opts.diffusion_guidance_scale, | |
num_inference_steps = opts.diffusion_inference_steps, | |
video = cond_video, | |
mask_video = cond_masks, | |
reference = frames_ref, | |
).videos | |
save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
viz = True | |
if viz: | |
tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device) | |
tensor_right = sample[0][:,opts.cut:,...].to(opts.device) | |
interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device) | |
result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
result_reverse = torch.flip(result, dims=[1]) | |
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2) | |
def infer_bullet(self,opts): | |
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res) | |
prompt = self.get_caption(opts,frames[opts.video_length//2]) | |
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device) | |
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
assert frames.shape[0] == opts.video_length | |
pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length) | |
warped_images = [] | |
masks = [] | |
for i in tqdm(range(opts.video_length)): | |
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False) | |
warped_images.append(warped_frame2) | |
masks.append(mask2) | |
cond_video = (torch.cat(warped_images)+1.)/2. | |
cond_masks = torch.cat(masks) | |
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False) | |
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest') | |
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps) | |
save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps) | |
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps) | |
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
frames_ref = frames[:,:,-10:,:,:] | |
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
generator = torch.Generator(device=opts.device).manual_seed(opts.seed) | |
del self.depth_estimater | |
del self.caption_processor | |
del self.captioner | |
gc.collect() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
sample = self.pipeline( | |
prompt, | |
num_frames = opts.video_length, | |
negative_prompt = opts.negative_prompt, | |
height = opts.sample_size[0], | |
width = opts.sample_size[1], | |
generator = generator, | |
guidance_scale = opts.diffusion_guidance_scale, | |
num_inference_steps = opts.diffusion_inference_steps, | |
video = cond_video, | |
mask_video = cond_masks, | |
reference = frames_ref, | |
).videos | |
save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps) | |
viz = True | |
if viz: | |
tensor_left = frames[0].to(opts.device) | |
tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1) | |
tensor_right = sample[0].to(opts.device) | |
tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1) | |
interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device) | |
result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3) | |
result_reverse = torch.flip(result, dims=[1]) | |
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4) | |
def get_caption(self,opts,image): | |
image_array = (image * 255).astype(np.uint8) | |
pil_image = Image.fromarray(image_array) | |
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16) | |
generated_ids = self.captioner.generate(**inputs) | |
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return generated_text + opts.refine_prompt | |
def get_poses(self,opts,depths,num_frames): | |
radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale | |
radius = min(radius, 5) | |
cx = 512. #depths.shape[-1]//2 | |
cy = 288. #depths.shape[-2]//2 | |
f = 500 #500. | |
K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device) | |
c2w_init = torch.tensor([[-1., 0., 0., 0.], | |
[ 0., 1., 0., 0.], | |
[ 0., 0., -1., 0.], | |
[ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0) | |
if opts.camera == 'target': | |
dtheta, dphi, dr, dx, dy = opts.target_pose | |
poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device) | |
elif opts.camera =='traj': | |
with open(opts.traj_txt, 'r') as file: | |
lines = file.readlines() | |
theta = [float(i) for i in lines[0].split()] | |
phi = [float(i) for i in lines[1].split()] | |
r = [float(i)*radius for i in lines[2].split()] | |
poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device) | |
poses[:,2, 3] = poses[:,2, 3] + radius | |
pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1) | |
pose_t = poses | |
return pose_s, pose_t, K | |
def setup_diffusion(self,opts): | |
# transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype) | |
transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype) | |
# transformer = transformer.to(opts.weight_dtype) | |
vae = AutoencoderKLCogVideoX.from_pretrained( | |
opts.model_name, | |
subfolder="vae" | |
).to(opts.weight_dtype) | |
text_encoder = T5EncoderModel.from_pretrained( | |
opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype | |
) | |
# Get Scheduler | |
Choosen_Scheduler = { | |
"Euler": EulerDiscreteScheduler, | |
"Euler A": EulerAncestralDiscreteScheduler, | |
"DPM++": DPMSolverMultistepScheduler, | |
"PNDM": PNDMScheduler, | |
"DDIM_Cog": CogVideoXDDIMScheduler, | |
"DDIM_Origin": DDIMScheduler, | |
}[opts.sampler_name] | |
scheduler = Choosen_Scheduler.from_pretrained( | |
opts.model_name, | |
subfolder="scheduler" | |
) | |
self.pipeline = TrajCrafter_Pipeline.from_pretrained( | |
opts.model_name, | |
vae=vae, | |
text_encoder=text_encoder, | |
transformer=transformer, | |
scheduler=scheduler, | |
torch_dtype=opts.weight_dtype | |
) | |
if opts.low_gpu_memory_mode: | |
self.pipeline.enable_sequential_cpu_offload() | |
else: | |
self.pipeline.enable_model_cpu_offload() | |
def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed): | |
frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res) | |
prompt = self.get_caption(self.opts,frames[self.opts.video_length//2]) | |
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device) | |
depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device) | |
frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1] | |
num_frames = frames.shape[0] | |
assert num_frames == self.opts.video_length | |
radius_scale = float(radius_scale) | |
radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale | |
radius = min(radius, 5) | |
cx = 512. #depths.shape[-1]//2 | |
cy = 288. #depths.shape[-2]//2 | |
f = 500 #500. | |
K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device) | |
c2w_init = torch.tensor([[-1., 0., 0., 0.], | |
[ 0., 1., 0., 0.], | |
[ 0., 0., -1., 0.], | |
[ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0) | |
# import pdb | |
# pdb.set_trace() | |
theta,phi,r,x,y = [float(i) for i in pose.split(';')] | |
# theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()] | |
# target mode | |
poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device) | |
poses[:,2, 3] = poses[:,2, 3] + radius | |
pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1) | |
pose_t = poses | |
warped_images = [] | |
masks = [] | |
for i in tqdm(range(self.opts.video_length)): | |
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False) | |
warped_images.append(warped_frame2) | |
masks.append(mask2) | |
cond_video = (torch.cat(warped_images)+1.)/2. | |
cond_masks = torch.cat(masks) | |
frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False) | |
cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False) | |
cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest') | |
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps) | |
save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps) | |
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps) | |
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2. | |
frames_ref = frames[:,:,:10,:,:] | |
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0) | |
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255. | |
generator = torch.Generator(device=self.opts.device).manual_seed(seed) | |
# del self.depth_estimater | |
# del self.caption_processor | |
# del self.captioner | |
# gc.collect() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
sample = self.pipeline( | |
prompt, | |
num_frames = self.opts.video_length, | |
negative_prompt = self.opts.negative_prompt, | |
height = self.opts.sample_size[0], | |
width = self.opts.sample_size[1], | |
generator = generator, | |
guidance_scale = self.opts.diffusion_guidance_scale, | |
num_inference_steps = steps, | |
video = cond_video, | |
mask_video = cond_masks, | |
reference = frames_ref, | |
).videos | |
save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps) | |
viz = True | |
if viz: | |
tensor_left = frames[0].to(self.opts.device) | |
tensor_right = sample[0].to(self.opts.device) | |
interval = torch.ones(3, 49, 384, 30).to(self.opts.device) | |
result = torch.cat((tensor_left, interval, tensor_right), dim=3) | |
result_reverse = torch.flip(result, dims=[1]) | |
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1) | |
save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2) | |
return os.path.join(self.opts.save_dir,'viz.mp4') |