Spaces:
Running
Running
import os | |
import numpy as np | |
import cv2 | |
from moviepy.editor import VideoFileClip | |
from .face_det import FaceAnalysis | |
from .super_resolution import BSRGAN | |
from dofaker.face_swap import get_swapper_model | |
from dofaker.face_enhance import GFPGAN | |
class FaceSwapper: | |
def __init__(self, | |
face_det_model='buffalo_l', | |
face_swap_model='inswapper', | |
image_sr_model='bsrgan', | |
face_enhance_model='gfpgan', | |
face_det_model_dir='weights/models', | |
face_swap_model_dir='weights/models', | |
image_sr_model_dir='weights/models', | |
face_enhance_model_dir='weights/models', | |
face_sim_thre=0.5, | |
log_iters=10, | |
use_enhancer=True, | |
use_sr=True, | |
scale=1): | |
self.face_sim_thre = face_sim_thre | |
self.log_iters = log_iters | |
self.det_model = FaceAnalysis(name=face_det_model, | |
root=face_det_model_dir) | |
self.det_model.prepare(ctx_id=1, det_size=(640, 640)) | |
self.swapper_model = get_swapper_model(name=face_swap_model, | |
root=face_swap_model_dir) | |
if use_enhancer: | |
self.face_enhance = GFPGAN(name=face_enhance_model, | |
root=face_enhance_model_dir) | |
else: | |
self.face_enhance = None | |
if use_sr: | |
self.sr = BSRGAN(name=image_sr_model, | |
root=image_sr_model_dir, | |
scale=scale) | |
self.scale = scale | |
else: | |
self.sr = None | |
self.scale = scale | |
def run(self, | |
input_path: str, | |
dst_face_paths, | |
src_face_paths, | |
output_dir='output'): | |
if isinstance(dst_face_paths, str): | |
dst_face_paths = [dst_face_paths] | |
if isinstance(src_face_paths, str): | |
src_face_paths = [src_face_paths] | |
if input_path.lower().endswith(('jpg', 'jpeg', 'webp', 'png', 'bmp')): | |
return self.swap_image(input_path, dst_face_paths, src_face_paths, | |
output_dir) | |
else: | |
return self.swap_video(input_path, dst_face_paths, src_face_paths, | |
output_dir) | |
def swap_video(self, | |
input_video_path, | |
dst_face_paths, | |
src_face_paths, | |
output_dir='output'): | |
assert os.path.exists( | |
input_video_path), 'The input video path {} not exist.' | |
os.makedirs(output_dir, exist_ok=True) | |
src_faces = self.get_faces(src_face_paths) | |
if dst_face_paths is not None: | |
dst_faces = self.get_faces(dst_face_paths) | |
dst_face_embeddings = self.get_faces_embeddings(dst_faces) | |
assert len(dst_faces) == len( | |
src_faces | |
), 'The detected faces in source images not equal target image faces.' | |
video = cv2.VideoCapture(input_video_path) | |
fps = video.get(cv2.CAP_PROP_FPS) | |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frame_size = (width, height) | |
print('video fps: {}, total_frames: {}, width: {}, height: {}'.format( | |
fps, total_frames, width, height)) | |
video_name = os.path.basename(input_video_path).split('.')[0] | |
four_cc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') | |
temp_video_path = os.path.join(output_dir, | |
'temp_{}.mp4'.format(video_name)) | |
save_video_path = os.path.join(output_dir, '{}.mp4'.format(video_name)) | |
output_video = cv2.VideoWriter( | |
temp_video_path, four_cc, fps, | |
(int(frame_size[0] * self.scale), int(frame_size[1] * self.scale))) | |
i = 0 | |
while video.isOpened(): | |
ret, frame = video.read() | |
if ret: | |
if dst_face_paths is not None: | |
swapped_image = self.swap_faces(frame, | |
dst_face_embeddings, | |
src_faces=src_faces) | |
else: | |
swapped_image = self.swap_all_faces(frame, | |
src_faces=src_faces) | |
i += 1 | |
if i % self.log_iters == 0: | |
print('processing {}/{}'.format(i, total_frames)) | |
output_video.write(swapped_image) | |
else: | |
break | |
video.release() | |
output_video.release() | |
self.add_audio_to_video(input_video_path, temp_video_path, | |
save_video_path) | |
os.remove(temp_video_path) | |
return save_video_path | |
def swap_image(self, | |
image_path, | |
dst_face_paths, | |
src_face_paths, | |
output_dir='output'): | |
os.makedirs(output_dir, exist_ok=True) | |
src_faces = self.get_faces(src_face_paths) | |
if dst_face_paths is not None: | |
dst_faces = self.get_faces(dst_face_paths) | |
dst_face_embeddings = self.get_faces_embeddings(dst_faces) | |
assert len(dst_faces) == len( | |
src_faces | |
), 'The detected faces in source images not equal target image faces.' | |
image = cv2.imread(image_path) | |
if dst_face_paths is not None: | |
swapped_image = self.swap_faces(image, | |
dst_face_embeddings, | |
src_faces=src_faces) | |
else: | |
swapped_image = self.swap_all_faces(image, src_faces=src_faces) | |
base_name = os.path.basename(image_path) | |
save_path = os.path.join(output_dir, base_name) | |
cv2.imwrite(save_path, swapped_image) | |
return save_path | |
def add_audio_to_video(self, src_video_path, target_video_path, | |
save_video_path): | |
audio = VideoFileClip(src_video_path).audio | |
target_video = VideoFileClip(target_video_path) | |
target_video = target_video.set_audio(audio) | |
target_video.write_videofile(save_video_path) | |
return target_video_path | |
def get_faces(self, image_paths): | |
if isinstance(image_paths, str): | |
image_paths = [image_paths] | |
faces = [] | |
for image_path in image_paths: | |
image = cv2.imread(image_path) | |
assert image is not None, "the source image is None, please check your image {} format.".format( | |
image_path) | |
img_faces = self.det_model.get(image, max_num=1) | |
assert len( | |
img_faces | |
) == 1, 'The detected face in image {} must be 1, but got {}, please ensure your image including one face.'.format( | |
image_path, len(img_faces)) | |
faces += img_faces | |
return faces | |
def swap_faces(self, image, dst_face_embeddings: np.ndarray, | |
src_faces: list) -> np.ndarray: | |
res = image.copy() | |
image_faces = self.det_model.get(image) | |
if len(image_faces) == 0: | |
return res | |
image_face_embeddings = self.get_faces_embeddings(image_faces) | |
sim = np.dot(dst_face_embeddings, image_face_embeddings.T) | |
for i in range(dst_face_embeddings.shape[0]): | |
index = np.where(sim[i] > self.face_sim_thre)[0].tolist() | |
for idx in index: | |
res = self.swapper_model.get(res, | |
image_faces[idx], | |
src_faces[i], | |
paste_back=True) | |
if self.face_enhance is not None: | |
res = self.face_enhance.get(res, | |
image_faces[idx], | |
paste_back=True) | |
if self.sr is not None: | |
res = self.sr.get(res, image_format='bgr') | |
return res | |
def swap_all_faces(self, image, src_faces: list) -> np.ndarray: | |
assert len( | |
src_faces | |
) == 1, 'If replace all faces in source, the number of src face should be 1, but got {}.'.format( | |
len(src_faces)) | |
res = image.copy() | |
image_faces = self.det_model.get(image) | |
if len(image_faces) == 0: | |
return res | |
for image_face in image_faces: | |
res = self.swapper_model.get(res, | |
image_face, | |
src_faces[0], | |
paste_back=True) | |
if self.face_enhance is not None: | |
res = self.face_enhance.get(res, image_face, paste_back=True) | |
if self.sr is not None: | |
res = self.sr.get(res, image_format='bgr') | |
return res | |
def get_faces_embeddings(self, faces): | |
feats = [] | |
for face in faces: | |
feats.append(face.normed_embedding) | |
if len(feats) == 1: | |
feats = np.array(feats, dtype=np.float32).reshape(1, -1) | |
else: | |
feats = np.array(feats, dtype=np.float32) | |
return feats | |