import torch # torch.jit.script = lambda f: f # General import os from os.path import join as opj import argparse import datetime from pathlib import Path # import spaces import gradio as gr import tempfile import yaml from t2v_enhanced.model.video_ldm import VideoLDM # Utilities from t2v_enhanced.inference_utils import * from t2v_enhanced.model_init import * from t2v_enhanced.model_func import * on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" parser = argparse.ArgumentParser() parser.add_argument('--public_access', action='store_true', default=True) parser.add_argument('--where_to_log', type=str, default="gradio_output") parser.add_argument('--device', type=str, default="cuda") args = parser.parse_args() default_prompt = "A man with yellow ballon head is riding a bike on the street of New York City" Path(args.where_to_log).mkdir(parents=True, exist_ok=True) result_fol = Path(args.where_to_log).absolute() device = args.device n_devices = int(os.environ.get('NDEVICES', 4)) if n_devices == 4: devices = [f"cuda:{idx}" for idx in range(4)] else: devices = ["cuda"] * 4 # -------------------------- # ----- Configurations ----- # -------------------------- cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True} # -------------------------- # ----- Initialization ----- # -------------------------- ms_model = init_modelscope(devices[1]) # # zs_model = init_zeroscope(device) ad_model = init_animatediff(devices[1]) svd_model = init_svd(devices[2]) sdxl_model = init_sdxl(devices[2]) ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute() stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol) msxl_model = init_v2v_model(cfg_v2v, devices[3]) # ------------------------- # ----- Functionality ----- # ------------------------- # @spaces.GPU(duration=120) def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol): now = datetime.datetime.now() name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_") if num_frames == [] or num_frames is None: num_frames = 24 else: num_frames = int(num_frames.split(" ")[0]) if num_frames > 56: num_frames = 56 if prompt == "" or prompt is None: prompt = default_prompt n_autoreg_gen = (num_frames-8)//8 if model_name_stage1 == "ModelScopeT2V (text to video)": inference_generator = torch.Generator(device=ms_model.device).manual_seed(seed) short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device) elif model_name_stage1 == "AnimateDiff (text to video)": inference_generator = torch.Generator(device=ad_model.device).manual_seed(seed) short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device) elif model_name_stage1 == "SVD (image to video)": # For cached examples if isinstance(image, dict): image = image["path"] inference_generator = torch.Generator(device=svd_model.device).manual_seed(seed) short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device) stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model) video_path = opj(where_to_log, name+".mp4") return video_path # @spaces.GPU(duration=400) def enhance(prompt, input_to_enhance, num_frames=None, image=None, model_name_stage1=None, model_name_stage2=None, seed=33, t=50, image_guidance=9.5, result_fol=result_fol): if prompt == "" or prompt is None: prompt = default_prompt if input_to_enhance is None: input_to_enhance = generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance) encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model) # for idx in range(4): # print(f">>> cuda:{idx}", torch.cuda.max_memory_allocated(f"cuda:{idx}")) return encoded_video def change_visibility(value): if value == "SVD (image to video)": return gr.Image(label='Image Prompt (if not attached then SDXL will be used to generate the starting image)', show_label=True, scale=1, show_download_button=False, interactive=True, value=None) else: return gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, value=None) # [prompt_stage1, video_stage2, num_frames, image_stage1, model_name_stage1, seed, t, image_guidance] examples_1 = [ ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.", "__assets__/examples/t2v/1.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0], ["People dancing in room filled with fog and colorful lights.", "__assets__/examples/t2v/2.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0], ["Discover the secret language of bees: delve into the complex communication system that allows bees to coordinate their actions and navigate the world.", "__assets__/examples/t2v/3.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0], ["sunset, orange sky, warm lighting, fishing boats, ocean waves seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, coastal landscape, seaside scenery.", "__assets__/examples/t2v/4.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0], ["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.", "__assets__/examples/t2v/5.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0], ["Ants, beetles and centipede nest.", "__assets__/examples/t2v/6.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0], ] examples_2 = [ ["Fishes swimming in ocean camera moving, cinematic.", "__assets__/examples/i2v/1.mp4", "56 - frames", "__assets__/fish.jpg", "SVD (image to video)", 33, 50, 9.0], ["A squirrel on a table full of big nuts.", "__assets__/examples/i2v/2.mp4", "56 - frames", "__assets__/squirrel.jpg", "SVD (image to video)", 33, 50, 9.0], ] # -------------------------- # ----- Gradio-Demo UI ----- # -------------------------- with gr.Blocks() as demo: gr.HTML( """
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.