File size: 5,274 Bytes
26555ee
 
 
 
 
 
 
 
 
 
ceeea40
6eb6fe1
004a144
e45b743
 
 
 
 
862185b
3ed2ed6
862185b
19801e5
 
26555ee
004a144
8c4bc8e
1dade9e
 
 
 
 
 
 
 
 
 
8f2d0e0
844e818
8f2d0e0
 
 
 
 
3ed2ed6
 
 
26555ee
 
 
 
 
1dade9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26555ee
 
 
333f8a8
 
8f27515
333f8a8
 
 
 
26555ee
333f8a8
 
 
 
 
26555ee
333f8a8
26555ee
3ed2ed6
 
3906138
 
3ed2ed6
 
 
 
 
 
3906138
 
 
 
 
 
 
 
 
 
26555ee
844e818
 
d7cdbdd
844e818
 
 
26555ee
08c8308
26555ee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from text_to_video import model_t2v_fun,setup_seed
from omegaconf import OmegaConf
import torch
import imageio
import os
import cv2
import pandas as pd
import torchvision
import random
import base64
from models import get_models

from pipelines.pipeline_videogen import VideoGenPipeline
from download import find_model
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection

SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

config_path = "./base/configs/sample.yaml"
args = OmegaConf.load("./base/configs/sample.yaml")
device = "cuda" if torch.cuda.is_available() else "cpu"

sd_path = args.pretrained_path
unet = get_models(args, sd_path).to(device, dtype=torch.float16)
state_dict = find_model("./pretrained_models/lavie_base.pt")
unet.load_state_dict(state_dict)
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
unet.eval()
vae.eval()
text_encoder_one.eval()

#def infer(secret_token, prompt, seed_inp, ddim_steps,cfg, infer_type):
def generate_video(secret_token, prompt):
    seed_inp = -1
    ddim_steps = 50
    cfg = 7.5
    infer_type = "ddim"
    
    if secret_token != SECRET_TOKEN:
        raise gr.Error(f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
       
    if seed_inp!=-1:
        setup_seed(seed_inp)
    else:
        seed_inp = random.choice(range(10000000))
        setup_seed(seed_inp)
    if infer_type == 'ddim':
        scheduler = DDIMScheduler.from_pretrained(sd_path, 
											   subfolder="scheduler",
											   beta_start=args.beta_start, 
											   beta_end=args.beta_end, 
											   beta_schedule=args.beta_schedule)
    elif infer_type == 'eulerdiscrete':
        scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
        									   subfolder="scheduler",
											   beta_start=args.beta_start,
											   beta_end=args.beta_end,
											   beta_schedule=args.beta_schedule)
    elif infer_type == 'ddpm':
        scheduler = DDPMScheduler.from_pretrained(sd_path,
											  subfolder="scheduler",
											  beta_start=args.beta_start,
											  beta_end=args.beta_end,
											  beta_schedule=args.beta_schedule)
    model = VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
    model.to(device)
    if device == "cuda":
        model.enable_xformers_memory_efficient_attention()
    videos = model(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)

    video_path = args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4'
    
    torchvision.io.write_video(video_path, videos[0], fps=8)

    # Read the content of the video file and encode it to base64
    with open(video_path, "rb") as video_file:
        video_base64 = base64.b64encode(video_file.read()).decode('utf-8')

    # Prepend the appropriate data URI header with MIME type
    video_data_uri = 'data:video/mp4;base64,' + video_base64
    
    # Clean up the video file to avoid filling up storage
    # os.remove(video_path)

    return video_data_uri


with gr.Blocks() as demo:
    with gr.Column():
        gr.HTML("""
            <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
              <div style="text-align: center; color: black;">
                <p style="color: black;">This space is a REST API to programmatically generate MP4 videos.</p>
                <p style="color: black;">Interested in using it? Look no further than the <a href="https://huggingface.co/spaces/Vchitect/LaVie" target="_blank">original space</a>!</p>
              </div>
        </div>""")
        secret_token = gr.Textbox(label="Secret token")
 
        prompt = gr.Textbox(value="", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
        infer_type = gr.Dropdown(['ddpm','ddim','eulerdiscrete'], label='infer_type',value='ddim')
        ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
        seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
        cfg = gr.Number(label="guidance_scale",value=7.5)

        submit_btn = gr.Button("Generate video")
        base64_out = gr.Textbox(label="Base64 Video")

    submit_btn.click(
        fn=generate_video,
        inputs=[secret_token, prompt], # seed_inp, ddim_steps, cfg, infer_type],
        outputs=base64_out,
        api_name='run',
    )

demo.queue(max_size=12).launch()