|
"""run bash scripts/download_models.sh first to prepare the weights file"""
|
|
import os
|
|
import shutil
|
|
from argparse import Namespace
|
|
from src.utils.preprocess import CropAndExtract
|
|
from src.test_audio2coeff import Audio2Coeff
|
|
from src.facerender.animate import AnimateFromCoeff
|
|
from src.generate_batch import get_data
|
|
from src.generate_facerender_batch import get_facerender_data
|
|
from cog import BasePredictor, Input, Path
|
|
|
|
checkpoints = "checkpoints"
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
def setup(self):
|
|
"""Load the model into memory to make running multiple predictions efficient"""
|
|
device = "cuda"
|
|
|
|
path_of_lm_croper = os.path.join(
|
|
checkpoints, "shape_predictor_68_face_landmarks.dat"
|
|
)
|
|
path_of_net_recon_model = os.path.join(checkpoints, "epoch_20.pth")
|
|
dir_of_BFM_fitting = os.path.join(checkpoints, "BFM_Fitting")
|
|
wav2lip_checkpoint = os.path.join(checkpoints, "wav2lip.pth")
|
|
|
|
audio2pose_checkpoint = os.path.join(checkpoints, "auido2pose_00140-model.pth")
|
|
audio2pose_yaml_path = os.path.join("src", "config", "auido2pose.yaml")
|
|
|
|
audio2exp_checkpoint = os.path.join(checkpoints, "auido2exp_00300-model.pth")
|
|
audio2exp_yaml_path = os.path.join("src", "config", "auido2exp.yaml")
|
|
|
|
free_view_checkpoint = os.path.join(
|
|
checkpoints, "facevid2vid_00189-model.pth.tar"
|
|
)
|
|
|
|
|
|
self.preprocess_model = CropAndExtract(
|
|
path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device
|
|
)
|
|
|
|
self.audio_to_coeff = Audio2Coeff(
|
|
audio2pose_checkpoint,
|
|
audio2pose_yaml_path,
|
|
audio2exp_checkpoint,
|
|
audio2exp_yaml_path,
|
|
wav2lip_checkpoint,
|
|
device,
|
|
)
|
|
|
|
self.animate_from_coeff = {
|
|
"full": AnimateFromCoeff(
|
|
free_view_checkpoint,
|
|
os.path.join(checkpoints, "mapping_00109-model.pth.tar"),
|
|
os.path.join("src", "config", "facerender_still.yaml"),
|
|
device,
|
|
),
|
|
"others": AnimateFromCoeff(
|
|
free_view_checkpoint,
|
|
os.path.join(checkpoints, "mapping_00229-model.pth.tar"),
|
|
os.path.join("src", "config", "facerender.yaml"),
|
|
device,
|
|
),
|
|
}
|
|
|
|
def predict(
|
|
self,
|
|
source_image: Path = Input(
|
|
description="Upload the source image, it can be video.mp4 or picture.png",
|
|
),
|
|
driven_audio: Path = Input(
|
|
description="Upload the driven audio, accepts .wav and .mp4 file",
|
|
),
|
|
enhancer: str = Input(
|
|
description="Choose a face enhancer",
|
|
choices=["gfpgan", "RestoreFormer"],
|
|
default="gfpgan",
|
|
),
|
|
preprocess: str = Input(
|
|
description="how to preprocess the images",
|
|
choices=["crop", "resize", "full"],
|
|
default="full",
|
|
),
|
|
ref_eyeblink: Path = Input(
|
|
description="path to reference video providing eye blinking",
|
|
default=None,
|
|
),
|
|
ref_pose: Path = Input(
|
|
description="path to reference video providing pose",
|
|
default=None,
|
|
),
|
|
still: bool = Input(
|
|
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
|
|
default=True,
|
|
),
|
|
) -> Path:
|
|
"""Run a single prediction on the model"""
|
|
|
|
animate_from_coeff = (
|
|
self.animate_from_coeff["full"]
|
|
if preprocess == "full"
|
|
else self.animate_from_coeff["others"]
|
|
)
|
|
|
|
args = load_default()
|
|
args.pic_path = str(source_image)
|
|
args.audio_path = str(driven_audio)
|
|
device = "cuda"
|
|
args.still = still
|
|
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
|
|
args.ref_pose = None if ref_pose is None else str(ref_pose)
|
|
|
|
|
|
results_dir = "results"
|
|
if os.path.exists(results_dir):
|
|
shutil.rmtree(results_dir)
|
|
os.makedirs(results_dir)
|
|
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
|
|
os.makedirs(first_frame_dir)
|
|
|
|
print("3DMM Extraction for source image")
|
|
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
|
|
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
|
|
)
|
|
if first_coeff_path is None:
|
|
print("Can't get the coeffs of the input")
|
|
return
|
|
|
|
if ref_eyeblink is not None:
|
|
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
|
|
0
|
|
]
|
|
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
|
|
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
|
print("3DMM Extraction for the reference video providing eye blinking")
|
|
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
|
|
ref_eyeblink, ref_eyeblink_frame_dir
|
|
)
|
|
else:
|
|
ref_eyeblink_coeff_path = None
|
|
|
|
if ref_pose is not None:
|
|
if ref_pose == ref_eyeblink:
|
|
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
|
else:
|
|
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
|
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
|
|
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
|
print("3DMM Extraction for the reference video providing pose")
|
|
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
|
|
ref_pose, ref_pose_frame_dir
|
|
)
|
|
else:
|
|
ref_pose_coeff_path = None
|
|
|
|
|
|
batch = get_data(
|
|
first_coeff_path,
|
|
args.audio_path,
|
|
device,
|
|
ref_eyeblink_coeff_path,
|
|
still=still,
|
|
)
|
|
coeff_path = self.audio_to_coeff.generate(
|
|
batch, results_dir, args.pose_style, ref_pose_coeff_path
|
|
)
|
|
|
|
print("coeff2video")
|
|
data = get_facerender_data(
|
|
coeff_path,
|
|
crop_pic_path,
|
|
first_coeff_path,
|
|
args.audio_path,
|
|
args.batch_size,
|
|
args.input_yaw,
|
|
args.input_pitch,
|
|
args.input_roll,
|
|
expression_scale=args.expression_scale,
|
|
still_mode=still,
|
|
preprocess=preprocess,
|
|
)
|
|
animate_from_coeff.generate(
|
|
data, results_dir, args.pic_path, crop_info,
|
|
enhancer=enhancer, background_enhancer=args.background_enhancer,
|
|
preprocess=preprocess)
|
|
|
|
output = "/tmp/out.mp4"
|
|
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
|
|
shutil.copy(mp4_path, output)
|
|
|
|
return Path(output)
|
|
|
|
|
|
def load_default():
|
|
return Namespace(
|
|
pose_style=0,
|
|
batch_size=2,
|
|
expression_scale=1.0,
|
|
input_yaw=None,
|
|
input_pitch=None,
|
|
input_roll=None,
|
|
background_enhancer=None,
|
|
face3dvis=False,
|
|
net_recon="resnet50",
|
|
init_path=None,
|
|
use_last_fc=False,
|
|
bfm_folder="./checkpoints/BFM_Fitting/",
|
|
bfm_model="BFM_model_front.mat",
|
|
focal=1015.0,
|
|
center=112.0,
|
|
camera_d=10.0,
|
|
z_near=5.0,
|
|
z_far=15.0,
|
|
)
|
|
|