File size: 5,274 Bytes
26555ee ceeea40 6eb6fe1 004a144 e45b743 862185b 3ed2ed6 862185b 19801e5 26555ee 004a144 8c4bc8e 1dade9e 8f2d0e0 844e818 8f2d0e0 3ed2ed6 26555ee 1dade9e 26555ee 333f8a8 8f27515 333f8a8 26555ee 333f8a8 26555ee 333f8a8 26555ee 3ed2ed6 3906138 3ed2ed6 3906138 26555ee 844e818 d7cdbdd 844e818 26555ee 08c8308 26555ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
from text_to_video import model_t2v_fun,setup_seed
from omegaconf import OmegaConf
import torch
import imageio
import os
import cv2
import pandas as pd
import torchvision
import random
import base64
from models import get_models
from pipelines.pipeline_videogen import VideoGenPipeline
from download import find_model
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
config_path = "./base/configs/sample.yaml"
args = OmegaConf.load("./base/configs/sample.yaml")
device = "cuda" if torch.cuda.is_available() else "cpu"
sd_path = args.pretrained_path
unet = get_models(args, sd_path).to(device, dtype=torch.float16)
state_dict = find_model("./pretrained_models/lavie_base.pt")
unet.load_state_dict(state_dict)
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
unet.eval()
vae.eval()
text_encoder_one.eval()
#def infer(secret_token, prompt, seed_inp, ddim_steps,cfg, infer_type):
def generate_video(secret_token, prompt):
seed_inp = -1
ddim_steps = 50
cfg = 7.5
infer_type = "ddim"
if secret_token != SECRET_TOKEN:
raise gr.Error(f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
if seed_inp!=-1:
setup_seed(seed_inp)
else:
seed_inp = random.choice(range(10000000))
setup_seed(seed_inp)
if infer_type == 'ddim':
scheduler = DDIMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif infer_type == 'eulerdiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif infer_type == 'ddpm':
scheduler = DDPMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
model = VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
model.to(device)
if device == "cuda":
model.enable_xformers_memory_efficient_attention()
videos = model(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
if not os.path.exists(args.output_folder):
os.mkdir(args.output_folder)
video_path = args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4'
torchvision.io.write_video(video_path, videos[0], fps=8)
# Read the content of the video file and encode it to base64
with open(video_path, "rb") as video_file:
video_base64 = base64.b64encode(video_file.read()).decode('utf-8')
# Prepend the appropriate data URI header with MIME type
video_data_uri = 'data:video/mp4;base64,' + video_base64
# Clean up the video file to avoid filling up storage
# os.remove(video_path)
return video_data_uri
with gr.Blocks() as demo:
with gr.Column():
gr.HTML("""
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
<div style="text-align: center; color: black;">
<p style="color: black;">This space is a REST API to programmatically generate MP4 videos.</p>
<p style="color: black;">Interested in using it? Look no further than the <a href="https://huggingface.co/spaces/Vchitect/LaVie" target="_blank">original space</a>!</p>
</div>
</div>""")
secret_token = gr.Textbox(label="Secret token")
prompt = gr.Textbox(value="", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
infer_type = gr.Dropdown(['ddpm','ddim','eulerdiscrete'], label='infer_type',value='ddim')
ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
cfg = gr.Number(label="guidance_scale",value=7.5)
submit_btn = gr.Button("Generate video")
base64_out = gr.Textbox(label="Base64 Video")
submit_btn.click(
fn=generate_video,
inputs=[secret_token, prompt], # seed_inp, ddim_steps, cfg, infer_type],
outputs=base64_out,
api_name='run',
)
demo.queue(max_size=12).launch()
|