import spaces import argparse import torch import tempfile import os import cv2 import numpy as np import gradio as gr import torchvision.transforms.functional as F import matplotlib.pyplot as plt import matplotlib as mpl from omegaconf import OmegaConf from mpl_toolkits.mplot3d.art3d import Poly3DCollection from inference_cameractrl import get_relative_pose, ray_condition, get_pipeline from cameractrl.utils.util import save_videos_grid cv2.setNumThreads(1) mpl.use('agg') #### Description #### title = r"""

CameraCtrl: Enabling Camera Control for Video Diffusion Models

""" subtitle = r"""

CameraCtrl Image2Video with Stable Video Diffusion (SVD) model

""" description = r""" Official Gradio demo for CameraCtrl: Enabling Camera Control for Video Diffusion Models.
CameraCtrl is capable of precisely controlling the camera trajectory during the video generation process.
Note that, with SVD, CameraCtrl only support Image2Video now.
""" closing_words = r""" --- If you are interested in this demo or CameraCtrl is helpful for you, please give us a ⭐ of the CameraCtrl Github Repo ! [![GitHub Stars](https://img.shields.io/github/stars/hehao13/CameraCtrl )](https://github.com/hehao13/CameraCtrl) --- 📝 **Citation**
If you find our paper or code is useful for your research, please consider citing: ```bibtex @article{he2024cameractrl, title={CameraCtrl: Enabling Camera Control for Text-to-Video Generation}, author={Hao He and Yinghao Xu and Yuwei Guo and Gordon Wetzstein and Bo Dai and Hongsheng Li and Ceyuan Yang}, journal={arXiv preprint arXiv:2404.02101}, year={2024} } ``` 📧 **Contact**
If you have any questions, please feel free to contact me at haohe@link.cuhk.edu.hk. **Acknowledgement**
We thank MotionCtrl and IC-Light for their gradio codes.
""" RESIZE_MODES = ['Resize then Center Crop', 'Directly resize'] CAMERA_TRAJECTORY_MODES = ["Provided Camera Trajectories", "Custom Camera Trajectories"] height = 320 width = 576 num_frames = 14 device = "cuda" if torch.cuda.is_available() else "cpu" config = "configs/train_cameractrl/svd_320_576_cameractrl.yaml" model_id = "stabilityai/stable-video-diffusion-img2vid" ckpt = "checkpoints/CameraCtrl_svdxt.ckpt" if not os.path.exists(ckpt): os.makedirs("checkpoints", exist_ok=True) os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true") os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svdxt.ckpt") model_config = OmegaConf.load(config) pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'], model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'], ckpt, True, device) examples = [ [ "assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png", "assets/pose_files/0bf152ef84195293.txt", "Trajectory 1" ], [ "assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png", "assets/pose_files/0c9b371cc6225682.txt", "Trajectory 2" ], [ "assets/example_condition_images/Rocky_coastline_with_crashing_waves..png", "assets/pose_files/0c11dbe781b1c11c.txt", "Trajectory 3" ], [ "assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png", "assets/pose_files/0f47577ab3441480.txt", "Trajectory 4" ], [ "assets/example_condition_images/An_exploding_cheese_house..png", "assets/pose_files/0f47577ab3441480.txt", "Trajectory 4" ], [ "assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png", "assets/pose_files/0f68374b76390082.txt", "Trajectory 5" ], [ "assets/example_condition_images/Leaves_are_falling_from_trees..png", "assets/pose_files/2c80f9eb0d3b2bb4.txt", "Trajectory 6" ], [ "assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png", "assets/pose_files/2f25826f0d0ef09a.txt", "Trajectory 7" ], [ "assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png", "assets/pose_files/3f79dc32d575bcdc.txt", "Trajectory 8" ], [ "assets/example_condition_images/A_car_running_on_Mars..png", "assets/pose_files/4a2d6753676df096.txt", "Trajectory 9" ], ] class Camera(object): def __init__(self, entry): fx, fy, cx, cy = entry[1:5] self.fx = fx self.fy = fy self.cx = cx self.cy = cy w2c_mat = np.array(entry[7:]).reshape(3, 4) w2c_mat_4x4 = np.eye(4) w2c_mat_4x4[:3, :] = w2c_mat self.w2c_mat = w2c_mat_4x4 self.c2w_mat = np.linalg.inv(w2c_mat_4x4) class CameraPoseVisualizer: def __init__(self, xlim, ylim, zlim): self.fig = plt.figure(figsize=(18, 7)) self.ax = self.fig.add_subplot(projection='3d') self.plotly_data = None # plotly data traces self.ax.set_aspect("auto") self.ax.set_xlim(xlim) self.ax.set_ylim(ylim) self.ax.set_zlim(zlim) self.ax.set_xlabel('x') self.ax.set_ylabel('y') self.ax.set_zlabel('z') def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9 / 16, base_xval=1, zval=3): vertex_std = np.array([[0, 0, 0, 1], [base_xval, -base_xval * hw_ratio, zval, 1], [base_xval, base_xval * hw_ratio, zval, 1], [-base_xval, base_xval * hw_ratio, zval, 1], [-base_xval, -base_xval * hw_ratio, zval, 1]]) vertex_transformed = vertex_std @ extrinsic.T meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) self.ax.add_collection3d( Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) def colorbar(self, max_frame_length): cmap = mpl.cm.rainbow norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical', label='Frame Indexes') def show(self): plt.title('Camera Trajectory') plt.show() def get_c2w(w2cs): target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] camera_positions = np.asarray([c2w[:3, 3] for c2w in ret_poses]) # [n_frame, 3] position_distances = [camera_positions[i] - camera_positions[i - 1] for i in range(1, len(camera_positions))] xyz_max = np.max(camera_positions, axis=0) xyz_min = np.min(camera_positions, axis=0) xyz_ranges = xyz_max - xyz_min # [3, ] max_range = np.max(xyz_ranges) expected_xyz_ranges = 1 scale_ratio = expected_xyz_ranges / max_range scaled_position_distances = [dis * scale_ratio for dis in position_distances] # [n_frame - 1] scaled_camera_positions = [camera_positions[0], ] scaled_camera_positions.extend([camera_positions[0] + np.sum(np.asarray(scaled_position_distances[:i]), axis=0) for i in range(1, len(camera_positions))]) ret_poses = [np.concatenate( (np.concatenate((ori_pose[:3, :3], cam_position[:, None]), axis=1), np.asarray([0, 0, 0, 1])[None]), axis=0) for ori_pose, cam_position in zip(ret_poses, scaled_camera_positions)] transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) ret_poses = [transform_matrix @ x for x in ret_poses] return np.array(ret_poses, dtype=np.float32) def visualize_trajectory(trajectory_file): with open(trajectory_file, 'r') as f: poses = f.readlines() w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]] num_frames = len(w2cs) last_row = np.zeros((1, 4)) last_row[0, -1] = 1.0 w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] c2ws = get_c2w(w2cs) visualizer = CameraPoseVisualizer([-1.2, 1.2], [-1.2, 1.2], [-1.2, 1.2]) for frame_idx, c2w in enumerate(c2ws): visualizer.extrinsic2pyramid(c2w, frame_idx / num_frames, hw_ratio=9 / 16, base_xval=0.02, zval=0.1) visualizer.colorbar(num_frames) return visualizer.fig vis_traj = visualize_trajectory('assets/pose_files/0bf152ef84195293.txt') @torch.inference_mode() def process_input_image(input_image, resize_mode): global height, width expected_hw_ratio = height / width inp_w, inp_h = input_image.size inp_hw_ratio = inp_h / inp_w if inp_hw_ratio > expected_hw_ratio: resized_height = inp_hw_ratio * width resized_width = width else: resized_height = height resized_width = height / inp_hw_ratio resized_image = F.resize(input_image, size=[resized_height, resized_width]) if resize_mode == RESIZE_MODES[0]: return_image = F.center_crop(resized_image, output_size=[height, width]) else: return_image = resized_image return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), gr.update( visible=True), gr.update(visible=True), gr.update(visible=True) def update_camera_trajectories(trajectory_mode): if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]: return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) elif trajectory_mode == CAMERA_TRAJECTORY_MODES[1]: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) def update_camera_args(trajectory_mode, provided_camera_trajectory, customized_trajectory_file): if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]: res = "Provided " + str(provided_camera_trajectory) else: if customized_trajectory_file is None: res = " " else: res = f"Customized trajectory file {customized_trajectory_file.name.split('/')[-1]}" return res def update_camera_args_reset(): return " " def update_trajectory_vis_plot(camera_trajectory_args, provided_camera_trajectory, customized_trajectory_file): if 'Provided' in camera_trajectory_args: if provided_camera_trajectory == "Trajectory 1": trajectory_file_path = "assets/pose_files/0bf152ef84195293.txt" elif provided_camera_trajectory == "Trajectory 2": trajectory_file_path = "assets/pose_files/0c9b371cc6225682.txt" elif provided_camera_trajectory == "Trajectory 3": trajectory_file_path = "assets/pose_files/0c11dbe781b1c11c.txt" elif provided_camera_trajectory == "Trajectory 4": trajectory_file_path = "assets/pose_files/0f47577ab3441480.txt" elif provided_camera_trajectory == "Trajectory 5": trajectory_file_path = "assets/pose_files/0f68374b76390082.txt" elif provided_camera_trajectory == "Trajectory 6": trajectory_file_path = "assets/pose_files/2c80f9eb0d3b2bb4.txt" elif provided_camera_trajectory == "Trajectory 7": trajectory_file_path = "assets/pose_files/2f25826f0d0ef09a.txt" elif provided_camera_trajectory == "Trajectory 8": trajectory_file_path = "assets/pose_files/3f79dc32d575bcdc.txt" else: trajectory_file_path = "assets/pose_files/4a2d6753676df096.txt" else: trajectory_file_path = customized_trajectory_file.name vis_traj = visualize_trajectory(trajectory_file_path) return gr.update(visible=True), vis_traj, gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), trajectory_file_path def update_set_button(): return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) def update_buttons_for_example(example_image, example_traj_path, provided_traj_name): global height, width return_image = example_image return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True) # @torch.inference_mode() # @spaces.GPU(duration=150) # def sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator): # res = pipeline( # image=condition_image, # pose_embedding=plucker_embedding, # height=height, # width=width, # num_frames=num_frames, # num_inference_steps=num_inference_step, # min_guidance_scale=min_guidance_scale, # max_guidance_scale=max_guidance_scale, # fps=fps_id, # do_image_process=True, # generator=generator, # output_type='pt' # ).frames[0].transpose(0, 1).cpu() # # temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name # save_videos_grid(res[None], temporal_video_path, rescale=False) # return temporal_video_path @spaces.GPU(duration=80) def sample_video(condition_image, trajectory_file, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, seed): global height, width, num_frames, device, pipeline with open(trajectory_file, 'r') as f: poses = f.readlines() poses = [pose.strip().split(' ') for pose in poses[1:]] cam_params = [[float(x) for x in pose] for pose in poses] cam_params = [Camera(cam_param) for cam_param in cam_params] sample_wh_ratio = width / height pose_wh_ratio = cam_params[0].fy / cam_params[0].fx if pose_wh_ratio > sample_wh_ratio: resized_ori_w = height * pose_wh_ratio for cam_param in cam_params: cam_param.fx = resized_ori_w * cam_param.fx / width else: resized_ori_h = width / pose_wh_ratio for cam_param in cam_params: cam_param.fy = resized_ori_h * cam_param.fy / height intrinsic = np.asarray([[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params], dtype=np.float32) K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True) c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] plucker_embedding = ray_condition(K, c2ws, height, width, device='cpu') # b f h w 6 plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device) generator = torch.Generator(device=device) generator.manual_seed(int(seed)) with torch.no_grad(): sample = pipeline( image=condition_image, pose_embedding=plucker_embedding, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_step, min_guidance_scale=min_guidance_scale, max_guidance_scale=max_guidance_scale, fps=fps_id, do_image_process=True, generator=generator, output_type='pt' ).frames[0].transpose(0, 1).cpu() temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name save_videos_grid(sample[None], temporal_video_path, rescale=False) return temporal_video_path # return sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator) def main(args): demo = gr.Blocks().queue() with demo: gr.Markdown(title) gr.Markdown(subtitle) gr.Markdown(description) with gr.Column(): # step1: Input condition image step1_title = gr.Markdown("---\n## Step 1: Input an Image", show_label=False, visible=True) step1_dec = gr.Markdown(f"\n 1. Upload an Image by `Drag` or Click `Upload Image`; \ \n 2. Click `{RESIZE_MODES[0]}` or `{RESIZE_MODES[1]}` to select the image resize mode. \ \n - `{RESIZE_MODES[0]}`: First resize the input image, then center crop it into the resolution of 320 x 576. \ \n - `{RESIZE_MODES[1]}`: Only resize the input image, and keep the original aspect ratio.", show_label=False, visible=True) with gr.Row(equal_height=True): with gr.Column(scale=2): input_image = gr.Image(type='pil', interactive=True, elem_id='condition_image', elem_classes='image', visible=True) with gr.Row(): resize_crop_button = gr.Button(RESIZE_MODES[0], visible=True) directly_resize_button = gr.Button(RESIZE_MODES[1], visible=True) with gr.Column(scale=2): processed_image = gr.Image(type='pil', interactive=False, elem_id='processed_image', elem_classes='image', visible=False) # step2: Select camera trajectory step2_camera_trajectory = gr.Markdown("---\n## Step 2: Select the camera trajectory", show_label=False, visible=False) step2_camera_trajectory_des = gr.Markdown(f"\n - `{CAMERA_TRAJECTORY_MODES[0]}`: Including 9 camera trajectories extracted from the test set of RealEstate10K dataset, each has 25 frames. \ \n - `{CAMERA_TRAJECTORY_MODES[1]}`: You can provide the customized camera trajectories in the txt file.", show_label=False, visible=False) with gr.Row(equal_height=True): provide_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[0], visible=False) customized_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[1], visible=False) with gr.Row(): with gr.Column(): provided_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[0]}", show_label=False, visible=False) provided_camera_trajectory_des = gr.Markdown(f"\n 1. Click one of the provide camera trajectories, such as `Trajectory 1`; \ \n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \ \n 3. Click `Reset Trajectory` to reset the camera trajectory. ", show_label=False, visible=False) customized_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[1]}", show_label=False, visible=False) customized_run_status = gr.Markdown(f"\n 1. Input the txt file containing camera trajectory. \ \n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \ \n 3. Click `Reset Trajectory` to reset the camera trajectory. ", show_label=False, visible=False) with gr.Row(): provided_trajectories = gr.Dropdown( ["Trajectory 1", "Trajectory 2", "Trajectory 3", "Trajectory 4", "Trajectory 5", "Trajectory 6", "Trajectory 7", "Trajectory 8", "Trajectory 9"], label="Provided Trajectories", interactive=True, visible=False) with gr.Row(): customized_camera_trajectory_file = gr.File( label="Upload customized camera trajectory (in .txt format).", visible=False, interactive=True) with gr.Row(): camera_args = gr.Textbox(value=" ", label="Camera Trajectory Name", visible=False) camera_trajectory_path = gr.Textbox(value=" ", visible=False) with gr.Row(): camera_trajectory_vis = gr.Button(value="Visualize Camera Trajectory", visible=False) camera_trajectory_reset = gr.Button(value="Reset Camera Trajectory", visible=False) with gr.Column(): vis_camera_trajectory = gr.Plot(vis_traj, label='Camera Trajectory', visible=False) # step3: Set inference parameters with gr.Row(): with gr.Column(): step3_title = gr.Markdown(f"---\n## Step3: Setting the inference hyper-parameters.", visible=False) step3_des = gr.Markdown( f"\n 1. Set the mumber of inference step; \ \n 2. Set the seed; \ \n 3. Set the minimum guidance scale and the maximum guidance scale; \ \n 4. Set the fps; \ \n - Please refer to the SVD paper for the meaning of the last three parameter", visible=False) with gr.Row(): with gr.Column(): num_inference_steps = gr.Number(value=25, label='Number Inference Steps', step=1, interactive=True, visible=False) with gr.Column(): seed = gr.Number(value=42, label='Seed', minimum=1, interactive=True, visible=False, step=1) with gr.Column(): min_guidance_scale = gr.Number(value=1.0, label='Minimum Guidance Scale', minimum=1.0, step=0.5, interactive=True, visible=False) with gr.Column(): max_guidance_scale = gr.Number(value=3.0, label='Maximum Guidance Scale', minimum=1.0, step=0.5, interactive=True, visible=False) with gr.Column(): fps = gr.Number(value=7, label='FPS', minimum=1, step=1, interactive=True, visible=False) with gr.Column(): _ = gr.Button("Seed", visible=False) with gr.Column(): _ = gr.Button("Seed", visible=False) with gr.Column(): _ = gr.Button("Seed", visible=False) with gr.Row(): with gr.Column(): _ = gr.Button("Set", visible=False) with gr.Column(): set_button = gr.Button("Set", visible=False) with gr.Column(): _ = gr.Button("Set", visible=False) # step 4: Generate video with gr.Row(): with gr.Column(): step4_title = gr.Markdown("---\n## Step4 Generating video", show_label=False, visible=False) step4_des = gr.Markdown(f"\n - Click the `Start generation !` button to generate the video.; \ \n - If the content of generated video is not very aligned with the condition image, try to increase the `Minimum Guidance Scale` and `Maximum Guidance Scale`. \ \n - If the generated videos are distored, try to increase `FPS`.", visible=False) start_button = gr.Button(value="Start generation !", visible=False) with gr.Column(): generate_video = gr.Video(value=None, label="Generate Video", visible=False) resize_crop_button.click(fn=process_input_image, inputs=[input_image, resize_crop_button], outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button, customized_trajectory_button]) directly_resize_button.click(fn=process_input_image, inputs=[input_image, directly_resize_button], outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button, customized_trajectory_button]) provide_trajectory_button.click(fn=update_camera_trajectories, inputs=[provide_trajectory_button], outputs=[provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories, customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file, camera_args, camera_trajectory_vis, camera_trajectory_reset]) customized_trajectory_button.click(fn=update_camera_trajectories, inputs=[customized_trajectory_button], outputs=[provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories, customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file, camera_args, camera_trajectory_vis, camera_trajectory_reset]) provided_trajectories.change(fn=update_camera_args, inputs=[provide_trajectory_button, provided_trajectories, customized_camera_trajectory_file], outputs=[camera_args]) customized_camera_trajectory_file.change(fn=update_camera_args, inputs=[customized_trajectory_button, provided_trajectories, customized_camera_trajectory_file], outputs=[camera_args]) camera_trajectory_reset.click(fn=update_camera_args_reset, inputs=None, outputs=[camera_args]) camera_trajectory_vis.click(fn=update_trajectory_vis_plot, inputs=[camera_args, provided_trajectories, customized_camera_trajectory_file], outputs=[vis_camera_trajectory, vis_camera_trajectory, step3_title, step3_des, num_inference_steps, min_guidance_scale, max_guidance_scale, fps, seed, set_button, camera_trajectory_path]) set_button.click(fn=update_set_button, inputs=None, outputs=[step4_title, step4_des, start_button, generate_video]) start_button.click(fn=sample_video, inputs=[processed_image, camera_trajectory_path, num_inference_steps, min_guidance_scale, max_guidance_scale, fps, seed], outputs=[generate_video]) # set example gr.Markdown("## Examples") gr.Markdown("\n Choosing the one of the following examples to get a quick start, by selecting an example, " "we will set the condition image and camera trajectory automatically. " "Then, you can click the `Visualize Camera Trajectory` button to visualize the camera trajectory.") gr.Examples( fn=update_buttons_for_example, run_on_click=True, cache_examples=False, examples=examples, inputs=[input_image, camera_args, provided_trajectories], outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button, customized_trajectory_button, provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories, customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file, camera_args, camera_trajectory_vis, camera_trajectory_reset] ) with gr.Row(): gr.Markdown(closing_words) demo.launch(**args) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--listen', default='0.0.0.0') parser.add_argument('--broswer', action='store_true') parser.add_argument('--share', action='store_true') args = parser.parse_args() launch_kwargs = {'server_name': args.listen, 'inbrowser': args.broswer, 'share': args.share} main(launch_kwargs)