# coding: utf-8 """ Pipeline of LivePortrait (CPU-optimized version) """ import torch torch.set_num_threads(4) # Limit the number of threads to reduce memory usage import cv2 import numpy as np import pickle import os import os.path as osp from rich.progress import track import gc from .config.argument_config import ArgumentConfig from .config.inference_config import InferenceConfig from .config.crop_config import CropConfig from .utils.cropper import Cropper from .utils.camera import get_rotation_matrix from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.retargeting_utils import calc_lip_close_ratio from .utils.io import load_image_rgb, load_driving_info, resize_to_limit from .utils.helper_cpu import mkdir, basename, dct2cpu, is_video, is_template,show_memory_usage from .utils.rprint import rlog as log from .live_portrait_wrapper_cpu import LivePortraitWrapperCPU as wrapper # from .live_portrait_wrapper import LivePortraitWrapper as wrapper def make_abs_path(fn): return osp.join(osp.dirname(osp.realpath(__file__)), fn) class LiveCPUPortraitPipeline(object): def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): self.live_portrait_wrapper: wrapper = wrapper(cfg=inference_cfg) self.cropper = Cropper(crop_cfg=crop_cfg) self.mem_mon = show_memory_usage() def execute(self, args: ArgumentConfig): inference_cfg = self.live_portrait_wrapper.cfg # for convenience ######## process source portrait ######## img_rgb = load_image_rgb(args.source_image) log(f"resizing source image to {inference_cfg.ref_max_shape}x{inference_cfg.ref_max_shape}") img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) log(f"processing image from {args.source_image}") crop_info = self.cropper.crop_single_image(img_rgb) source_lmk = crop_info['lmk_crop'] img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] if inference_cfg.flag_do_crop: log(f"Cropping source image.") I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) else: log(f"Load source image from {args.source_image}") I_s = self.live_portrait_wrapper.prepare_source(img_rgb) x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) x_c_s = x_s_info['kp'] R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) if inference_cfg.flag_lip_zero: c_d_lip_before_animation = [0.] combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: inference_cfg.flag_lip_zero = False else: lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) ######## process driving info ######## output_fps = 10 # default fps if is_video(args.driving_info): log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}") output_fps = int(get_fps(args.driving_info)) log(f'The FPS of {args.driving_info} is: {output_fps}') driving_rgb_lst = load_driving_info(args.driving_info) driving_rgb_lst_256 = [cv2.resize(_, (128,128)) for _ in driving_rgb_lst] I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) n_frames = I_d_lst.shape[0] if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting: driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst) input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) elif is_template(args.driving_info): log(f"Load from video templates {args.driving_info}") with open(args.driving_info, 'rb') as f: template_lst, driving_lmk_lst = pickle.load(f) n_frames = template_lst[0]['n_frames'] input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) else: raise Exception("Unsupported driving types!") ######## prepare for pasteback ######## if inference_cfg.flag_pasteback: mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) I_p_paste_lst = [] # Determine batch size based on available memory and frame size batch_size = 128 # Set this based on your system's memory capacity I_p_lst = [] R_d_0, x_d_0_info = None, None log(f'Number of frames:{n_frames} processing in {n_frames/batch_size:.0f} batches') for start in range(0, n_frames, batch_size): end = min(start + batch_size, n_frames) for i in track(range(start, end), description=f'Animating.....', total=end - start): log(f'Processing frame {i+1}/{end}') if is_video(args.driving_info): I_d_i = I_d_lst[i] x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) else: x_d_i_info = template_lst[i] x_d_i_info = dct2cpu(x_d_i_info) R_d_i = x_d_i_info['R_d'] if i == 0: R_d_0 = R_d_i x_d_0_info = x_d_i_info if inference_cfg.flag_relative: R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) else: R_new = R_d_i delta_new = x_d_i_info['exp'] scale_new = x_s_info['scale'] t_new = x_d_i_info['t'] t_new[..., 2].fill_(0) # zero tz x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: if inference_cfg.flag_lip_zero: x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: if inference_cfg.flag_lip_zero: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) else: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) else: eyes_delta, lip_delta = None, None if inference_cfg.flag_eye_retargeting: c_d_eyes_i = input_eye_ratio_lst[i] combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) if inference_cfg.flag_lip_retargeting: c_d_lip_i = input_lip_ratio_lst[i] combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) if inference_cfg.flag_relative: x_d_i_new = x_s + \ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) else: x_d_i_new = x_d_i_new + \ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) if inference_cfg.flag_stitching: x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) # Check memory usage periodically show_memory_usage() out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] I_p_lst.append(I_p_i) log(f'Generated {len(I_p_lst)} frames ') if inference_cfg.flag_pasteback: I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) I_p_paste_lst.append(I_p_i_to_ori_blend) # Clear memory after processing the batch torch.cuda.empty_cache() #del I_d_lst, x_d_i_new, x_d_i_info, out, I_p_i # Clear batch-related variables gc.collect() # Force garbage collection # Check memory usage periodically show_memory_usage() mkdir(args.output_dir) wfp_concat = None flag_has_audio = has_audio_stream(args.driving_info) if is_video(args.driving_info): frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256) wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) if flag_has_audio: wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4') add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio) os.replace(wfp_concat_with_audio, wfp_concat) log(f"Replace {wfp_concat} with {wfp_concat_with_audio}") wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') if inference_cfg.flag_pasteback: images2video(I_p_paste_lst, wfp=wfp, fps=output_fps) else: images2video(I_p_lst, wfp=wfp, fps=output_fps) if flag_has_audio: wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4') add_audio_to_video(wfp, args.driving_info, wfp_with_audio) os.replace(wfp_with_audio, wfp) log(f"Replace {wfp} with {wfp_with_audio}") return wfp, wfp_concat