Latte-1 / demo.py
maxin-cn's picture
Update demo.py
ed933a7 verified
raw
history blame contribute delete
No virus
17.3 kB
import gradio as gr
import os
import spaces
import torch
import argparse
import torchvision
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
EulerDiscreteScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from omegaconf import OmegaConf
from transformers import T5EncoderModel, T5Tokenizer, BitsAndBytesConfig
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from sample.pipeline_latte import LattePipeline
from models import get_models
import imageio
from torchvision.utils import save_image
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
args = parser.parse_args()
args = OmegaConf.load(args.config)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_model = get_models(args).to(device, dtype=torch.float16)
if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
else:
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path,
subfolder="text_encoder",
quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
device_map="auto",
)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
@spaces.GPU
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
torch.manual_seed(seed)
if sample_method == 'DDIM':
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'EulerDiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DDPM':
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'DPMSolverMultistep':
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DPMSolverSinglestep':
scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'PNDM':
scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'HeunDiscrete':
scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'EulerAncestralDiscrete':
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DEISMultistep':
scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'KDPM2AncestralDiscrete':
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
pipe_tmp = LattePipeline.from_pretrained(
args.pretrained_model_path,
transformer=None,
text_encoder=text_encoder,
device_map="balanced",)
prompt_embeds, negative_prompt_embeds = pipe_tmp.encode_prompt(text_input, negative_prompt="")
videogen_pipeline = LattePipeline(vae=vae,
# text_encoder=text_encoder,
text_encoder=None,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
videos = videogen_pipeline(
# text_input,
prompt_embeds=prompt_embeds,
negative_prompt=None,
negative_prompt_embeds=negative_prompt_embeds,
video_length=video_length,
height=height,
width=width,
num_inference_steps=diffusion_step,
guidance_scale=scfg_scale,
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
).video
save_path = args.save_img_path + 'temp' + '.mp4'
# torchvision.io.write_video(save_path, videos[0], fps=8)
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
return save_path
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1>
</div>
"""
with gr.Blocks() as demo:
# gr.HTML(intro)
# with gr.Accordion("README", open=False):
# gr.HTML(
# """
# <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
# <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a>
# </p>
# We will continue update Latte.
# """
# )
gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>")
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div>
"""
)
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
[<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div>
"""
)
with gr.Row():
with gr.Column(visible=True) as input_raws:
with gr.Row():
with gr.Column(scale=1.0):
text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")
with gr.Row():
with gr.Column(scale=0.5):
sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")
with gr.Column(scale=0.5):
video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
with gr.Row():
with gr.Column(scale=1.0):
scfg_scale = gr.Slider(
minimum=1,
maximum=50,
value=7.5,
step=0.1,
interactive=True,
label="Guidence Scale",
)
with gr.Row():
with gr.Column(scale=1.0):
seed = gr.Slider(
minimum=1,
maximum=2147483647,
value=100,
step=1,
interactive=True,
label="Seed",
)
with gr.Row():
with gr.Column(scale=0.5):
height = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Height",
)
# with gr.Row():
with gr.Column(scale=0.5):
width = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Width",
)
with gr.Row():
with gr.Column(scale=1.0):
diffusion_step = gr.Slider(
minimum=20,
maximum=250,
value=50,
step=1,
interactive=True,
label="Sampling Step",
)
with gr.Column(scale=0.6, visible=True) as video_upload:
output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360)
with gr.Row():
with gr.Column(scale=1.0, min_width=0):
run = gr.Button(value="Generate", variant='primary')
EXAMPLES = [
["3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A wizard wearing a pointed hat and a blue robe with white stars casting a spell that shoots lightning from his hand and holding an old tome in his other hand.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["Cinematic trailer for a group of samoyed puppies learning to become chefs.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A cyborg koala dj in front of aturntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-f, fantasy, intricate, neon light, soft light smooth, sharp focus, illustration.", "DDIM", 7.5, 100, 512, 512, 16, 50],
]
examples = gr.Examples(
examples = EXAMPLES,
fn = gen_video,
inputs=[text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step],
outputs=[output],
cache_examples=True,
# cache_examples="lazy",
)
run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
demo.launch(debug=False, share=True)