import gradio as gr import os import spaces import torch import argparse import torchvision from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from omegaconf import OmegaConf from transformers import T5EncoderModel, T5Tokenizer, BitsAndBytesConfig import os, sys sys.path.append(os.path.split(sys.path[0])[0]) from sample.pipeline_latte import LattePipeline from models import get_models import imageio from torchvision.utils import save_image parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml") args = parser.parse_args() args = OmegaConf.load(args.config) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" transformer_model = get_models(args).to(device, dtype=torch.float16) if args.enable_vae_temporal_decoder: vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) else: vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16), device_map="auto", ) # set eval mode transformer_model.eval() vae.eval() text_encoder.eval() @spaces.GPU def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step): torch.manual_seed(seed) if sample_method == 'DDIM': scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'EulerDiscrete': scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DDPM': scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type, clip_sample=False) elif sample_method == 'DPMSolverMultistep': scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DPMSolverSinglestep': scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'PNDM': scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'HeunDiscrete': scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'EulerAncestralDiscrete': scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'DEISMultistep': scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) elif sample_method == 'KDPM2AncestralDiscrete': scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule, variance_type=args.variance_type) pipe_tmp = LattePipeline.from_pretrained( args.pretrained_model_path, transformer=None, text_encoder=text_encoder, device_map="balanced",) prompt_embeds, negative_prompt_embeds = pipe_tmp.encode_prompt(text_input, negative_prompt="") videogen_pipeline = LattePipeline(vae=vae, # text_encoder=text_encoder, text_encoder=None, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model).to(device) # videogen_pipeline.enable_xformers_memory_efficient_attention() videos = videogen_pipeline( # text_input, prompt_embeds=prompt_embeds, negative_prompt=None, negative_prompt_embeds=negative_prompt_embeds, video_length=video_length, height=height, width=width, num_inference_steps=diffusion_step, guidance_scale=scfg_scale, enable_temporal_attentions=args.enable_temporal_attentions, num_images_per_prompt=1, mask_feature=True, enable_vae_temporal_decoder=args.enable_vae_temporal_decoder ).video save_path = args.save_img_path + 'temp' + '.mp4' # torchvision.io.write_video(save_path, videos[0], fps=8) imageio.mimwrite(save_path, videos[0], fps=8, quality=7) return save_path if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) intro = """

Latte: Latent Diffusion Transformer for Video Generation

""" with gr.Blocks() as demo: # gr.HTML(intro) # with gr.Accordion("README", open=False): # gr.HTML( # """ #

# project page | paper #

# We will continue update Latte. # """ # ) gr.Markdown("
Latte: Latent Diffusion Transformer for Video Generation
") gr.Markdown( """

Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!

""" ) gr.Markdown( """
[Arxiv Report] | [Project Page] | [Github]
""" ) with gr.Row(): with gr.Column(visible=True) as input_raws: with gr.Row(): with gr.Column(scale=1.0): text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt") with gr.Row(): with gr.Column(scale=0.5): sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM") with gr.Column(scale=0.5): video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16) with gr.Row(): with gr.Column(scale=1.0): scfg_scale = gr.Slider( minimum=1, maximum=50, value=7.5, step=0.1, interactive=True, label="Guidence Scale", ) with gr.Row(): with gr.Column(scale=1.0): seed = gr.Slider( minimum=1, maximum=2147483647, value=100, step=1, interactive=True, label="Seed", ) with gr.Row(): with gr.Column(scale=0.5): height = gr.Slider( minimum=256, maximum=768, value=512, step=16, interactive=False, label="Height", ) # with gr.Row(): with gr.Column(scale=0.5): width = gr.Slider( minimum=256, maximum=768, value=512, step=16, interactive=False, label="Width", ) with gr.Row(): with gr.Column(scale=1.0): diffusion_step = gr.Slider( minimum=20, maximum=250, value=50, step=1, interactive=True, label="Sampling Step", ) with gr.Column(scale=0.6, visible=True) as video_upload: output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360) with gr.Row(): with gr.Column(scale=1.0, min_width=0): run = gr.Button(value="Generate", variant='primary') EXAMPLES = [ ["3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["A wizard wearing a pointed hat and a blue robe with white stars casting a spell that shoots lightning from his hand and holding an old tome in his other hand.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["Cinematic trailer for a group of samoyed puppies learning to become chefs.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", "DDIM", 7.5, 100, 512, 512, 16, 50], ["A cyborg koala dj in front of aturntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-f, fantasy, intricate, neon light, soft light smooth, sharp focus, illustration.", "DDIM", 7.5, 100, 512, 512, 16, 50], ] examples = gr.Examples( examples = EXAMPLES, fn = gen_video, inputs=[text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], outputs=[output], cache_examples=True, # cache_examples="lazy", ) run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output]) demo.launch(debug=False, share=True)