|
import logging |
|
import os |
|
import cv2 |
|
import time |
|
import copy |
|
import dill |
|
import torch |
|
from ultralytics import YOLO |
|
import safetensors.torch |
|
import gradio as gr |
|
from gradio_i18n import Translate, gettext as _ |
|
from ultralytics.utils import LOGGER as ultralytics_logger |
|
from enum import Enum |
|
from typing import Union, List, Dict, Tuple |
|
|
|
from modules.utils.paths import * |
|
from modules.utils.image_helper import * |
|
from modules.utils.video_helper import * |
|
from modules.live_portrait.model_downloader import * |
|
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper |
|
from modules.utils.camera import get_rotation_matrix |
|
from modules.utils.helper import load_yaml |
|
from modules.utils.constants import * |
|
from modules.config.inference_config import InferenceConfig |
|
from modules.live_portrait.spade_generator import SPADEDecoder |
|
from modules.live_portrait.warping_network import WarpingNetwork |
|
from modules.live_portrait.motion_extractor import MotionExtractor |
|
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor |
|
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork |
|
|
|
|
|
class LivePortraitInferencer: |
|
def __init__(self, |
|
model_dir: str = MODELS_DIR, |
|
output_dir: str = OUTPUTS_DIR): |
|
self.model_dir = model_dir |
|
self.output_dir = output_dir |
|
relative_dirs = [ |
|
os.path.join(self.model_dir, "animal"), |
|
os.path.join(self.output_dir, "videos"), |
|
os.path.join(self.output_dir, "temp"), |
|
os.path.join(self.output_dir, "temp", "video_frames"), |
|
os.path.join(self.output_dir, "temp", "video_frames", "out"), |
|
] |
|
for dir_path in relative_dirs: |
|
os.makedirs(dir_path, exist_ok=True) |
|
|
|
self.model_config = load_yaml(MODEL_CONFIG)["model_params"] |
|
|
|
self.appearance_feature_extractor = None |
|
self.motion_extractor = None |
|
self.warping_module = None |
|
self.spade_generator = None |
|
self.stitching_retargeting_module = None |
|
self.pipeline = None |
|
self.detect_model = None |
|
self.device = self.get_device() |
|
self.model_type = ModelType.HUMAN.value |
|
|
|
self.mask_img = None |
|
self.temp_img_idx = 0 |
|
self.src_image = None |
|
self.src_image_list = None |
|
self.sample_image = None |
|
self.driving_images = None |
|
self.driving_values = None |
|
self.crop_factor = None |
|
self.psi = None |
|
self.psi_list = None |
|
self.d_info = None |
|
|
|
def load_models(self, |
|
model_type: str = ModelType.HUMAN.value, |
|
progress=gr.Progress()): |
|
if isinstance(model_type, ModelType): |
|
model_type = model_type.value |
|
if model_type not in [mode.value for mode in ModelType]: |
|
model_type = ModelType.HUMAN.value |
|
|
|
self.model_type = model_type |
|
if model_type == ModelType.ANIMAL.value: |
|
model_dir = os.path.join(self.model_dir, "animal") |
|
else: |
|
model_dir = self.model_dir |
|
|
|
self.download_if_no_models( |
|
model_type=model_type |
|
) |
|
|
|
total_models_num = 5 |
|
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...") |
|
appearance_feat_config = self.model_config["appearance_feature_extractor_params"] |
|
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device) |
|
self.appearance_feature_extractor = self.load_safe_tensor( |
|
self.appearance_feature_extractor, |
|
os.path.join(model_dir, "appearance_feature_extractor.safetensors") |
|
) |
|
|
|
progress(1/total_models_num, desc="Loading Motion Extractor model...") |
|
motion_ext_config = self.model_config["motion_extractor_params"] |
|
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device) |
|
self.motion_extractor = self.load_safe_tensor( |
|
self.motion_extractor, |
|
os.path.join(model_dir, "motion_extractor.safetensors") |
|
) |
|
|
|
progress(2/total_models_num, desc="Loading Warping Module model...") |
|
warping_module_config = self.model_config["warping_module_params"] |
|
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device) |
|
self.warping_module = self.load_safe_tensor( |
|
self.warping_module, |
|
os.path.join(model_dir, "warping_module.safetensors") |
|
) |
|
|
|
progress(3/total_models_num, desc="Loading Spade generator model...") |
|
spaded_decoder_config = self.model_config["spade_generator_params"] |
|
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device) |
|
self.spade_generator = self.load_safe_tensor( |
|
self.spade_generator, |
|
os.path.join(model_dir, "spade_generator.safetensors") |
|
) |
|
|
|
progress(4/total_models_num, desc="Loading Stitcher model...") |
|
stitcher_config = self.model_config["stitching_retargeting_module_params"] |
|
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device) |
|
self.stitching_retargeting_module = self.load_safe_tensor( |
|
self.stitching_retargeting_module, |
|
os.path.join(model_dir, "stitching_retargeting_module.safetensors"), |
|
True |
|
) |
|
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module} |
|
|
|
if self.pipeline is None or model_type != self.model_type: |
|
self.pipeline = LivePortraitWrapper( |
|
InferenceConfig(), |
|
self.appearance_feature_extractor, |
|
self.motion_extractor, |
|
self.warping_module, |
|
self.spade_generator, |
|
self.stitching_retargeting_module |
|
) |
|
|
|
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n" |
|
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device) |
|
|
|
def edit_expression(self, |
|
model_type: str = ModelType.HUMAN.value, |
|
rotate_pitch: float = 0, |
|
rotate_yaw: float = 0, |
|
rotate_roll: float = 0, |
|
blink: float = 0, |
|
eyebrow: float = 0, |
|
wink: float = 0, |
|
pupil_x: float = 0, |
|
pupil_y: float = 0, |
|
aaa: float = 0, |
|
eee: float = 0, |
|
woo: float = 0, |
|
smile: float = 0, |
|
src_ratio: float = 1, |
|
sample_ratio: float = 1, |
|
sample_parts: str = SamplePart.ALL.value, |
|
crop_factor: float = 2.3, |
|
src_image: Optional[str] = None, |
|
sample_image: Optional[str] = None,) -> None: |
|
if isinstance(model_type, ModelType): |
|
model_type = model_type.value |
|
if model_type not in [mode.value for mode in ModelType]: |
|
model_type = ModelType.HUMAN |
|
|
|
if self.pipeline is None or model_type != self.model_type: |
|
self.load_models( |
|
model_type=model_type |
|
) |
|
|
|
try: |
|
rotate_yaw = -rotate_yaw |
|
|
|
if src_image is not None: |
|
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor: |
|
self.crop_factor = crop_factor |
|
self.psi = self.prepare_source(src_image, crop_factor) |
|
self.src_image = src_image |
|
else: |
|
return None |
|
|
|
psi = self.psi |
|
s_info = psi.x_s_info |
|
|
|
s_exp = s_info['exp'] * src_ratio |
|
s_exp[0, 5] = s_info['exp'][0, 5] |
|
s_exp += s_info['kp'] |
|
|
|
es = ExpressionSet() |
|
|
|
if isinstance(sample_image, np.ndarray) and sample_image: |
|
if id(self.sample_image) != id(sample_image): |
|
self.sample_image = sample_image |
|
d_image_np = (sample_image * 255).byte().numpy() |
|
d_face = self.crop_face(d_image_np[0], 1.7) |
|
i_d = self.prepare_src_image(d_face) |
|
self.d_info = self.pipeline.get_kp_info(i_d) |
|
self.d_info['exp'][0, 5, 0] = 0 |
|
self.d_info['exp'][0, 5, 1] = 0 |
|
|
|
|
|
if sample_parts == SamplePart.ONLY_EXPRESSION.value or sample_parts == SamplePart.ONLY_EXPRESSION.ALL.value: |
|
es.e += self.d_info['exp'] * sample_ratio |
|
if sample_parts == SamplePart.ONLY_ROTATION.value or sample_parts == SamplePart.ONLY_ROTATION.ALL.value: |
|
rotate_pitch += self.d_info['pitch'] * sample_ratio |
|
rotate_yaw += self.d_info['yaw'] * sample_ratio |
|
rotate_roll += self.d_info['roll'] * sample_ratio |
|
elif sample_parts == SamplePart.ONLY_MOUTH.value: |
|
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20)) |
|
elif sample_parts == SamplePart.ONLY_EYES.value: |
|
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16)) |
|
|
|
es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile, |
|
rotate_pitch, rotate_yaw, rotate_roll) |
|
|
|
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1], |
|
s_info['roll'] + es.r[2]) |
|
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t'] |
|
|
|
x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new) |
|
|
|
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new) |
|
crop_out = self.pipeline.parse_output(crop_out['out'])[0] |
|
|
|
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR) |
|
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8) |
|
|
|
temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png") |
|
save_image(numpy_array=crop_out, output_path=temp_out_img_path) |
|
save_image(numpy_array=out, output_path=out_img_path) |
|
|
|
return out |
|
except Exception as e: |
|
raise |
|
|
|
def create_video(self, |
|
model_type: str = ModelType.HUMAN.value, |
|
retargeting_eyes: float = 1, |
|
retargeting_mouth: float = 1, |
|
crop_factor: float = 2.3, |
|
src_image: Optional[str] = None, |
|
driving_vid_path: Optional[str] = None, |
|
progress: gr.Progress = gr.Progress() |
|
): |
|
if self.pipeline is None or model_type != self.model_type: |
|
self.load_models( |
|
model_type=model_type |
|
) |
|
|
|
vid_info = get_video_info(vid_input=driving_vid_path) |
|
|
|
if src_image is not None: |
|
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor: |
|
self.crop_factor = crop_factor |
|
self.src_image = src_image |
|
|
|
self.psi_list = [self.prepare_source(src_image, crop_factor)] |
|
|
|
progress(0, desc="Extracting frames from the video..") |
|
driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path) |
|
|
|
driving_length = 0 |
|
if driving_images is not None: |
|
if id(driving_images) != id(self.driving_images): |
|
self.driving_images = driving_images |
|
self.driving_values = self.prepare_driving_video(driving_images) |
|
driving_length = len(self.driving_values) |
|
|
|
total_length = len(driving_images) |
|
|
|
c_i_es = ExpressionSet() |
|
c_o_es = ExpressionSet() |
|
d_0_es = None |
|
|
|
psi = None |
|
for i in range(total_length): |
|
|
|
if i == 0: |
|
psi = self.psi_list[i] |
|
s_info = psi.x_s_info |
|
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t'])) |
|
|
|
new_es = ExpressionSet(es=s_es) |
|
|
|
if i < driving_length: |
|
d_i_info = self.driving_values[i] |
|
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) |
|
|
|
if d_0_es is None: |
|
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t'])) |
|
|
|
self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16)) |
|
self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20)) |
|
|
|
new_es.e += d_i_info['exp'] - d_0_es.e |
|
new_es.r += d_i_r - d_0_es.r |
|
new_es.t += d_i_info['t'] - d_0_es.t |
|
|
|
r_new = get_rotation_matrix( |
|
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2]) |
|
d_new = new_es.s * (new_es.e @ r_new) + new_es.t |
|
d_new = self.pipeline.stitching(psi.x_s_user, d_new) |
|
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new) |
|
crop_out = self.pipeline.parse_output(crop_out['out'])[0] |
|
|
|
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), |
|
cv2.INTER_LINEAR) |
|
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype( |
|
np.uint8) |
|
|
|
out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png") |
|
save_image(out, out_frame_path) |
|
|
|
progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..") |
|
|
|
video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos")) |
|
|
|
return video_path |
|
|
|
def download_if_no_models(self, |
|
model_type: str = ModelType.HUMAN.value, |
|
progress=gr.Progress(), ): |
|
progress(0, desc="Downloading models...") |
|
|
|
if isinstance(model_type, ModelType): |
|
model_type = model_type.value |
|
if model_type == ModelType.ANIMAL.value: |
|
models_urls_dic = MODELS_ANIMAL_URL |
|
model_dir = os.path.join(self.model_dir, "animal") |
|
else: |
|
models_urls_dic = MODELS_URL |
|
model_dir = self.model_dir |
|
|
|
for model_name, model_url in models_urls_dic.items(): |
|
if model_url.endswith(".pt"): |
|
model_name += ".pt" |
|
elif model_url.endswith(".n2x"): |
|
model_name += ".n2x" |
|
else: |
|
model_name += ".safetensors" |
|
model_path = os.path.join(model_dir, model_name) |
|
if not os.path.exists(model_path): |
|
download_model(model_path, model_url) |
|
|
|
@staticmethod |
|
def load_safe_tensor(model, file_path, is_stitcher=False): |
|
def filter_stitcher(checkpoint, prefix): |
|
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if |
|
key.startswith(prefix)} |
|
return filtered_checkpoint |
|
|
|
if is_stitcher: |
|
model.load_state_dict(filter_stitcher(safetensors.torch.load_file(file_path), 'retarget_shoulder')) |
|
else: |
|
model.load_state_dict(safetensors.torch.load_file(file_path)) |
|
model.eval() |
|
return model |
|
|
|
@staticmethod |
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
def get_temp_img_name(self): |
|
self.temp_img_idx += 1 |
|
return "expression_edit_preview" + str(self.temp_img_idx) + ".png" |
|
|
|
@staticmethod |
|
def parsing_command(command, motoin_link): |
|
command.replace(' ', '') |
|
lines = command.split('\n') |
|
|
|
cmd_list = [] |
|
|
|
total_length = 0 |
|
|
|
i = 0 |
|
for line in lines: |
|
i += 1 |
|
if not line: |
|
continue |
|
try: |
|
cmds = line.split('=') |
|
idx = int(cmds[0]) |
|
if idx == 0: es = ExpressionSet() |
|
else: es = ExpressionSet(es = motoin_link[idx]) |
|
cmds = cmds[1].split(':') |
|
change = int(cmds[0]) |
|
keep = int(cmds[1]) |
|
except Exception as e: |
|
print(f"(AdvancedLivePortrait) Command Err Line {i}: {line}, :{e}") |
|
return None, None |
|
|
|
total_length += change + keep |
|
es.div(change) |
|
cmd_list.append(Command(es, change, keep)) |
|
|
|
return cmd_list, total_length |
|
|
|
def get_face_bboxes(self, image_rgb): |
|
pred = self.detect_model(image_rgb, conf=0.7, device=self.device) |
|
return pred[0].boxes.xyxy.cpu().numpy() |
|
|
|
def detect_face(self, image_rgb, crop_factor, sort = True): |
|
original_logger_level = ultralytics_logger.level |
|
ultralytics_logger.setLevel(logging.CRITICAL + 1) |
|
|
|
bboxes = self.get_face_bboxes(image_rgb) |
|
w, h = get_rgb_size(image_rgb) |
|
|
|
|
|
|
|
cx = w / 2 |
|
min_diff = w |
|
best_box = None |
|
for x1, y1, x2, y2 in bboxes: |
|
bbox_w = x2 - x1 |
|
if bbox_w < 30: continue |
|
diff = abs(cx - (x1 + bbox_w / 2)) |
|
if diff < min_diff: |
|
best_box = [x1, y1, x2, y2] |
|
|
|
min_diff = diff |
|
|
|
if best_box == None: |
|
print("Failed to detect face!!") |
|
return [0, 0, w, h] |
|
|
|
x1, y1, x2, y2 = best_box |
|
|
|
|
|
bbox_w = x2 - x1 |
|
bbox_h = y2 - y1 |
|
|
|
crop_w = bbox_w * crop_factor |
|
crop_h = bbox_h * crop_factor |
|
|
|
crop_w = max(crop_h, crop_w) |
|
crop_h = crop_w |
|
|
|
kernel_x = int(x1 + bbox_w / 2) |
|
kernel_y = int(y1 + bbox_h / 2) |
|
|
|
new_x1 = int(kernel_x - crop_w / 2) |
|
new_x2 = int(kernel_x + crop_w / 2) |
|
new_y1 = int(kernel_y - crop_h / 2) |
|
new_y2 = int(kernel_y + crop_h / 2) |
|
|
|
if not sort: |
|
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] |
|
|
|
if new_x1 < 0: |
|
new_x2 -= new_x1 |
|
new_x1 = 0 |
|
elif w < new_x2: |
|
new_x1 -= (new_x2 - w) |
|
new_x2 = w |
|
if new_x1 < 0: |
|
new_x2 -= new_x1 |
|
new_x1 = 0 |
|
|
|
if new_y1 < 0: |
|
new_y2 -= new_y1 |
|
new_y1 = 0 |
|
elif h < new_y2: |
|
new_y1 -= (new_y2 - h) |
|
new_y2 = h |
|
if new_y1 < 0: |
|
new_y2 -= new_y1 |
|
new_y1 = 0 |
|
|
|
if w < new_x2 and h < new_y2: |
|
over_x = new_x2 - w |
|
over_y = new_y2 - h |
|
over_min = min(over_x, over_y) |
|
new_x2 -= over_min |
|
new_y2 -= over_min |
|
|
|
ultralytics_logger.setLevel(original_logger_level) |
|
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] |
|
|
|
@staticmethod |
|
def retargeting(delta_out, driving_exp, factor, idxes): |
|
for idx in idxes: |
|
delta_out[0, idx] += driving_exp[0, idx] * factor |
|
|
|
@staticmethod |
|
def calc_face_region(square, dsize): |
|
region = copy.deepcopy(square) |
|
is_changed = False |
|
if dsize[0] < region[2]: |
|
region[2] = dsize[0] |
|
is_changed = True |
|
if dsize[1] < region[3]: |
|
region[3] = dsize[1] |
|
is_changed = True |
|
|
|
return region, is_changed |
|
|
|
@staticmethod |
|
def expand_img(rgb_img, square): |
|
crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1) |
|
new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]), |
|
cv2.INTER_LINEAR) |
|
return new_img |
|
|
|
def prepare_src_image(self, img): |
|
if isinstance(img, str): |
|
img = image_path_to_array(img) |
|
|
|
if len(img.shape) <= 3: |
|
img = img[np.newaxis, ...] |
|
|
|
d, h, w, c = img.shape |
|
img = img[0] |
|
input_shape = [256, 256] |
|
if h != input_shape[0] or w != input_shape[1]: |
|
if 256 < h: interpolation = cv2.INTER_AREA |
|
else: interpolation = cv2.INTER_LINEAR |
|
x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation) |
|
else: |
|
x = img.copy() |
|
|
|
if x.ndim == 3: |
|
x = x[np.newaxis].astype(np.float32) / 255. |
|
elif x.ndim == 4: |
|
x = x.astype(np.float32) / 255. |
|
else: |
|
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') |
|
x = np.clip(x, 0, 1) |
|
x = torch.from_numpy(x).permute(0, 3, 1, 2) |
|
x = x.to(self.device) |
|
return x |
|
|
|
def get_mask_img(self): |
|
if self.mask_img is None: |
|
self.mask_img = cv2.imread(MASK_TEMPLATES, cv2.IMREAD_COLOR) |
|
return self.mask_img |
|
|
|
def crop_face(self, img_rgb, crop_factor): |
|
crop_region = self.detect_face(img_rgb, crop_factor) |
|
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) |
|
face_img = rgb_crop(img_rgb, face_region) |
|
if is_changed: face_img = self.expand_img(face_img, crop_region) |
|
return face_img |
|
|
|
def prepare_source(self, source_image, crop_factor, is_video=False, tracking=False): |
|
|
|
|
|
|
|
if isinstance(source_image, str): |
|
source_image = image_path_to_array(source_image) |
|
|
|
if len(source_image.shape) <= 3: |
|
source_image = source_image[np.newaxis, ...] |
|
|
|
psi_list = [] |
|
for img_rgb in source_image: |
|
if tracking or len(psi_list) == 0: |
|
crop_region = self.detect_face(img_rgb, crop_factor) |
|
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) |
|
|
|
s_x = (face_region[2] - face_region[0]) / 512. |
|
s_y = (face_region[3] - face_region[1]) / 512. |
|
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y) |
|
mask_ori = cv2.warpAffine(self.get_mask_img(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR) |
|
mask_ori = mask_ori.astype(np.float32) / 255. |
|
|
|
if is_changed: |
|
s = (crop_region[2] - crop_region[0]) / 512. |
|
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s) |
|
|
|
face_img = rgb_crop(img_rgb, face_region) |
|
if is_changed: face_img = self.expand_img(face_img, crop_region) |
|
i_s = self.prepare_src_image(face_img) |
|
x_s_info = self.pipeline.get_kp_info(i_s) |
|
f_s_user = self.pipeline.extract_feature_3d(i_s) |
|
x_s_user = self.pipeline.transform_keypoint(x_s_info) |
|
psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori) |
|
if is_video == False: |
|
return psi |
|
psi_list.append(psi) |
|
|
|
return psi_list |
|
|
|
def prepare_driving_video(self, face_images): |
|
|
|
out_list = [] |
|
for f_img in face_images: |
|
i_d = self.prepare_src_image(f_img) |
|
d_info = self.pipeline.get_kp_info(i_d) |
|
out_list.append(d_info) |
|
|
|
return out_list |
|
|
|
@staticmethod |
|
def calc_fe(x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile, |
|
rotate_pitch, rotate_yaw, rotate_roll): |
|
|
|
x_d_new[0, 20, 1] += smile * -0.01 |
|
x_d_new[0, 14, 1] += smile * -0.02 |
|
x_d_new[0, 17, 1] += smile * 0.0065 |
|
x_d_new[0, 17, 2] += smile * 0.003 |
|
x_d_new[0, 13, 1] += smile * -0.00275 |
|
x_d_new[0, 16, 1] += smile * -0.00275 |
|
x_d_new[0, 3, 1] += smile * -0.0035 |
|
x_d_new[0, 7, 1] += smile * -0.0035 |
|
|
|
x_d_new[0, 19, 1] += mouth * 0.001 |
|
x_d_new[0, 19, 2] += mouth * 0.0001 |
|
x_d_new[0, 17, 1] += mouth * -0.0001 |
|
rotate_pitch -= mouth * 0.05 |
|
|
|
x_d_new[0, 20, 2] += eee * -0.001 |
|
x_d_new[0, 20, 1] += eee * -0.001 |
|
|
|
x_d_new[0, 14, 1] += eee * -0.001 |
|
|
|
x_d_new[0, 14, 1] += woo * 0.001 |
|
x_d_new[0, 3, 1] += woo * -0.0005 |
|
x_d_new[0, 7, 1] += woo * -0.0005 |
|
x_d_new[0, 17, 2] += woo * -0.0005 |
|
|
|
x_d_new[0, 11, 1] += wink * 0.001 |
|
x_d_new[0, 13, 1] += wink * -0.0003 |
|
x_d_new[0, 17, 0] += wink * 0.0003 |
|
x_d_new[0, 17, 1] += wink * 0.0003 |
|
x_d_new[0, 3, 1] += wink * -0.0003 |
|
rotate_roll -= wink * 0.1 |
|
rotate_yaw -= wink * 0.1 |
|
|
|
if 0 < pupil_x: |
|
x_d_new[0, 11, 0] += pupil_x * 0.0007 |
|
x_d_new[0, 15, 0] += pupil_x * 0.001 |
|
else: |
|
x_d_new[0, 11, 0] += pupil_x * 0.001 |
|
x_d_new[0, 15, 0] += pupil_x * 0.0007 |
|
|
|
x_d_new[0, 11, 1] += pupil_y * -0.001 |
|
x_d_new[0, 15, 1] += pupil_y * -0.001 |
|
eyes -= pupil_y / 2. |
|
|
|
x_d_new[0, 11, 1] += eyes * -0.001 |
|
x_d_new[0, 13, 1] += eyes * 0.0003 |
|
x_d_new[0, 15, 1] += eyes * -0.001 |
|
x_d_new[0, 16, 1] += eyes * 0.0003 |
|
x_d_new[0, 1, 1] += eyes * -0.00025 |
|
x_d_new[0, 2, 1] += eyes * 0.00025 |
|
|
|
if 0 < eyebrow: |
|
x_d_new[0, 1, 1] += eyebrow * 0.001 |
|
x_d_new[0, 2, 1] += eyebrow * -0.001 |
|
else: |
|
x_d_new[0, 1, 0] += eyebrow * -0.001 |
|
x_d_new[0, 2, 0] += eyebrow * 0.001 |
|
x_d_new[0, 1, 1] += eyebrow * 0.0003 |
|
x_d_new[0, 2, 1] += eyebrow * -0.0003 |
|
|
|
return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll]) |
|
|
|
|
|
class ExpressionSet: |
|
def __init__(self, erst=None, es=None): |
|
if es is not None: |
|
self.e = copy.deepcopy(es.e) |
|
self.r = copy.deepcopy(es.r) |
|
self.s = copy.deepcopy(es.s) |
|
self.t = copy.deepcopy(es.t) |
|
elif erst is not None: |
|
self.e = erst[0] |
|
self.r = erst[1] |
|
self.s = erst[2] |
|
self.t = erst[3] |
|
else: |
|
self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(self.get_device()) |
|
self.r = torch.Tensor([0, 0, 0]) |
|
self.s = 0 |
|
self.t = 0 |
|
|
|
def div(self, value): |
|
self.e /= value |
|
self.r /= value |
|
self.s /= value |
|
self.t /= value |
|
|
|
def add(self, other): |
|
self.e += other.e |
|
self.r += other.r |
|
self.s += other.s |
|
self.t += other.t |
|
|
|
def sub(self, other): |
|
self.e -= other.e |
|
self.r -= other.r |
|
self.s -= other.s |
|
self.t -= other.t |
|
|
|
def mul(self, value): |
|
self.e *= value |
|
self.r *= value |
|
self.s *= value |
|
self.t *= value |
|
|
|
@staticmethod |
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
|
|
def logging_time(original_fn): |
|
def wrapper_fn(*args, **kwargs): |
|
start_time = time.time() |
|
result = original_fn(*args, **kwargs) |
|
end_time = time.time() |
|
print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time)) |
|
return result |
|
|
|
return wrapper_fn |
|
|
|
|
|
def save_exp_data(file_name: str, save_exp: ExpressionSet = None): |
|
if save_exp is None or not file_name: |
|
return file_name |
|
|
|
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), "wb") as f: |
|
dill.dump(save_exp, f) |
|
|
|
return file_name |
|
|
|
|
|
def load_exp_data(self, file_name, ratio): |
|
file_list = [os.path.splitext(file)[0] for file in os.listdir(EXP_OUTPUT_DIR) if file.endswith('.exp')] |
|
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), 'rb') as f: |
|
es = dill.load(f) |
|
es.mul(ratio) |
|
return es |
|
|
|
|
|
def handle_exp_data(code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None): |
|
if add_exp is None: |
|
es = ExpressionSet() |
|
else: |
|
es = ExpressionSet(es=add_exp) |
|
|
|
codes = [code1, code2, code3, code4, code5] |
|
values = [value1, value2, value3, value4, value5] |
|
for i in range(5): |
|
idx = int(codes[i] / 10) |
|
r = codes[i] % 10 |
|
es.e[0, idx, r] += values[i] * 0.001 |
|
|
|
return es |
|
|
|
|
|
def print_exp_data(cut_noise, exp=None): |
|
if exp is None: |
|
return exp |
|
|
|
cuted_list = [] |
|
e = exp.exp * 1000 |
|
for idx in range(21): |
|
for r in range(3): |
|
a = abs(e[0, idx, r]) |
|
if (cut_noise < a): cuted_list.append((a, e[0, idx, r], idx * 10 + r)) |
|
|
|
sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0]) |
|
print(f"sorted_list: {[[item[2], round(float(item[1]), 1)] for item in sorted_list]}") |
|
return exp |
|
|
|
|
|
class Command: |
|
def __init__(self, |
|
es: ExpressionSet, |
|
change, |
|
keep): |
|
self.es = es |
|
self.change = change |
|
self.keep = keep |
|
|
|
|