Open-Sora-Plan-v1-0-0 / opensora /serve /gradio_web_server.py
fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame
5.11 kB
import argparse
import sys
import os
import random
import imageio
import torch
from diffusers import PNDMScheduler
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from datetime import datetime
from typing import List, Union
import gradio as gr
import numpy as np
from gradio.components import Textbox, Video, Image
from transformers import T5Tokenizer, T5EncoderModel
from opensora.models.ae import ae_stride_config, getae, getae_wrapper
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
from opensora.sample.pipeline_videogen import VideoGenPipeline
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
@torch.inference_mode()
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
seed = int(randomize_seed_fn(seed, randomize_seed))
set_env(seed)
video_length = transformer_model.config.video_length if not force_images else 1
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
videos = videogen_pipeline(prompt,
video_length=video_length,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
enable_temporal_attentions=not force_images,
num_images_per_prompt=1,
mask_feature=True,
).video
torch.cuda.empty_cache()
videos = videos[0]
tmp_save_path = 'tmp.mp4'
imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
return tmp_save_path, prompt, display_model_info, seed
if __name__ == '__main__':
args = type('args', (), {
'ae': 'CausalVAEModel_4x8x8',
'force_images': False,
'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0',
'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl',
'version': '65x512x512'
})
device = torch.device('cuda:0')
# Load model:
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device)
vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
vae.vae.enable_tiling()
image_size = int(args.version.split('x')[1])
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
vae.latent_size = latent_size
transformer_model.force_images = args.force_images
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir",
torch_dtype=torch.float16).to(device)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler()
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device=device)
demo = gr.Interface(
fn=generate_img,
inputs=[Textbox(label="",
placeholder="Please enter your prompt. \n"),
gr.Slider(
label='Sample Steps',
minimum=1,
maximum=500,
value=50,
step=10
),
gr.Slider(
label='Guidance Scale',
minimum=0.1,
maximum=30.0,
value=10.0,
step=0.1
),
gr.Slider(
label="Seed",
minimum=0,
maximum=203279,
step=1,
value=0,
),
gr.Checkbox(label="Randomize seed", value=True),
gr.Checkbox(label="Generate image (1 frame video)", value=False),
],
outputs=[Video(label="Vid", width=512, height=512),
Textbox(label="input prompt"),
Textbox(label="model info"),
gr.Slider(label='seed')],
title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css,
examples=examples,
)
demo.launch()