lev1 commited on
Commit
ff9dd1f
·
verified ·
1 Parent(s): 9cd26c4

Delete t2v_enhanced/inference.py

Browse files
Files changed (1) hide show
  1. t2v_enhanced/inference.py +0 -82
t2v_enhanced/inference.py DELETED
@@ -1,82 +0,0 @@
1
- # General
2
- import os
3
- from os.path import join as opj
4
- import argparse
5
- import datetime
6
- from pathlib import Path
7
- import torch
8
- import gradio as gr
9
- import tempfile
10
- import yaml
11
- from t2v_enhanced.model.video_ldm import VideoLDM
12
-
13
- # Utilities
14
- from t2v_enhanced.inference_utils import *
15
- from t2v_enhanced.model_init import *
16
- from t2v_enhanced.model_func import *
17
-
18
-
19
- if __name__ == "__main__":
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument('--prompt', type=str, default="A cat running on the street", help="The prompt to guide video generation.")
22
- parser.add_argument('--image', type=str, default="", help="Path to image conditioning.")
23
- # parser.add_argument('--video', type=str, default="", help="Path to video conditioning.")
24
- parser.add_argument('--base_model', type=str, default="ModelscopeT2V", help="Base model to generate first chunk from", choices=["ModelscopeT2V", "AnimateDiff", "SVD"])
25
- parser.add_argument('--num_frames', type=int, default=24, help="The number of video frames to generate.")
26
- parser.add_argument('--negative_prompt', type=str, default="", help="The prompt to guide what to not include in video generation.")
27
- parser.add_argument('--num_steps', type=int, default=50, help="The number of denoising steps.")
28
- parser.add_argument('--image_guidance', type=float, default=9.0, help="The guidance scale.")
29
-
30
- parser.add_argument('--output_dir', type=str, default="results", help="Path where to save the generated videos.")
31
- parser.add_argument('--device', type=str, default="cuda")
32
- parser.add_argument('--seed', type=int, default=33, help="Random seed")
33
- args = parser.parse_args()
34
-
35
-
36
- Path(args.output_dir).mkdir(parents=True, exist_ok=True)
37
- result_fol = Path(args.output_dir).absolute()
38
- device = args.device
39
-
40
-
41
- # --------------------------
42
- # ----- Configurations -----
43
- # --------------------------
44
- ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
45
- cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
46
-
47
-
48
- # --------------------------
49
- # ----- Initialization -----
50
- # --------------------------
51
- stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
52
- if args.base_model == "ModelscopeT2V":
53
- model = init_modelscope(device)
54
- elif args.base_model == "AnimateDiff":
55
- model = init_animatediff(device)
56
- elif args.base_model == "SVD":
57
- model = init_svd(device)
58
- sdxl_model = init_sdxl(device)
59
-
60
-
61
- inference_generator = torch.Generator(device="cuda")
62
-
63
-
64
- # ------------------
65
- # ----- Inputs -----
66
- # ------------------
67
- now = datetime.datetime.now()
68
- name = args.prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
69
-
70
- inference_generator = torch.Generator(device="cuda")
71
- inference_generator.manual_seed(args.seed)
72
-
73
- if args.base_model == "ModelscopeT2V":
74
- short_video = ms_short_gen(args.prompt, model, inference_generator)
75
- elif args.base_model == "AnimateDiff":
76
- short_video = ad_short_gen(args.prompt, model, inference_generator)
77
- elif args.base_model == "SVD":
78
- short_video = svd_short_gen(args.image, args.prompt, model, sdxl_model, inference_generator)
79
-
80
- n_autoreg_gen = args.num_frames // 8 - 8
81
- stream_long_gen(args.prompt, short_video, n_autoreg_gen, args.negative_prompt, args.seed, args.num_steps, args.image_guidance, name, stream_cli, stream_model)
82
- video2video(args.prompt, opj(result_fol, name+".mp4"), result_fol, cfg_v2v, msxl_model)