Spaces:
Runtime error
Runtime error
Delete t2v_enhanced/inference.py
Browse files- t2v_enhanced/inference.py +0 -82
t2v_enhanced/inference.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
# General
|
2 |
-
import os
|
3 |
-
from os.path import join as opj
|
4 |
-
import argparse
|
5 |
-
import datetime
|
6 |
-
from pathlib import Path
|
7 |
-
import torch
|
8 |
-
import gradio as gr
|
9 |
-
import tempfile
|
10 |
-
import yaml
|
11 |
-
from t2v_enhanced.model.video_ldm import VideoLDM
|
12 |
-
|
13 |
-
# Utilities
|
14 |
-
from t2v_enhanced.inference_utils import *
|
15 |
-
from t2v_enhanced.model_init import *
|
16 |
-
from t2v_enhanced.model_func import *
|
17 |
-
|
18 |
-
|
19 |
-
if __name__ == "__main__":
|
20 |
-
parser = argparse.ArgumentParser()
|
21 |
-
parser.add_argument('--prompt', type=str, default="A cat running on the street", help="The prompt to guide video generation.")
|
22 |
-
parser.add_argument('--image', type=str, default="", help="Path to image conditioning.")
|
23 |
-
# parser.add_argument('--video', type=str, default="", help="Path to video conditioning.")
|
24 |
-
parser.add_argument('--base_model', type=str, default="ModelscopeT2V", help="Base model to generate first chunk from", choices=["ModelscopeT2V", "AnimateDiff", "SVD"])
|
25 |
-
parser.add_argument('--num_frames', type=int, default=24, help="The number of video frames to generate.")
|
26 |
-
parser.add_argument('--negative_prompt', type=str, default="", help="The prompt to guide what to not include in video generation.")
|
27 |
-
parser.add_argument('--num_steps', type=int, default=50, help="The number of denoising steps.")
|
28 |
-
parser.add_argument('--image_guidance', type=float, default=9.0, help="The guidance scale.")
|
29 |
-
|
30 |
-
parser.add_argument('--output_dir', type=str, default="results", help="Path where to save the generated videos.")
|
31 |
-
parser.add_argument('--device', type=str, default="cuda")
|
32 |
-
parser.add_argument('--seed', type=int, default=33, help="Random seed")
|
33 |
-
args = parser.parse_args()
|
34 |
-
|
35 |
-
|
36 |
-
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
37 |
-
result_fol = Path(args.output_dir).absolute()
|
38 |
-
device = args.device
|
39 |
-
|
40 |
-
|
41 |
-
# --------------------------
|
42 |
-
# ----- Configurations -----
|
43 |
-
# --------------------------
|
44 |
-
ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
|
45 |
-
cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
|
46 |
-
|
47 |
-
|
48 |
-
# --------------------------
|
49 |
-
# ----- Initialization -----
|
50 |
-
# --------------------------
|
51 |
-
stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
|
52 |
-
if args.base_model == "ModelscopeT2V":
|
53 |
-
model = init_modelscope(device)
|
54 |
-
elif args.base_model == "AnimateDiff":
|
55 |
-
model = init_animatediff(device)
|
56 |
-
elif args.base_model == "SVD":
|
57 |
-
model = init_svd(device)
|
58 |
-
sdxl_model = init_sdxl(device)
|
59 |
-
|
60 |
-
|
61 |
-
inference_generator = torch.Generator(device="cuda")
|
62 |
-
|
63 |
-
|
64 |
-
# ------------------
|
65 |
-
# ----- Inputs -----
|
66 |
-
# ------------------
|
67 |
-
now = datetime.datetime.now()
|
68 |
-
name = args.prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
|
69 |
-
|
70 |
-
inference_generator = torch.Generator(device="cuda")
|
71 |
-
inference_generator.manual_seed(args.seed)
|
72 |
-
|
73 |
-
if args.base_model == "ModelscopeT2V":
|
74 |
-
short_video = ms_short_gen(args.prompt, model, inference_generator)
|
75 |
-
elif args.base_model == "AnimateDiff":
|
76 |
-
short_video = ad_short_gen(args.prompt, model, inference_generator)
|
77 |
-
elif args.base_model == "SVD":
|
78 |
-
short_video = svd_short_gen(args.image, args.prompt, model, sdxl_model, inference_generator)
|
79 |
-
|
80 |
-
n_autoreg_gen = args.num_frames // 8 - 8
|
81 |
-
stream_long_gen(args.prompt, short_video, n_autoreg_gen, args.negative_prompt, args.seed, args.num_steps, args.image_guidance, name, stream_cli, stream_model)
|
82 |
-
video2video(args.prompt, opj(result_fol, name+".mp4"), result_fol, cfg_v2v, msxl_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|