import torch # torch.jit.script = lambda f: f # General import os from os.path import join as opj import argparse import datetime from pathlib import Path import spaces import gradio as gr import tempfile import yaml from t2v_enhanced.model.video_ldm import VideoLDM # Utilities from t2v_enhanced.inference_utils import * from t2v_enhanced.model_init import * from t2v_enhanced.model_func import * on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" parser = argparse.ArgumentParser() parser.add_argument('--public_access', action='store_true', default=True) parser.add_argument('--where_to_log', type=str, default="gradio_output") parser.add_argument('--device', type=str, default="cuda") args = parser.parse_args() Path(args.where_to_log).mkdir(parents=True, exist_ok=True) result_fol = Path(args.where_to_log).absolute() device = args.device # -------------------------- # ----- Configurations ----- # -------------------------- cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True} # -------------------------- # ----- Initialization ----- # -------------------------- ms_model = init_modelscope(device) # # zs_model = init_zeroscope(device) ad_model = init_animatediff(device) svd_model = init_svd(device) sdxl_model = init_sdxl(device) ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute() stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol) msxl_model = init_v2v_model(cfg_v2v) # ------------------------- # ----- Functionality ----- # ------------------------- @spaces.GPU(duration=120) def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol): now = datetime.datetime.now() name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_") if num_frames == [] or num_frames is None: num_frames = 24 else: num_frames = int(num_frames.split(" ")[0]) if num_frames > 56: num_frames = 56 n_autoreg_gen = (num_frames-8)//8 inference_generator = torch.Generator(device="cuda").manual_seed(seed) if model_name_stage1 == "ModelScopeT2V (text to video)": short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device) elif model_name_stage1 == "AnimateDiff (text to video)": short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device) elif model_name_stage1 == "SVD (image to video)": # For cached examples if isinstance(image, dict): image = image["path"] short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device) stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model) video_path = opj(where_to_log, name+".mp4") return video_path @spaces.GPU(duration=400) def enhance(prompt, input_to_enhance, num_frames=None, image=None, model_name_stage1=None, model_name_stage2=None, seed=33, t=50, image_guidance=9.5, result_fol=result_fol): if input_to_enhance is None: input_to_enhance = generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance) encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model) return encoded_video def change_visibility(value): if value == "SVD (image to video)": return gr.Image(label='Image Prompt (if not attached then SDXL will be used to generate the starting image)', show_label=True, scale=1, show_download_button=False, interactive=True, value=None) else: return gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, value=None) examples_1 = [ ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.", None, "56 - frames", None, "ModelScopeT2V (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["People dancing in room filled with fog and colorful lights.", None, "56 - frames", None, "ModelScopeT2V (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["Discover the secret language of bees: delve into the complex communication system that allows bees to coordinate their actions and navigate the world.", None, "56 - frames", None, "AnimateDiff (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["sunset, orange sky, warm lighting, fishing boats, ocean waves seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, coastal landscape, seaside scenery.", None, "56 - frames", None, "AnimateDiff (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.", None, "56 - frames", None, "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["Ants, beetles and centipede nest.", None, "56 - frames", None, "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ] examples_2 = [ ["Fishes swimming in ocean camera moving, cinematic.", None, "56 - frames", "__assets__/fish.jpg", "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ["A squirrel on a table full of big nuts.", None, "56 - frames", "__assets__/squirrel.jpg", "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0], ] # -------------------------- # ----- Gradio-Demo UI ----- # -------------------------- with gr.Blocks() as demo: gr.HTML( """
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.