from pathlib import Path from typing import Any, Optional, Union, Callable import pytorch_lightning as pl import torch from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler from diffusers.utils.import_utils import is_xformers_available from einops import rearrange, repeat from transformers import CLIPTextModel, CLIPTokenizer from t2v_enhanced.utils.video_utils import ResultProcessor, save_videos_grid, video_naming from t2v_enhanced.model import pl_module_params_controlnet from t2v_enhanced.model.diffusers_conditional.models.controlnet.controlnet import ControlNetModel from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel from t2v_enhanced.model.diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import TextToVideoSDPipeline from t2v_enhanced.model.diffusers_conditional.models.controlnet.processor import set_use_memory_efficient_attention_xformers from t2v_enhanced.model.diffusers_conditional.models.controlnet.mask_generator import MaskGenerator import warnings # from warnings import warn from t2v_enhanced.utils.iimage import IImage from t2v_enhanced.utils.object_loader import instantiate_object from t2v_enhanced.utils.object_loader import get_class class VideoLDM(pl.LightningModule): def __init__(self, inference_params: pl_module_params_controlnet.InferenceParams, opt_params: pl_module_params_controlnet.OptimizerParams = None, unet_params: pl_module_params_controlnet.UNetParams = None, ): super().__init__() self.inference_generator = torch.Generator(device=self.device) self.opt_params = opt_params self.unet_params = unet_params print(f"Base pipeline from: {unet_params.pipeline_repo}") print(f"Pipeline class {unet_params.pipeline_class}") # load entire pipeline (unet, vq, text encoder,..) state_dict_control_model = None state_dict_fusion = None state_dict_base_model = None if len(opt_params.load_trained_controlnet_from_ckpt) > 0: state_dict_ckpt = torch.load(opt_params.load_trained_controlnet_from_ckpt, map_location=torch.device("cpu")) state_dict_ckpt = state_dict_ckpt["state_dict"] state_dict_control_model = dict(filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items())) state_dict_control_model = {k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items()} state_dict_fusion = dict(filter(lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items())) state_dict_fusion = {k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items()} del state_dict_ckpt state_dict_proj = None state_dict_ckpt = None if hasattr(unet_params, "use_resampler") and unet_params.use_resampler: num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None if unet_params.use_image_tokens_ctrl: num_queries = unet_params.num_control_input_frames assert unet_params.frame_expansion == "none" image_encoder = self.unet_params.image_encoder embedding_dim = image_encoder.embedding_dim resampler = instantiate_object(self.unet_params.resampler_cls, video_length=num_queries, embedding_dim=embedding_dim, input_tokens=image_encoder.num_tokens, num_layers=self.unet_params.resampler_merging_layers, aggregation=self.unet_params.aggregation) state_dict_proj = None self.resampler = resampler self.image_encoder = image_encoder noise_scheduler = DDPMScheduler.from_pretrained(self.unet_params.pipeline_repo, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(self.unet_params.pipeline_repo, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae") base_model = UNet3DConditionModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="unet", low_cpu_mem_usage=False, device_map=None, merging_mode=self.unet_params.merging_mode_base, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_main, use_fps_conditioning=self.opt_params.use_fps_conditioning, unet_params=unet_params) if state_dict_base_model is not None: miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False) assert len(unex) == 0 if len(miss) > 0: warnings.warn(f"Missing keys when loading base_mode:{miss}") del state_dict_base_model if state_dict_fusion is not None: miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False) assert len(unex) == 0 del state_dict_fusion print("PIPE LOADING DONE") self.noise_scheduler = noise_scheduler self.tokenizer = tokenizer self.text_encoder = text_encoder self.vae = vae self.unet = ControlNetModel.from_unet( unet=base_model, conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels, downsample_controlnet_cond=unet_params.downsample_controlnet_cond, num_frames=unet_params.num_frames if (unet_params.frame_expansion != "none" or self.unet_params.use_controlnet_mask) else unet_params.num_control_input_frames, num_frame_conditioning=unet_params.num_control_input_frames, frame_expansion=unet_params.frame_expansion, pre_transformer_in_cond=unet_params.pre_transformer_in_cond, num_tranformers=unet_params.num_tranformers, vae=AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae"), zero_conv_mode=unet_params.zero_conv_mode, merging_mode=unet_params.merging_mode, condition_encoder=unet_params.condition_encoder, use_controlnet_mask=unet_params.use_controlnet_mask, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_ctrl, unet_params=unet_params, use_image_encoder_normalization=unet_params.use_image_encoder_normalization, ) if state_dict_control_model is not None: miss, unex = self.unet.load_state_dict( state_dict_control_model, strict=False) if len(miss) > 0: print("WARNING: Loading checkpoint for controlnet misses states") print(miss) if unet_params.frame_expansion == "none": attention_params = self.unet_params.attention_mask_params assert not attention_params.temporal_self_attention_only_on_conditioning and not attention_params.spatial_attend_on_condition_frames and not attention_params.temp_attend_on_neighborhood_of_condition_frames self.mask_generator = MaskGenerator( self.unet_params.attention_mask_params, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) self.mask_generator_base = MaskGenerator( self.unet_params.attention_mask_params_base, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) if state_dict_proj is not None and unet_params.use_image_tokens_main: if unet_params.use_image_tokens_main: missing, unexpected = base_model.load_state_dict( state_dict_proj, strict=False) elif unet_params.use_image_tokens_ctrl: missing, unexpected = unet.load_state_dict( state_dict_proj, strict=False) assert len(unexpected) == 0, f"Unexpected entries {unexpected}" print(f"Missing keys state proj = {missing}") del state_dict_proj base_model.requires_grad_(False) self.base_model = base_model self.unet.requires_grad_(False) self.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) layers_config = opt_params.layers_config layers_config.set_requires_grad(self) print("CUSTOM XFORMERS ATTENTION USED.") if is_xformers_available(): set_use_memory_efficient_attention_xformers(self.unet, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames, attention_mask_params=self.unet_params.attention_mask_params ) set_use_memory_efficient_attention_xformers(self.base_model, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames, attention_mask_params=self.unet_params.attention_mask_params_base) if len(inference_params.scheduler_cls) > 0: inf_scheduler_class = get_class(inference_params.scheduler_cls) else: inf_scheduler_class = DDIMScheduler inf_scheduler = inf_scheduler_class.from_pretrained( self.unet_params.pipeline_repo, subfolder="scheduler") inference_pipeline = TextToVideoSDPipeline(vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.base_model, controlnet=self.unet, scheduler=inf_scheduler ) inference_pipeline.set_noise_generator(self.opt_params.noise_generator) inference_pipeline.enable_vae_slicing() inference_pipeline.set_progress_bar_config(disable=True) self.inference_params = inference_params self.inference_pipeline = inference_pipeline self.result_processor = ResultProcessor(fps=self.inference_params.frame_rate, n_frames=self.inference_params.video_length) def on_start(self): datamodule = self.trainer._data_connector._datahook_selector.datamodule pipe_id_model = self.unet_params.pipeline_repo for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]: dataset = getattr(datamodule, dataset_key, None) if dataset is not None and hasattr(dataset, "model_id"): pipe_id_data = dataset.model_id assert pipe_id_model == pipe_id_data, f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'" self.result_processor.set_logger(self.logger) def on_predict_start(self) -> None: self.on_start() # pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) # pipe.set_progress_bar_config(disable=True) # self.first_stage = pipe.to(self.device) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: cfg = self.trainer.predict_cfg result_file_stem = cfg["result_file_stem"] storage_fol = Path(cfg['predict_dir']) prompts = [cfg["prompt"]] inference_params: pl_module_params_controlnet.InferenceParams = self.inference_params conditioning_type = inference_params.conditioning_type n_autoregressive_generations = inference_params.n_autoregressive_generations mode = inference_params.mode start_from_real_input = inference_params.start_from_real_input assert isinstance(prompts, list) prompts = n_autoregressive_generations * prompts self.inference_generator.manual_seed(self.inference_params.seed) assert self.unet_params.num_control_input_frames == self.inference_params.video_length//2, f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}" chunks_conditional = [] batch_size = 1 shape = (batch_size, self.inference_pipeline.unet.config.in_channels, self.inference_params.video_length, self.inference_pipeline.unet.config.sample_size, self.inference_pipeline.unet.config.sample_size) for idx, prompt in enumerate(prompts): if idx > 0: content = sample*2-1 content_latent = self.vae.encode(content).latent_dist.sample() * self.vae.config.scaling_factor content_latent = rearrange(content_latent, "F C W H -> 1 C F W H") content_latent = content_latent[:, :, self.unet_params.num_control_input_frames:].detach().clone() if hasattr(self.inference_pipeline, "noise_generator"): latents = self.inference_pipeline.noise_generator.sample_noise(shape=shape, device=self.device, dtype=self.dtype, generator=self.inference_generator, content=content_latent if idx > 0 else None) else: latents = None if idx == 0: sample = cfg["video"] else: if inference_params.conditioning_type == "fixed": context = chunks_conditional[0][:self.unet_params.num_frame_conditioning] context = [context] context = [2*sample-1 for sample in context] input_frames_conditioning = torch.cat(context).detach().clone() input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") elif inference_params.conditioning_type == "last_chunk": input_frames_conditioning = condition_input[:, -self.unet_params.num_frame_conditioning:].detach().clone() elif inference_params.conditioning_type == "past": context = [sample[:self.unet_params.num_control_input_frames] for sample in chunks_conditional] context = [2*sample-1 for sample in context] input_frames_conditioning = torch.cat(context).detach().clone() input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") else: raise NotImplementedError() input_frames = condition_input[:, self.unet_params.num_control_input_frames:].detach().clone() sample = self(prompt, input_frames=input_frames, input_frames_conditioning=input_frames_conditioning, latents=latents) if hasattr(self.inference_pipeline, "reset_noise_generator_state"): self.inference_pipeline.reset_noise_generator_state() condition_input = rearrange(sample, "F C W H -> 1 F C W H") condition_input = (2*condition_input)-1 # range: [-1,1] # store first 16 frames, then always last 8 of a chunk chunks_conditional.append(sample) result_formats = self.inference_params.result_formats # result_formats = [gif", "mp4"] concat_video = self.inference_params.concat_video def IImage_normalized(x): return IImage(x, vmin=0, vmax=1) for result_format in result_formats: save_format = result_format.replace("eval_", "") merged_video = None for chunk_idx, (prompt, video) in enumerate(zip(prompts, chunks_conditional)): if chunk_idx == 0: current_video = IImage_normalized(video) else: current_video = IImage_normalized(video[self.unet_params.num_control_input_frames:]) if merged_video is None: merged_video = current_video else: merged_video &= current_video if concat_video: filename = video_naming(prompts[0], save_format, batch_idx, 0) result_file_video = (storage_fol / filename).absolute().as_posix() result_file_video = (Path(result_file_video).parent / (result_file_stem+Path(result_file_video).suffix)).as_posix() self.result_processor.save_to_file(video=merged_video.torch(vmin=0, vmax=1), prompt=prompts[0], video_filename=result_file_video, prompt_on_vid=False) def forward(self, prompt, input_frames=None, input_frames_conditioning=None, latents=None): call_params = self.inference_params.to_dict() print(f"INFERENCE PARAMS = {call_params}") call_params["prompt"] = prompt call_params["image"] = input_frames call_params["num_frames"] = self.inference_params.video_length call_params["return_dict"] = False call_params["output_type"] = "pt_t2v" call_params["mask_generator"] = self.mask_generator call_params["precision"] = "16" if self.trainer.precision.startswith("16") else "32" call_params["no_text_condition_control"] = self.opt_params.no_text_condition_control call_params["weight_control_sample"] = self.unet_params.weight_control_sample call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch call_params["img_cond_resampler"] = self.resampler if self.unet_params.use_resampler else None call_params["img_cond_encoder"] = self.image_encoder if self.unet_params.use_resampler else None call_params["input_frames_conditioning"] = input_frames_conditioning call_params["cfg_text_image"] = self.unet_params.cfg_text_image call_params["use_of"] = self.unet_params.use_of if latents is not None: call_params["latents"] = latents sample = self.inference_pipeline(generator=self.inference_generator, **call_params) return sample