import os from datetime import datetime from pathlib import Path import torch from diffusers import AutoencoderKL, DDIMScheduler from einops import repeat from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import CLIPVisionModelWithProjection import torch.nn.functional as F import gc from huggingface_hub import hf_hub_download from musepose.models.pose_guider import PoseGuider from musepose.models.unet_2d_condition import UNet2DConditionModel from musepose.models.unet_3d import UNet3DConditionModel from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from musepose.utils.util import get_fps, read_frames, save_videos_grid from downloading_weights import download_models class MusePoseInference: def __init__(self, model_dir, output_dir): self.image_gen_model_paths = { "pretrained_base_model": os.path.join(model_dir, "sd-image-variations-diffusers"), "pretrained_vae": os.path.join(model_dir, "sd-vae-ft-mse"), "image_encoder": os.path.join(model_dir, "image_encoder"), } self.musepose_model_paths = { "denoising_unet": os.path.join(model_dir, "MusePose", "denoising_unet.pth"), "reference_unet": os.path.join(model_dir, "MusePose", "reference_unet.pth"), "pose_guider": os.path.join(model_dir, "MusePose", "pose_guider.pth"), "motion_module": os.path.join(model_dir, "MusePose", "motion_module.pth"), } self.inference_config_path = os.path.join("configs", "inference_v2.yaml") self.vae = None self.reference_unet = None self.denoising_unet = None self.pose_guider = None self.image_enc = None self.pipe = None self.model_dir = model_dir self.output_dir = output_dir if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) def infer_musepose( self, ref_image_path: str, pose_video_path: str, weight_dtype: str, W: int, H: int, L: int, S: int, O: int, cfg: float, seed: int, steps: int, fps: int, skip: int ): download_models(model_dir=self.model_dir) print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}") print(f"Input Image Path: {ref_image_path}") print(f"Pose Video Path: {pose_video_path}") print(f"Dtype: {weight_dtype}") print(f"Width: {W}") print(f"Height: {H}") print(f"Video Frame Length: {L}") print(f"VIDEO SLICE FRAME LENGTH:: {S}") print(f"VIDEO SLICE OVERLAP_FRAME NUMBER: {O}") print(f"CFG: {cfg}") print(f"Seed: {seed}") print(f"Steps: {steps}") print(f"FPS: {fps}") print(f"Skip: {skip}") image_file_name = os.path.splitext(os.path.basename(ref_image_path))[0] pose_video_file_name = os.path.splitext(os.path.basename(pose_video_path))[0] output_file_name = f"img_{image_file_name}_pose_{pose_video_file_name}" output_path = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}.mp4')) output_path_demo = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}_demo.mp4')) if weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 self.vae = AutoencoderKL.from_pretrained( self.image_gen_model_paths["pretrained_vae"], ).to("cuda", dtype=weight_dtype) self.reference_unet = UNet2DConditionModel.from_pretrained( self.image_gen_model_paths["pretrained_base_model"], subfolder="unet", ).to(dtype=weight_dtype, device="cuda") inference_config_path = self.inference_config_path infer_config = OmegaConf.load(inference_config_path) self.denoising_unet = UNet3DConditionModel.from_pretrained_2d( Path(self.image_gen_model_paths["pretrained_base_model"]), Path(self.musepose_model_paths["motion_module"]), subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device="cuda") self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( dtype=weight_dtype, device="cuda" ) self.image_enc = CLIPVisionModelWithProjection.from_pretrained( self.image_gen_model_paths["image_encoder"] ).to(dtype=weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) generator = torch.manual_seed(seed) width, height = W, H # load pretrained weights self.denoising_unet.load_state_dict( torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"), strict=False, ) self.reference_unet.load_state_dict( torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"), ) self.pose_guider.load_state_dict( torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"), ) self.pipe = Pose2VideoPipeline( vae=self.vae, image_encoder=self.image_enc, reference_unet=self.reference_unet, denoising_unet=self.denoising_unet, pose_guider=self.pose_guider, scheduler=scheduler, ) self.pipe = self.pipe.to("cuda", dtype=weight_dtype) print("image: ", ref_image_path, "pose_video: ", pose_video_path) ref_image_pil = Image.open(ref_image_path).convert("RGB") pose_list = [] pose_tensor_list = [] pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") L = min(L, len(pose_images)) pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) original_width, original_height = 0, 0 pose_images = pose_images[::skip + 1] print("processing length:", len(pose_images)) src_fps = src_fps // (skip + 1) print("fps", src_fps) L = L // ((skip + 1)) for pose_image_pil in pose_images[: L]: pose_tensor_list.append(pose_transform(pose_image_pil)) pose_list.append(pose_image_pil) original_width, original_height = pose_image_pil.size pose_image_pil = pose_image_pil.resize((width, height)) # repeart the last segment last_segment_frame_num = (L - S) % (S - O) repeart_frame_num = (S - O - last_segment_frame_num) % (S - O) for i in range(repeart_frame_num): pose_list.append(pose_list[-1]) pose_tensor_list.append(pose_tensor_list[-1]) ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L) pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) pose_tensor = pose_tensor.transpose(0, 1) pose_tensor = pose_tensor.unsqueeze(0) video = self.pipe( ref_image_pil, pose_list, width, height, len(pose_list), steps, cfg, generator=generator, context_frames=S, context_stride=1, context_overlap=O, ).videos result = self.scale_video(video[:, :, :L], original_width, original_height) save_videos_grid( result, output_path, n_rows=1, fps=src_fps if fps is None or fps < 0 else fps, ) video = torch.cat([ref_image_tensor, pose_tensor[:, :, :L], video[:, :, :L]], dim=0) video = self.scale_video(video, original_width, original_height) save_videos_grid( video, output_path_demo, n_rows=3, fps=src_fps if fps is None or fps < 0 else fps, ) self.release_vram() return output_path, output_path_demo def release_vram(self): models = [ 'vae', 'reference_unet', 'denoising_unet', 'pose_guider', 'image_enc', 'pipe' ] for model_name in models: model = getattr(self, model_name, None) if model is not None: del model setattr(self, model_name, None) if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() @staticmethod def scale_video(video, width, height): video_reshaped = video.view(-1, *video.shape[2:]) # [batch*frames, channels, height, width] scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False) scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height, width) # [batch, frames, channels, height, width] return scaled_video