File size: 17,330 Bytes
94bafa8
 
ed933a7
94bafa8
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
94bafa8
 
 
 
 
4051d56
94bafa8
ed933a7
94bafa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
 
 
 
 
94bafa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
 
 
 
 
 
 
94bafa8
 
 
b1e71c1
 
94bafa8
 
 
 
 
b1e71c1
 
 
 
 
94bafa8
 
 
 
 
 
 
 
 
 
 
 
4051d56
 
94bafa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
94bafa8
 
 
b1e71c1
94bafa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
4051d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e71c1
 
4051d56
94bafa8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import gradio as gr
import os
import spaces
import torch
import argparse
import torchvision


from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, 
                                  EulerDiscreteScheduler, DPMSolverMultistepScheduler, 
                                  HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
                                  DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from omegaconf import OmegaConf
from transformers import T5EncoderModel, T5Tokenizer, BitsAndBytesConfig

import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from sample.pipeline_latte import LattePipeline
from models import get_models
import imageio
from torchvision.utils import save_image


    
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
args = parser.parse_args()
args = OmegaConf.load(args.config)

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

transformer_model = get_models(args).to(device, dtype=torch.float16)

if args.enable_vae_temporal_decoder:
    vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
else:
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, 
                                              subfolder="text_encoder", 
                                              quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
                                              device_map="auto",
                                              )

# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()

@spaces.GPU
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
    torch.manual_seed(seed)
    if sample_method == 'DDIM':
        scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type,
                                                    clip_sample=False)
    elif sample_method == 'EulerDiscrete':
        scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, 
                                                        subfolder="scheduler",
                                                        beta_start=args.beta_start, 
                                                        beta_end=args.beta_end, 
                                                        beta_schedule=args.beta_schedule,
                                                        variance_type=args.variance_type)
    elif sample_method == 'DDPM':
        scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type,
                                                    clip_sample=False)
    elif sample_method == 'DPMSolverMultistep':
        scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'DPMSolverSinglestep':
        scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'PNDM':
        scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'HeunDiscrete':
        scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'EulerAncestralDiscrete':
        scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'DEISMultistep':
        scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
    elif sample_method == 'KDPM2AncestralDiscrete':
        scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, 
                                                    subfolder="scheduler",
                                                    beta_start=args.beta_start, 
                                                    beta_end=args.beta_end, 
                                                    beta_schedule=args.beta_schedule,
                                                    variance_type=args.variance_type)
        
    pipe_tmp = LattePipeline.from_pretrained(
                                    args.pretrained_model_path,
                                    transformer=None,
                                    text_encoder=text_encoder,
                                    device_map="balanced",)
    prompt_embeds, negative_prompt_embeds = pipe_tmp.encode_prompt(text_input, negative_prompt="")


    videogen_pipeline = LattePipeline(vae=vae, 
                                    # text_encoder=text_encoder, 
                                    text_encoder=None, 
                                    tokenizer=tokenizer, 
                                    scheduler=scheduler, 
                                    transformer=transformer_model).to(device)
    # videogen_pipeline.enable_xformers_memory_efficient_attention()

    videos = videogen_pipeline(
                                # text_input, 
                                prompt_embeds=prompt_embeds,
                                negative_prompt=None,
                                negative_prompt_embeds=negative_prompt_embeds,
                                video_length=video_length, 
                                height=height, 
                                width=width, 
                                num_inference_steps=diffusion_step,
                                guidance_scale=scfg_scale,
                                enable_temporal_attentions=args.enable_temporal_attentions,
                                num_images_per_prompt=1,
                                mask_feature=True,
                                enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
                                ).video

    save_path = args.save_img_path + 'temp' + '.mp4'
    # torchvision.io.write_video(save_path, videos[0], fps=8)
    imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
    return save_path


if not os.path.exists(args.save_img_path):
    os.makedirs(args.save_img_path)

intro = """
<div style="display: flex;align-items: center;justify-content: center">
    <h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1>
</div>
"""

with gr.Blocks() as demo:
    # gr.HTML(intro)
    # with gr.Accordion("README", open=False):
    #     gr.HTML(
    #         """
    #         <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
    #             <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a>
    #         </p>

    #         We will continue update Latte.
    #     """
    #     )
    gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>")
    gr.Markdown(
        """<div style="display: flex;align-items: center;justify-content: center">
        <h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div>
        """
    )
    gr.Markdown(
        """<div style="display: flex;align-items: center;justify-content: center">
        [<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div>
        """
    )


    with gr.Row():
        with gr.Column(visible=True) as input_raws:
            with gr.Row():
                with gr.Column(scale=1.0):
                    text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")

            with gr.Row():
                with gr.Column(scale=0.5):
                    sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")

                with gr.Column(scale=0.5):
                    video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
            with gr.Row():
                with gr.Column(scale=1.0):
                    scfg_scale = gr.Slider(
                        minimum=1,
                        maximum=50,
                        value=7.5,
                        step=0.1,
                        interactive=True,
                        label="Guidence Scale",
                    )
            with gr.Row():
                with gr.Column(scale=1.0):
                    seed = gr.Slider(
                        minimum=1,
                        maximum=2147483647,
                        value=100,
                        step=1,
                        interactive=True,
                        label="Seed",
                    )
            with gr.Row():
                with gr.Column(scale=0.5):
                    height = gr.Slider(
                        minimum=256,
                        maximum=768,
                        value=512,
                        step=16,
                        interactive=False,
                        label="Height",
                    )
            # with gr.Row():
                with gr.Column(scale=0.5):
                    width = gr.Slider(
                        minimum=256,
                        maximum=768,
                        value=512,
                        step=16,
                        interactive=False,
                        label="Width",
                    )
            with gr.Row():
                with gr.Column(scale=1.0):
                    diffusion_step = gr.Slider(
                        minimum=20,
                        maximum=250,
                        value=50,
                        step=1,
                        interactive=True,
                        label="Sampling Step",
                    )

                         
        with gr.Column(scale=0.6, visible=True) as video_upload:
            output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360)

            with gr.Row():
                with gr.Column(scale=1.0, min_width=0):
                    run = gr.Button(value="Generate", variant='primary')

    EXAMPLES = [
        ["3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["A wizard wearing a pointed hat and a blue robe with white stars casting a spell that shoots lightning from his hand and holding an old tome in his other hand.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["Cinematic trailer for a group of samoyed puppies learning to become chefs.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
        ["A cyborg koala dj in front of aturntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-f, fantasy, intricate, neon light, soft light smooth, sharp focus, illustration.",  "DDIM", 7.5, 100, 512, 512, 16, 50],
    ]

    examples = gr.Examples(
        examples = EXAMPLES,
        fn = gen_video,
        inputs=[text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step],
        outputs=[output],
        cache_examples=True,
        # cache_examples="lazy",
    )
                    
    run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
    
demo.launch(debug=False, share=True)