import math import os import torch import argparse import torchvision from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from omegaconf import OmegaConf from torchvision.utils import save_image from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer import os, sys from opensora.models.ae import ae_stride_config, getae, getae_wrapper from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper from opensora.models.diffusion.latte.modeling_latte import LatteT2V from opensora.models.text_encoder import get_text_enc from opensora.utils.utils import save_video_grid sys.path.append(os.path.split(sys.path[0])[0]) from pipeline_videogen import VideoGenPipeline import imageio def main(args): # torch.manual_seed(args.seed) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor # Load model: transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) transformer_model.force_images = args.force_images tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) vae.latent_size = latent_size if args.force_images: video_length = 1 ext = 'jpg' else: ext = 'mp4' # set eval mode transformer_model.eval() vae.eval() text_encoder.eval() if args.sample_method == 'DDIM': ######### scheduler = DDIMScheduler() elif args.sample_method == 'EulerDiscrete': scheduler = EulerDiscreteScheduler() elif args.sample_method == 'DDPM': ############# scheduler = DDPMScheduler() elif args.sample_method == 'DPMSolverMultistep': scheduler = DPMSolverMultistepScheduler() elif args.sample_method == 'DPMSolverSinglestep': scheduler = DPMSolverSinglestepScheduler() elif args.sample_method == 'PNDM': scheduler = PNDMScheduler() elif args.sample_method == 'HeunDiscrete': ######## scheduler = HeunDiscreteScheduler() elif args.sample_method == 'EulerAncestralDiscrete': scheduler = EulerAncestralDiscreteScheduler() elif args.sample_method == 'DEISMultistep': scheduler = DEISMultistepScheduler() elif args.sample_method == 'KDPM2AncestralDiscrete': ######### scheduler = KDPM2AncestralDiscreteScheduler() print('videogen_pipeline', device) videogen_pipeline = VideoGenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model).to(device=device) # videogen_pipeline.enable_xformers_memory_efficient_attention() if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) video_grids = [] if not isinstance(args.text_prompt, list): args.text_prompt = [args.text_prompt] if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): text_prompt = open(args.text_prompt[0], 'r').readlines() args.text_prompt = [i.strip() for i in text_prompt] for prompt in args.text_prompt: print('Processing the ({}) prompt'.format(prompt)) videos = videogen_pipeline(prompt, video_length=video_length, height=image_size, width=image_size, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, enable_temporal_attentions=not args.force_images, num_images_per_prompt=1, mask_feature=True, ).video try: if args.force_images: videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w save_image(videos / 255.0, os.path.join(args.save_img_path, prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), nrow=1, normalize=True, value_range=(0, 1)) # t c h w else: imageio.mimwrite( os.path.join( args.save_img_path, prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}' ), videos[0], fps=args.fps, quality=9) # highest quality is 10, lowest is 0 except: print('Error when saving {}'.format(prompt)) video_grids.append(videos) video_grids = torch.cat(video_grids, dim=0) # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) if args.force_images: save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) else: video_grids = save_video_grid(video_grids) imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9) print('save path {}'.format(args.save_img_path)) # save_videos_grid(video, f"./{prompt}.gif") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') parser.add_argument("--version", type=str, default='65x512x512', choices=['65x512x512', '65x256x256', '17x256x256']) parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") parser.add_argument("--guidance_scale", type=float, default=7.5) parser.add_argument("--sample_method", type=str, default="PNDM") parser.add_argument("--num_sampling_steps", type=int, default=50) parser.add_argument("--fps", type=int, default=24) parser.add_argument("--run_time", type=int, default=0) parser.add_argument("--text_prompt", nargs='+') parser.add_argument('--force_images', action='store_true') parser.add_argument('--tile_overlap_factor', type=float, default=0.25) parser.add_argument('--enable_tiling', action='store_true') args = parser.parse_args() main(args)