Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Any, Optional, Union, Callable | |
import pytorch_lightning as pl | |
import torch | |
from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from einops import rearrange, repeat | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from t2v_enhanced.utils.video_utils import ResultProcessor, save_videos_grid, video_naming | |
from t2v_enhanced.model import pl_module_params_controlnet | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.controlnet import ControlNetModel | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import TextToVideoSDPipeline | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.processor import set_use_memory_efficient_attention_xformers | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.mask_generator import MaskGenerator | |
import warnings | |
# from warnings import warn | |
from t2v_enhanced.utils.iimage import IImage | |
from t2v_enhanced.utils.object_loader import instantiate_object | |
from t2v_enhanced.utils.object_loader import get_class | |
class VideoLDM(pl.LightningModule): | |
def __init__(self, | |
inference_params: pl_module_params_controlnet.InferenceParams, | |
opt_params: pl_module_params_controlnet.OptimizerParams = None, | |
unet_params: pl_module_params_controlnet.UNetParams = None, | |
): | |
super().__init__() | |
self.inference_generator = torch.Generator(device=self.device) | |
self.opt_params = opt_params | |
self.unet_params = unet_params | |
print(f"Base pipeline from: {unet_params.pipeline_repo}") | |
print(f"Pipeline class {unet_params.pipeline_class}") | |
# load entire pipeline (unet, vq, text encoder,..) | |
state_dict_control_model = None | |
state_dict_fusion = None | |
state_dict_base_model = None | |
if len(opt_params.load_trained_controlnet_from_ckpt) > 0: | |
state_dict_ckpt = torch.load(opt_params.load_trained_controlnet_from_ckpt, map_location=torch.device("cpu")) | |
state_dict_ckpt = state_dict_ckpt["state_dict"] | |
state_dict_control_model = dict(filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items())) | |
state_dict_control_model = {k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items()} | |
state_dict_fusion = dict(filter(lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items())) | |
state_dict_fusion = {k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items()} | |
del state_dict_ckpt | |
state_dict_proj = None | |
state_dict_ckpt = None | |
if hasattr(unet_params, "use_resampler") and unet_params.use_resampler: | |
num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None | |
if unet_params.use_image_tokens_ctrl: | |
num_queries = unet_params.num_control_input_frames | |
assert unet_params.frame_expansion == "none" | |
image_encoder = self.unet_params.image_encoder | |
embedding_dim = image_encoder.embedding_dim | |
resampler = instantiate_object(self.unet_params.resampler_cls, video_length=num_queries, embedding_dim=embedding_dim, input_tokens=image_encoder.num_tokens, num_layers=self.unet_params.resampler_merging_layers, aggregation=self.unet_params.aggregation) | |
state_dict_proj = None | |
self.resampler = resampler | |
self.image_encoder = image_encoder | |
noise_scheduler = DDPMScheduler.from_pretrained(self.unet_params.pipeline_repo, subfolder="scheduler") | |
tokenizer = CLIPTokenizer.from_pretrained(self.unet_params.pipeline_repo, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae") | |
base_model = UNet3DConditionModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="unet", low_cpu_mem_usage=False, device_map=None, merging_mode=self.unet_params.merging_mode_base, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_main, use_fps_conditioning=self.opt_params.use_fps_conditioning, unet_params=unet_params) | |
if state_dict_base_model is not None: | |
miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False) | |
assert len(unex) == 0 | |
if len(miss) > 0: | |
warnings.warn(f"Missing keys when loading base_mode:{miss}") | |
del state_dict_base_model | |
if state_dict_fusion is not None: | |
miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False) | |
assert len(unex) == 0 | |
del state_dict_fusion | |
print("PIPE LOADING DONE") | |
self.noise_scheduler = noise_scheduler | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.vae = vae | |
self.unet = ControlNetModel.from_unet( | |
unet=base_model, | |
conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels, | |
downsample_controlnet_cond=unet_params.downsample_controlnet_cond, | |
num_frames=unet_params.num_frames if (unet_params.frame_expansion != "none" or self.unet_params.use_controlnet_mask) else unet_params.num_control_input_frames, | |
num_frame_conditioning=unet_params.num_control_input_frames, | |
frame_expansion=unet_params.frame_expansion, | |
pre_transformer_in_cond=unet_params.pre_transformer_in_cond, | |
num_tranformers=unet_params.num_tranformers, | |
vae=AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae"), | |
zero_conv_mode=unet_params.zero_conv_mode, | |
merging_mode=unet_params.merging_mode, | |
condition_encoder=unet_params.condition_encoder, | |
use_controlnet_mask=unet_params.use_controlnet_mask, | |
use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_ctrl, | |
unet_params=unet_params, | |
use_image_encoder_normalization=unet_params.use_image_encoder_normalization, | |
) | |
if state_dict_control_model is not None: | |
miss, unex = self.unet.load_state_dict( | |
state_dict_control_model, strict=False) | |
if len(miss) > 0: | |
print("WARNING: Loading checkpoint for controlnet misses states") | |
print(miss) | |
if unet_params.frame_expansion == "none": | |
attention_params = self.unet_params.attention_mask_params | |
assert not attention_params.temporal_self_attention_only_on_conditioning and not attention_params.spatial_attend_on_condition_frames and not attention_params.temp_attend_on_neighborhood_of_condition_frames | |
self.mask_generator = MaskGenerator( | |
self.unet_params.attention_mask_params, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) | |
self.mask_generator_base = MaskGenerator( | |
self.unet_params.attention_mask_params_base, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames) | |
if state_dict_proj is not None and unet_params.use_image_tokens_main: | |
if unet_params.use_image_tokens_main: | |
missing, unexpected = base_model.load_state_dict( | |
state_dict_proj, strict=False) | |
elif unet_params.use_image_tokens_ctrl: | |
missing, unexpected = unet.load_state_dict( | |
state_dict_proj, strict=False) | |
assert len(unexpected) == 0, f"Unexpected entries {unexpected}" | |
print(f"Missing keys state proj = {missing}") | |
del state_dict_proj | |
base_model.requires_grad_(False) | |
self.base_model = base_model | |
self.unet.requires_grad_(False) | |
self.text_encoder.requires_grad_(False) | |
self.vae.requires_grad_(False) | |
layers_config = opt_params.layers_config | |
layers_config.set_requires_grad(self) | |
print("CUSTOM XFORMERS ATTENTION USED.") | |
if is_xformers_available(): | |
set_use_memory_efficient_attention_xformers(self.unet, num_frame_conditioning=self.unet_params.num_control_input_frames, | |
num_frames=self.unet_params.num_frames, | |
attention_mask_params=self.unet_params.attention_mask_params | |
) | |
set_use_memory_efficient_attention_xformers(self.base_model, num_frame_conditioning=self.unet_params.num_control_input_frames, | |
num_frames=self.unet_params.num_frames, | |
attention_mask_params=self.unet_params.attention_mask_params_base) | |
if len(inference_params.scheduler_cls) > 0: | |
inf_scheduler_class = get_class(inference_params.scheduler_cls) | |
else: | |
inf_scheduler_class = DDIMScheduler | |
inf_scheduler = inf_scheduler_class.from_pretrained( | |
self.unet_params.pipeline_repo, subfolder="scheduler") | |
inference_pipeline = TextToVideoSDPipeline(vae=self.vae, | |
text_encoder=self.text_encoder, | |
tokenizer=self.tokenizer, | |
unet=self.base_model, | |
controlnet=self.unet, | |
scheduler=inf_scheduler | |
) | |
inference_pipeline.set_noise_generator(self.opt_params.noise_generator) | |
inference_pipeline.enable_vae_slicing() | |
inference_pipeline.set_progress_bar_config(disable=True) | |
self.inference_params = inference_params | |
self.inference_pipeline = inference_pipeline | |
self.result_processor = ResultProcessor(fps=self.inference_params.frame_rate, n_frames=self.inference_params.video_length) | |
def on_start(self): | |
datamodule = self.trainer._data_connector._datahook_selector.datamodule | |
pipe_id_model = self.unet_params.pipeline_repo | |
for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]: | |
dataset = getattr(datamodule, dataset_key, None) | |
if dataset is not None and hasattr(dataset, "model_id"): | |
pipe_id_data = dataset.model_id | |
assert pipe_id_model == pipe_id_data, f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'" | |
self.result_processor.set_logger(self.logger) | |
def on_predict_start(self) -> None: | |
self.on_start() | |
# pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16") | |
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
# pipe.set_progress_bar_config(disable=True) | |
# self.first_stage = pipe.to(self.device) | |
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: | |
cfg = self.trainer.predict_cfg | |
result_file_stem = cfg["result_file_stem"] | |
storage_fol = Path(cfg['predict_dir']) | |
prompts = [cfg["prompt"]] | |
inference_params: pl_module_params_controlnet.InferenceParams = self.inference_params | |
conditioning_type = inference_params.conditioning_type | |
# n_autoregressive_generations = inference_params.n_autoregressive_generations | |
n_autoregressive_generations = cfg["n_autoregressive_generations"] | |
mode = inference_params.mode | |
start_from_real_input = inference_params.start_from_real_input | |
assert isinstance(prompts, list) | |
prompts = n_autoregressive_generations * prompts | |
self.inference_generator.manual_seed(self.inference_params.seed) | |
assert self.unet_params.num_control_input_frames == self.inference_params.video_length//2, f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}" | |
chunks_conditional = [] | |
batch_size = 1 | |
shape = (batch_size, self.inference_pipeline.unet.config.in_channels, self.inference_params.video_length, | |
self.inference_pipeline.unet.config.sample_size, self.inference_pipeline.unet.config.sample_size) | |
for idx, prompt in enumerate(prompts): | |
if idx > 0: | |
content = sample*2-1 | |
content_latent = self.vae.encode(content).latent_dist.sample() * self.vae.config.scaling_factor | |
content_latent = rearrange(content_latent, "F C W H -> 1 C F W H") | |
content_latent = content_latent[:, :, self.unet_params.num_control_input_frames:].detach().clone() | |
if hasattr(self.inference_pipeline, "noise_generator"): | |
latents = self.inference_pipeline.noise_generator.sample_noise(shape=shape, device=self.device, dtype=self.dtype, generator=self.inference_generator, content=content_latent if idx > 0 else None) | |
else: | |
latents = None | |
if idx == 0: | |
sample = cfg["video"] | |
else: | |
if inference_params.conditioning_type == "fixed": | |
context = chunks_conditional[0][:self.unet_params.num_frame_conditioning] | |
context = [context] | |
context = [2*sample-1 for sample in context] | |
input_frames_conditioning = torch.cat(context).detach().clone() | |
input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") | |
elif inference_params.conditioning_type == "last_chunk": | |
input_frames_conditioning = condition_input[:, -self.unet_params.num_frame_conditioning:].detach().clone() | |
elif inference_params.conditioning_type == "past": | |
context = [sample[:self.unet_params.num_control_input_frames] for sample in chunks_conditional] | |
context = [2*sample-1 for sample in context] | |
input_frames_conditioning = torch.cat(context).detach().clone() | |
input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H") | |
else: | |
raise NotImplementedError() | |
input_frames = condition_input[:, self.unet_params.num_control_input_frames:].detach().clone() | |
sample = self(prompt, input_frames=input_frames, input_frames_conditioning=input_frames_conditioning, latents=latents) | |
if hasattr(self.inference_pipeline, "reset_noise_generator_state"): | |
self.inference_pipeline.reset_noise_generator_state() | |
condition_input = rearrange(sample, "F C W H -> 1 F C W H") | |
condition_input = (2*condition_input)-1 # range: [-1,1] | |
# store first 16 frames, then always last 8 of a chunk | |
chunks_conditional.append(sample) | |
result_formats = self.inference_params.result_formats | |
# result_formats = [gif", "mp4"] | |
concat_video = self.inference_params.concat_video | |
def IImage_normalized(x): return IImage(x, vmin=0, vmax=1) | |
for result_format in result_formats: | |
save_format = result_format.replace("eval_", "") | |
merged_video = None | |
for chunk_idx, (prompt, video) in enumerate(zip(prompts, chunks_conditional)): | |
if chunk_idx == 0: | |
current_video = IImage_normalized(video) | |
else: | |
current_video = IImage_normalized(video[self.unet_params.num_control_input_frames:]) | |
if merged_video is None: | |
merged_video = current_video | |
else: | |
merged_video &= current_video | |
if concat_video: | |
filename = video_naming(prompts[0], save_format, batch_idx, 0) | |
result_file_video = (storage_fol / filename).absolute().as_posix() | |
result_file_video = (Path(result_file_video).parent / (result_file_stem+Path(result_file_video).suffix)).as_posix() | |
self.result_processor.save_to_file(video=merged_video.torch(vmin=0, vmax=1), prompt=prompts[0], video_filename=result_file_video, prompt_on_vid=False) | |
def forward(self, prompt, input_frames=None, input_frames_conditioning=None, latents=None): | |
call_params = self.inference_params.to_dict() | |
# print(f"INFERENCE PARAMS = {call_params}") | |
call_params["prompt"] = prompt | |
call_params["image"] = input_frames | |
call_params["num_frames"] = self.inference_params.video_length | |
call_params["return_dict"] = False | |
call_params["output_type"] = "pt_t2v" | |
call_params["mask_generator"] = self.mask_generator | |
call_params["precision"] = "16" if self.trainer.precision.startswith("16") else "32" | |
call_params["no_text_condition_control"] = self.opt_params.no_text_condition_control | |
call_params["weight_control_sample"] = self.unet_params.weight_control_sample | |
call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask | |
call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch | |
call_params["img_cond_resampler"] = self.resampler if self.unet_params.use_resampler else None | |
call_params["img_cond_encoder"] = self.image_encoder if self.unet_params.use_resampler else None | |
call_params["input_frames_conditioning"] = input_frames_conditioning | |
call_params["cfg_text_image"] = self.unet_params.cfg_text_image | |
call_params["use_of"] = self.unet_params.use_of | |
if latents is not None: | |
call_params["latents"] = latents | |
sample = self.inference_pipeline(generator=self.inference_generator, **call_params) | |
return sample | |