File size: 11,188 Bytes
5c4a11c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5f9b65
60b82a2
d8df719
60b82a2
17e381c
4af04e2
 
5c4a11c
 
 
 
17e381c
 
 
 
 
 
25a83f9
 
 
17e381c
 
 
 
25a83f9
17e381c
25a83f9
17e381c
25a83f9
17e381c
5c4a11c
54ad002
4997010
 
54ad002
1b1d216
b1693d8
 
 
 
 
54ad002
5c4a11c
d29f4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6320af
 
 
a92b0d1
b6320af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8df719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bfbba7
 
 
 
 
 
 
 
 
 
 
21c2481
 
b6320af
d8df719
25a83f9
 
 
 
 
 
 
d8df719
664012c
a92b0d1
5d778f1
664012c
b6320af
1bfbba7
5c4a11c
a2c25bf
 
 
 
 
 
 
 
 
 
c5f87c2
a2c25bf
 
 
 
 
60b82a2
a2c25bf
b19ac30
a2c25bf
60b82a2
 
 
 
 
d8df719
 
25a83f9
 
 
 
 
 
d796b5a
60b82a2
9c6d3b9
60b82a2
4e684d3
9c6d3b9
c353fff
 
a92b0d1
c1fcf29
 
9c6d3b9
 
 
60b82a2
 
d8df719
4e684d3
c1fcf29
9c6d3b9
c1fcf29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import gradio as gr
import os
import time
import argparse
import yaml, math
from tqdm import trange
import torch
import numpy as np
from omegaconf import OmegaConf
import torch.distributed as dist
from pytorch_lightning import seed_everything

from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import str2bool
from lvdm.utils.dist_utils import setup_dist, gather_data
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
from utils import load_model, get_conditions, make_model_input_shape, torch_to_np
from lvdm.models.modules.lora import change_lora
from lvdm.utils.saving_utils import tensor_to_mp4

from huggingface_hub import hf_hub_download
import subprocess
import shlex 

config_path = "model_config.yaml"
config = OmegaConf.load(config_path)

# Download model
REPO_ID = 'VideoCrafter/t2v-version-1-1'
filename_list = ['models/base_t2v/model.ckpt',
                'models/videolora/lora_001_Loving_Vincent_style.ckpt',
                'models/videolora/lora_002_frozenmovie_style.ckpt',
                'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt',
                'models/videolora/lora_004_coco_style.ckpt',
                'models/adapter_t2v_depth/adapter.pth']

for filename in filename_list:
    if not os.path.exists(filename):
        hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)
            
ckpt_path = 'models/base_t2v/model.ckpt'

midas_path_url = 'https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt'

subprocess.run(shlex.split(f'wget {midas_path_url} -O models/adapter_t2v_depth/dpt_hybrid-midas.pt'))

# # get model & sampler
model, _, _ = load_model(config, ckpt_path, 
                         inject_lora=False, 
                         lora_scale=None, 
                         )
adapter_ckpt = 'models/adapter_t2v_depth/adapter.pth'
state_dict = torch.load(adapter_ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
    state_dict = state_dict["state_dict"]
model.adapter.load_state_dict(state_dict, strict=True)

ddim_sampler = DDIMSampler(model)

def sample_denoising_batch(model, noise_shape, condition, *args,
                           sample_type="ddim", sampler=None, 
                           ddim_steps=None, eta=None,
                           unconditional_guidance_scale=1.0, uc=None,
                           denoising_progress=False,
                           **kwargs,
                           ):

    assert(sampler is not None)
    assert(ddim_steps is not None)
    assert(eta is not None)
    ddim_sampler = sampler
    samples, _ = ddim_sampler.sample(S=ddim_steps,
                                     conditioning=condition,
                                     batch_size=noise_shape[0],
                                     shape=noise_shape[1:],
                                     verbose=denoising_progress,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=uc,
                                     eta=eta,
                                     **kwargs,
                                    )
    return samples
                               
@torch.no_grad()
def sample_text2video(model, prompt, n_samples, batch_size,
                      sample_type="ddim", sampler=None, 
                      ddim_steps=50, eta=1.0, cfg_scale=7.5, 
                      decode_frame_bs=1,
                      ddp=False, all_gather=True, 
                      batch_progress=True, show_denoising_progress=False,
                      ):
    # get cond vector
    assert(model.cond_stage_model is not None)
    cond_embd = get_conditions(prompt, model, batch_size)
    uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None

    # sample batches
    all_videos = []
    n_iter = math.ceil(n_samples / batch_size)
    iterator  = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
    for _ in iterator:
        noise_shape = make_model_input_shape(model, batch_size)
        samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
                                            sample_type=sample_type,
                                            sampler=sampler,
                                            ddim_steps=ddim_steps,
                                            eta=eta,
                                            unconditional_guidance_scale=cfg_scale, 
                                            uc=uncond_embd,
                                            denoising_progress=show_denoising_progress,
                                            )
        samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
        
        # gather samples from multiple gpus
        if ddp and all_gather:
            data_list = gather_data(samples, return_np=False)
            all_videos.extend([torch_to_np(data) for data in data_list])
        else:
            all_videos.append(torch_to_np(samples))
    
    all_videos = np.concatenate(all_videos, axis=0)
    assert(all_videos.shape[0] >= n_samples)
    return all_videos

def adapter_guided_synthesis(model, prompts, videos, noise_shape, sampler, n_samples=1, ddim_steps=50, ddim_eta=1., \
                        unconditional_guidance_scale=1.0, unconditional_guidance_scale_temporal=None, **kwargs):
    ddim_sampler = sampler 

    batch_size = noise_shape[0]
    ## get condition embeddings (support single prompt only)
    if isinstance(prompts, str):
        prompts = [prompts]
    cond = model.get_learned_conditioning(prompts)
    if unconditional_guidance_scale != 1.0:
        prompts = batch_size * [""]
        uc = model.get_learned_conditioning(prompts)
    else:
        uc = None
    
    ## adapter features: process in 2D manner
    b, c, t, h, w = videos.shape
    extra_cond = model.get_batch_depth(videos, (h,w))
    features_adapter = model.get_adapter_features(extra_cond)

    batch_variants = []
    for _ in range(n_samples):
        if ddim_sampler is not None:
            samples, _ = ddim_sampler.sample(S=ddim_steps,
                                            conditioning=cond,
                                            batch_size=noise_shape[0],
                                            shape=noise_shape[1:],
                                            verbose=False,
                                            unconditional_guidance_scale=unconditional_guidance_scale,
                                            unconditional_conditioning=uc,
                                            eta=ddim_eta,
                                            temporal_length=noise_shape[2],
                                            conditional_guidance_scale_temporal=unconditional_guidance_scale_temporal,
                                            features_adapter=features_adapter,
                                            **kwargs
                                            )        
        ## reconstruct from latent to pixel space
        batch_images = model.decode_first_stage(samples, decode_bs=1, return_cpu=False)
        batch_variants.append(batch_images)
    ## variants, batch, c, t, h, w
    batch_variants = torch.stack(batch_variants)
    return batch_variants.permute(1, 0, 2, 3, 4, 5), extra_cond


def save_results(videos, 
                 save_name="results", save_fps=8, save_mp4=True, 
                 save_npz=False, save_mp4_sheet=False, save_jpg=False
                 ):
    
    save_subdir = os.path.join("videos")
    os.makedirs(save_subdir, exist_ok=True)
    for i in range(videos.shape[0]):
        npz_to_video_grid(videos[i:i+1,...], 
                          os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), 
                          fps=save_fps)
        
    return os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4")

def save_results_control(batch_samples, batch_conds):
    save_subdir = os.path.join("videos")
    os.makedirs(save_subdir, exist_ok=True)
    
    tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(save_subdir, f'results_depth.mp4'), fps=10)
    tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(save_subdir, f'results_sample.mp4'), fps=10)
    
    return os.path.join(save_subdir, f'results_depth.mp4'), os.path.join(save_subdir, f'results_sample.mp4')

def get_video(prompt, seed, ddim_steps):
    seed_everything(seed)
    samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1,
                          sampler=ddim_sampler, ddim_steps=ddim_steps
                          )
    return save_results(samples)

def get_video_lora(prompt, seed, ddim_steps, model_choice):
    
    model_to_style = {
        "Frozen": ", frozenmovie style",
        "Coco": ", coco style",
        "Loving Vincent": ", Loving Vincent style",
        "MakotoShinkai YourName": ", MakotoShinkaiYourName style"
    }
    
    model_to_index = {
        "Frozen": 2,
        "Coco": 4,
        "Loving Vincent": 1,
        "MakotoShinkai YourName": 3
    }
    
    seed_everything(seed)
    prompt = prompt + model_to_style[model_choice]
    print(prompt)
    change_lora(model, inject_lora=True, lora_scale=1.0,lora_path = filename_list[model_to_index[model_choice]])
    samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1,
                          sampler=ddim_sampler, ddim_steps=ddim_steps
                          )
    return save_results(samples)

def get_video_control(prompt, input_video, seed, ddim_steps):
    seed_everything(seed)
    h,w = 512//8, 512//8
    noise_shape = [1, model.channels, model_control.temporal_length,h,w]
    batch_samples, batch_conds = adapter_guided_synthesis(model, prompt,input_video,noise_shape, sampler=ddim_sampler, n_samples = 1,
                          ddim_steps=ddim_steps
                          )
    #return save_results_control(batch_samples, batch_conds)
    return input_video

from gradio_t2v import create_demo as create_demo_basic
from gradio_videolora import create_demo as create_demo_videolora
from gradio_videocontrol import create_demo as create_demo_videocontrol

DESCRIPTION = '# [Latent Video Diffusion Models](https://github.com/VideoCrafter/VideoCrafter)'
DESCRIPTION += '\n<p>πŸ€—πŸ€—πŸ€— VideoCrafter is an open-source video generation and editing toolbox for crafting video content. This model can only be used for non-commercial purposes. To learn more about the model, take a look at the <a href="https://github.com/VideoCrafter/VideoCrafter" style="text-decoration: underline;" target="_blank">model card</a>.</p>'

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Tabs():
        with gr.TabItem('Basic Text2Video'):
            create_demo_basic(get_video)
        with gr.TabItem('VideoLoRA'):
            create_demo_videolora(get_video_lora)
        with gr.TabItem('VideoControl'):
            create_demo_videocontrol(get_video_control)

demo.queue(api_open=False).launch()