File size: 5,110 Bytes
b3f324b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123


import argparse
import sys
import os
import random

import imageio
import torch
from diffusers import PNDMScheduler
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from datetime import datetime
from typing import List, Union
import gradio as gr
import numpy as np
from gradio.components import Textbox, Video, Image
from transformers import T5Tokenizer, T5EncoderModel

from opensora.models.ae import ae_stride_config, getae, getae_wrapper
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
from opensora.sample.pipeline_videogen import VideoGenPipeline
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION


@torch.inference_mode()
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
    seed = int(randomize_seed_fn(seed, randomize_seed))
    set_env(seed)
    video_length = transformer_model.config.video_length if not force_images else 1
    height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
    num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
    videos = videogen_pipeline(prompt,
                               video_length=video_length,
                               height=height,
                               width=width,
                               num_inference_steps=sample_steps,
                               guidance_scale=scale,
                               enable_temporal_attentions=not force_images,
                               num_images_per_prompt=1,
                               mask_feature=True,
                               ).video

    torch.cuda.empty_cache()
    videos = videos[0]
    tmp_save_path = 'tmp.mp4'
    imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9)  # highest quality is 10, lowest is 0
    display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
    return tmp_save_path, prompt, display_model_info, seed

if __name__ == '__main__':
    args = type('args', (), {
        'ae': 'CausalVAEModel_4x8x8',
        'force_images': False,
        'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0',
        'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl',
        'version': '65x512x512'
    })
    device = torch.device('cuda:0')

    # Load model:
    transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device)

    vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
    vae.vae.enable_tiling()
    image_size = int(args.version.split('x')[1])
    latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
    vae.latent_size = latent_size
    transformer_model.force_images = args.force_images
    tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
    text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir",
                                                  torch_dtype=torch.float16).to(device)

    # set eval mode
    transformer_model.eval()
    vae.eval()
    text_encoder.eval()
    scheduler = PNDMScheduler()
    videogen_pipeline = VideoGenPipeline(vae=vae,
                                         text_encoder=text_encoder,
                                         tokenizer=tokenizer,
                                         scheduler=scheduler,
                                         transformer=transformer_model).to(device=device)


    demo = gr.Interface(
        fn=generate_img,
        inputs=[Textbox(label="",
                        placeholder="Please enter your prompt. \n"),
                gr.Slider(
                    label='Sample Steps',
                    minimum=1,
                    maximum=500,
                    value=50,
                    step=10
                ),
                gr.Slider(
                    label='Guidance Scale',
                    minimum=0.1,
                    maximum=30.0,
                    value=10.0,
                    step=0.1
                ),
                gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=203279,
                    step=1,
                    value=0,
                ),
                gr.Checkbox(label="Randomize seed", value=True),
                gr.Checkbox(label="Generate image (1 frame video)", value=False),
                ],
        outputs=[Video(label="Vid", width=512, height=512),
                 Textbox(label="input prompt"),
                 Textbox(label="model info"),
                 gr.Slider(label='seed')],
        title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css,
        examples=examples,
    )
    demo.launch()