#!/usr/bin/env python # -*- coding: UTF-8 -*- ''' @Project :EchoMimic @File :pipeline_echo_mimic.py @Author :juzhen.czy @Date :2024/3/4 17:44 ''' # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py import inspect import math from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import torch from diffusers import DiffusionPipeline import torch.nn.functional as F from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging from diffusers.utils.torch_utils import randn_tensor from einops import rearrange from tqdm import tqdm from transformers import CLIPImageProcessor from src.models.mutual_self_attention import ReferenceAttentionControl from src.pipelines.context import get_context_scheduler from src.pipelines.utils import get_tensor_interpolation_method @dataclass class Audio2VideoPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] class Audio2VideoPipeline(DiffusionPipeline): _optional_components = [] def __init__( self, vae, reference_unet, denoising_unet, audio_guider, face_locator, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], image_proj_model=None, tokenizer=None, text_encoder=None, ): super().__init__() self.register_modules( vae=vae, reference_unet=reference_unet, denoising_unet=denoising_unet, audio_guider=audio_guider, face_locator=face_locator, scheduler=scheduler, image_proj_model=image_proj_model, tokenizer=tokenizer, text_encoder=text_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.clip_image_processor = CLIPImageProcessor() self.ref_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True ) def enable_vae_slicing(self): self.vae.enable_slicing() def disable_vae_slicing(self): self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") video = [] for frame_idx in tqdm(range(latents.shape[0])): video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.cpu().float().numpy() return video def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set( inspect.signature(self.scheduler.step).parameters.keys() ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def prepare_latents( self, batch_size, num_channels_latents, width, height, video_length, dtype, device, generator ): shape = ( batch_size, num_channels_latents, # context_frame_length, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents_seg = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) latents = latents_seg latents = torch.clamp(latents, -1.5, 1.5) latents = latents * self.scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer( prompt, padding="longest", return_tensors="pt" ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) if ( hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) text_embeddings = text_embeddings.view( bs_embed * num_videos_per_prompt, seq_len, -1 ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if ( hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) uncond_embeddings = uncond_embeddings.view( batch_size * num_videos_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def interpolate_latents( self, latents: torch.Tensor, interpolation_factor: int, device ): if interpolation_factor < 2: return latents new_latents = torch.zeros( ( latents.shape[0], latents.shape[1], ((latents.shape[2] - 1) * interpolation_factor) + 1, latents.shape[3], latents.shape[4], ), device=latents.device, dtype=latents.dtype, ) org_video_length = latents.shape[2] rate = [i / interpolation_factor for i in range(interpolation_factor)][1:] new_index = 0 v0 = None v1 = None for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]): v0 = latents[:, :, i0, :, :] v1 = latents[:, :, i1, :, :] new_latents[:, :, new_index, :, :] = v0 new_index += 1 for f in rate: v = get_tensor_interpolation_method()( v0.to(device=device), v1.to(device=device), f ) new_latents[:, :, new_index, :, :] = v.to(latents.device) new_index += 1 new_latents[:, :, new_index, :, :] = v1 new_index += 1 return new_latents @torch.no_grad() def __call__( self, ref_image, audio_path, face_mask_tensor, width, height, video_length, num_inference_steps, guidance_scale, num_images_per_prompt=1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, context_schedule="uniform", context_frames=12, context_stride=1, context_overlap=0, context_batch_size=1, interpolation_factor=1, audio_sample_rate=16000, fps=25, audio_margin=2, **kwargs, ): # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps batch_size = 1 reference_control_writer = ReferenceAttentionControl( self.reference_unet, do_classifier_free_guidance=do_classifier_free_guidance, mode="write", batch_size=batch_size, fusion_blocks="full", ) reference_control_reader = ReferenceAttentionControl( self.denoising_unet, do_classifier_free_guidance=do_classifier_free_guidance, mode="read", batch_size=batch_size, fusion_blocks="full", ) whisper_feature = self.audio_guider.audio2feat(audio_path) whisper_chunks = self.audio_guider.feature2chunks(feature_array=whisper_feature, fps=fps) print("whisper_chunks:", whisper_chunks.shape) audio_frame_num = whisper_chunks.shape[0] audio_fea_final = torch.Tensor(whisper_chunks).to(dtype=self.vae.dtype, device=self.vae.device) audio_fea_final = audio_fea_final.unsqueeze(0) print("audio_fea_final:", audio_fea_final.shape) video_length = min(video_length, audio_frame_num) if video_length < audio_frame_num: audio_fea_final = audio_fea_final[:, :video_length, :, :] num_channels_latents = self.denoising_unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, width, height, video_length, audio_fea_final.dtype, device, generator ) # print(video_length, latents.shape) face_locator_tensor = self.face_locator(face_mask_tensor) uc_face_locator_tensor = torch.zeros_like(face_locator_tensor) face_locator_tensor = torch.cat([uc_face_locator_tensor, face_locator_tensor], dim=0) # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Prepare ref image latents ref_image_tensor = self.ref_image_processor.preprocess( ref_image, height=height, width=width ) ref_image_tensor = ref_image_tensor.to( dtype=self.vae.dtype, device=self.vae.device ) ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean ref_image_latents = ref_image_latents * 0.18215 # (b , 4, h, w) context_scheduler = get_context_scheduler(context_schedule) # denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order context_queue = list( context_scheduler( 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap, ) ) print("ref_image_latents shape:", ref_image_latents.shape) print("face_mask_tensor shape:", face_mask_tensor.shape) print("face_locator_tensor shape:", face_locator_tensor.shape) with self.progress_bar(total=num_inference_steps) as progress_bar: for t_i, t in enumerate(timesteps): noise_pred = torch.zeros( ( latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:], ), device=latents.device, dtype=latents.dtype, ) counter = torch.zeros( (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype, ) # 1. Forward reference image if t_i == 0: self.reference_unet( ref_image_latents, torch.zeros_like(t), encoder_hidden_states=None, return_dict=False, ) reference_control_reader.update(reference_control_writer, do_classifier_free_guidance=True) num_context_batches = math.ceil(len(context_queue) / context_batch_size) global_context = [] for j in range(num_context_batches): global_context.append( context_queue[ j * context_batch_size : (j + 1) * context_batch_size ] ) for context in global_context: new_context = [[0 for _ in range(len(context[c_j]))] for c_j in range(len(context))] for c_j in range(len(context)): for c_i in range(len(context[c_j])): new_context[c_j][c_i] = (context[c_j][c_i] + t_i * 2) % video_length latent_model_input = ( torch.cat([latents[:, :, c] for c in new_context]) .to(device) .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1) ) audio_latents = torch.cat([audio_fea_final[:, c] for c in new_context]).to(device) audio_latents = torch.cat([torch.zeros_like(audio_latents), audio_latents], 0) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) pred = self.denoising_unet( latent_model_input, t, encoder_hidden_states=None, audio_cond_fea=audio_latents, face_musk_fea=face_locator_tensor, return_dict=False, )[0] for j, c in enumerate(new_context): noise_pred[:, :, c] = noise_pred[:, :, c] + pred counter[:, :, c] = counter[:, :, c] + 1 # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs ).prev_sample if t_i == len(timesteps) - 1 or ( (t_i + 1) > num_warmup_steps and (t_i + 1) % self.scheduler.order == 0 ): progress_bar.update() reference_control_reader.clear() reference_control_writer.clear() if interpolation_factor > 0: latents = self.interpolate_latents(latents, interpolation_factor, device) # Post-processing images = self.decode_latents(latents) # (b, c, f, h, w) # Convert to tensor if output_type == "tensor": images = torch.from_numpy(images) if not return_dict: return images return Audio2VideoPipelineOutput(videos=images) def smooth_f_axis(self, tensor, smoothing_coef=0.2): """ 对5维tensor的F轴进行平滑,首尾帧保持原始值,其他帧应用平滑操作。 参数: tensor -- 输入的5D Tensor,形状为(B, C, F, W, H) smoothing_coef -- 平滑前后帧的系数,默认为0.3 返回: smoothed_tensor -- 平滑后的5D Tensor """ # 生成平滑核 weight_kernel = torch.tensor([smoothing_coef, 1 - 2 * smoothing_coef, smoothing_coef], device=tensor.device, dtype=tensor.dtype).view(1, 1, -1) b, c, f, w, h = tensor.size() # 对于第三轴(F轴)仅平滑内部帧 (1:-1) # 首先,移动要平滑的轴到最后一个维度以便应用conv1d tensor_moved = rearrange(tensor, "b c f h w -> b (c h w) f") # 对除了两端帧外的所有帧应用conv1d internal_frames = F.conv1d(tensor_moved, weight_kernel) print("tensor:", tensor.shape) # 重新整理输出形状为 (B, F, C, W, H) internal_frames = rearrange(internal_frames, "b (c h w) f -> b c f h w", c=c, h=h, w=w) print("internal_frames:", internal_frames.shape) # 将第一帧和最后一帧保持不变,合并结果 # 首帧 tensor[:, :, 0:1, :, :], 中间帧 internal_frames[:, :, 1:-1, :, :], 最后帧 tensor[:, :, -1:, :, :] smoothed_tensor = torch.cat( [tensor[:, :, 0:1, :, :], internal_frames, tensor[:, :, -1:, :, :]], dim=2) return smoothed_tensor