import os import uuid import glob import shutil from pathlib import Path from multiprocessing.pool import Pool import gradio as gr import torch from torchvision import transforms import cv2 import numpy as np from PIL import Image import tqdm from modules.networks.faceshifter import FSGenerator from inference.alignment import norm_crop, norm_crop_with_M, paste_back from inference.utils import save, get_5_from_98, get_detector, get_lmk from third_party.PIPNet.lib.tools import get_lmk_model, demo_image from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark from inference.tricks import Trick make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn)) fs_model_name = 'faceshifter' in_size = 256 mouth_net_param = { "use": True, "feature_dim": 128, "crop_param": (28, 56, 84, 112), "weight_path": make_abs_path("./weights/arcface/mouth_net_28_56_84_112.pth"), } trick = Trick() T = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(0.5, 0.5), ] ) tensor2pil_transform = transforms.ToPILImage() def extract_generator(ckpt: str, pt: str): print(f'[extract_generator] loading ckpt...') from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL import yaml with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) config['mouth_net'] = mouth_net_param if in_size == 256: net = FaceshifterPL(n_layers=3, num_D=3, config=config) elif in_size == 512: net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False) else: raise ValueError('Not supported in_size.') checkpoint = torch.load(ckpt, map_location="cpu", ) net.load_state_dict(checkpoint["state_dict"], strict=False) net.eval() G = net.generator torch.save(G.state_dict(), pt) print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}') ''' load model ''' if fs_model_name == 'faceshifter': pt_path = make_abs_path("./weights/extracted/G_mouth1_t38_post.pth") # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth") # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt" # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth") # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt" if not os.path.exists(pt_path) or 't512' in pt_path: extract_generator(ckpt_path, pt_path) fs_model = FSGenerator( make_abs_path("./weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"), mouth_net_param=mouth_net_param, in_size=in_size, downup=in_size == 512, ) fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True) fs_model.eval() @torch.no_grad() def infer_batch_to_img(i_s, i_t, post: bool = False): i_r = fs_model(i_s, i_t)[0] # x, id_vector, att if post: target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) target_hair_mask = trick.smooth_mask(target_hair_mask) i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r img_r = trick.tensor_to_arr(i_r)[0] return img_r elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla': from modules.networks.simswap import Generator_Adain_Upsample sw_model = Generator_Adain_Upsample( input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False, mouth_net_param=mouth_net_param ) if fs_model_name == 'simswap_triplet': pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth") ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" "simswap_triplet_5/epoch=12-step=782999.ckpt") elif fs_model_name == 'simswap_vanilla': pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth") ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" "simswap_vanilla_4/epoch=694-step=1487999.ckpt") else: pt_path = None ckpt_path = None sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False) sw_model.eval() fs_model = sw_model from trainer.simswap.simswap_pl import SimSwapPL import yaml with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) config['mouth_net'] = mouth_net_param net = SimSwapPL(config=config, use_official_arc='off' in pt_path) checkpoint = torch.load(ckpt_path, map_location="cpu") net.load_state_dict(checkpoint["state_dict"], strict=False) net.eval() sw_mouth_net = net.mouth_net # maybe None sw_netArc = net.netArc fs_model = fs_model.cuda() sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net sw_netArc = sw_netArc.cuda() @torch.no_grad() def infer_batch_to_img(i_s, i_t, post: bool = False): i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,) if post: target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) target_hair_mask = trick.smooth_mask(target_hair_mask) i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r i_r = i_r.clamp(-1, 1) i_r = trick.tensor_to_arr(i_r)[0] return i_r elif fs_model_name == 'simswap_official': from simswap.image_infer import SimSwapOfficialImageInfer fs_model = SimSwapOfficialImageInfer() pt_path = 'Simswap Official' mouth_net_param = { "use": False } @torch.no_grad() def infer_batch_to_img(i_s, i_t): i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t) i_r = i_r.clamp(-1, 1) return i_r else: raise ValueError('Not supported fs_model_name.') print(f'[demo] model loaded from {pt_path}') def swap_image( source_image, target_path, out_path, transform, G, align_source="arcface", align_target="set1", gpu_mode=True, paste_back=True, use_post=False, use_gpen=False, in_size=256, ): name = target_path.split("/")[-1] name = "out_" + name if isinstance(G, torch.nn.Module): G.eval() if gpu_mode: G = G.cuda() source_img = np.array(Image.open(source_image).convert("RGB")) net, detector = get_lmk_model() lmk = get_5_from_98(demo_image(source_img, net, detector)[0]) source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) source_img = transform(source_img).unsqueeze(0) target = np.array(Image.open(target_path).convert("RGB")) original_target = target.copy() lmk = get_5_from_98(demo_image(target, net, detector)[0]) target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) target = transform(target).unsqueeze(0) if gpu_mode: target = target.cuda() source_img = source_img.cuda() cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1]) cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1]) # both inputs should be 512 result = infer_batch_to_img(source_img, target, post=use_post) cv2.imwrite('result.png', result[:, :, ::-1]) os.makedirs(out_path, exist_ok=True) Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name)) save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None), trick=trick, use_post=use_gpen) def process_video( source_image, target_path, out_path, transform, G, align_source="arcface", align_target="set1", gpu_mode=True, frames=9999999, use_tddfav2=False, landmark_smooth="kalman", ): if isinstance(G, torch.nn.Module): G.eval() if gpu_mode: G = G.cuda() ''' Target video to frames (.png) ''' fps = 25.0 if not os.path.isdir(target_path): vidcap = cv2.VideoCapture(target_path) fps = vidcap.get(cv2.CAP_PROP_FPS) try: for match in glob.glob(os.path.join("./tmp/", "*.png")): os.remove(match) for match in glob.glob(os.path.join(out_path, "*.png")): os.remove(match) except Exception as e: print(e) os.makedirs("./tmp/", exist_ok=True) os.system( f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 ./tmp/frame_%05d.png" ) target_path = "./tmp/" globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png"))) ''' Get target landmarks ''' print('[Extracting target landmarks...]') if not use_tddfav2: align_net, align_detector = get_lmk_model() else: align_net, align_detector = get_detector(gpu_mode=gpu_mode) target_lmks = [] for frame_path in tqdm.tqdm(globbed_images): target = np.array(Image.open(frame_path).convert("RGB")) lmk = demo_image(target, align_net, align_detector) lmk = lmk[0] target_lmks.append(lmk) ''' Landmark smoothing ''' target_lmks = np.array(target_lmks, np.float32) # (#frames, 98, 2) if landmark_smooth == 'kalman': target_lmks = kalman_filter_landmark(target_lmks, process_noise=0.01, measure_noise=0.01).astype(np.int) elif landmark_smooth == 'savgol': target_lmks = savgol_filter_landmark(target_lmks).astype(np.int) elif landmark_smooth == 'cancel': target_lmks = target_lmks.astype(np.int) else: raise KeyError('Not supported landmark_smooth choice') ''' Crop source image ''' source_img = np.array(Image.open(source_image).convert("RGB")) if not use_tddfav2: lmk = get_5_from_98(demo_image(source_img, align_net, align_detector)[0]) else: lmk = get_lmk(source_img, align_net, align_detector) source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) source_img = transform(source_img).unsqueeze(0) if gpu_mode: source_img = source_img.cuda() ''' Process by frames ''' targets = [] t_facial_masks = [] Ms = [] original_frames = [] names = [] count = 0 for image in tqdm.tqdm(globbed_images): names.append(os.path.join(out_path, Path(image).name)) target = np.array(Image.open(image).convert("RGB")) original_frames.append(target) ''' Crop target frames ''' lmk = get_5_from_98(target_lmks[count]) target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) target = transform(target).unsqueeze(0) # in [-1,1] if gpu_mode: target = target.cuda() ''' Finetune paste masks ''' target_facial_mask = trick.get_any_mask(target, par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze() # in [0,1] target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float) target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks) # in [0,1] t_facial_masks.append(target_facial_mask) ''' Face swapping ''' with torch.no_grad(): if 'faceshifter' in fs_model_name: output = G(source_img, target) target_hair_mask = trick.get_any_mask(target, par=[0, 17]) target_hair_mask = trick.smooth_mask(target_hair_mask) output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output output = trick.finetune_mouth(source_img, target, output) elif 'simswap' in fs_model_name and 'official' not in fs_model_name: output = fs_model(source=source_img, target=target, net_arc=sw_netArc, mouth_net=sw_mouth_net,) if 'vanilla' not in fs_model_name: target_hair_mask = trick.get_any_mask(target, par=[0, 17]) target_hair_mask = trick.smooth_mask(target_hair_mask) output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output output = trick.finetune_mouth(source_img, target, output) output = output.clamp(-1, 1) elif 'simswap_official' in fs_model_name: output = fs_model.image_infer(source_tensor=source_img, target_tensor=target) output = output.clamp(-1, 1) if isinstance(output, tuple): target = output[0][0] * 0.5 + 0.5 else: target = output[0] * 0.5 + 0.5 targets.append(np.array(tensor2pil_transform(target))) Ms.append(M) count += 1 if count > frames: break os.makedirs(out_path, exist_ok=True) return targets, t_facial_masks, Ms, original_frames, names, fps def swap_image_gr(img1, img2, use_post=False, use_gpen=False, ): root_dir = make_abs_path("./online_data") req_id = uuid.uuid1().hex data_dir = os.path.join(root_dir, req_id) os.makedirs(data_dir, exist_ok=True) source_path = os.path.join(data_dir, "source.png") target_path = os.path.join(data_dir, "target.png") filename = "paste_back_out_target.png" out_path = os.path.join(data_dir, filename) cv2.imwrite(source_path, img1[:, :, ::-1]) cv2.imwrite(target_path, img2[:, :, ::-1]) swap_image( source_path, target_path, data_dir, T, fs_model, gpu_mode=use_gpu, align_target='ffhq', align_source='ffhq', use_post=use_post, use_gpen=use_gpen, in_size=in_size, ) out = cv2.imread(out_path)[..., ::-1] return out def swap_video_gr(img1, target_path, frames=9999999): root_dir = make_abs_path("./online_data") req_id = uuid.uuid1().hex data_dir = os.path.join(root_dir, req_id) os.makedirs(data_dir, exist_ok=True) source_path = os.path.join(data_dir, "source.png") cv2.imwrite(source_path, img1[:, :, ::-1]) out_dir = os.path.join(data_dir, "out") out_name = "output.mp4" targets, t_facial_masks, Ms, original_frames, names, fps = process_video( source_path, target_path, out_dir, T, fs_model, gpu_mode=use_gpu, frames=frames, align_target='ffhq', align_source='ffhq', use_tddfav2=False, ) pool_process = 170 audio = True concat = False if pool_process <= 1: for target, M, original_target, name, t_facial_mask in tqdm.tqdm( zip(targets, Ms, original_frames, names, t_facial_masks) ): if M is None or target is None: Image.fromarray(original_target.astype(np.uint8)).save(name) continue Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name) else: with Pool(pool_process) as pool: pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks)) video_save_path = os.path.join(out_dir, out_name) if audio: print("use audio") os.system( f"ffmpeg -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}" f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" ) else: print("no audio") os.system( f"ffmpeg -y -r {fps} -i ./tmp/frame_%05d.png " f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" ) # ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4 if concat: concat_video_save_path = os.path.join(out_dir, "concat_" + out_name) os.system( f"ffmpeg -y -i {target_path} -i {video_save_path} -filter_complex hstack {concat_video_save_path}" ) # delete tmp file shutil.rmtree("./tmp/") for match in glob.glob(os.path.join(out_dir, "*.png")): os.remove(match) print(video_save_path) return video_save_path if __name__ == "__main__": use_gpu = torch.cuda.is_available() with gr.Blocks() as demo: gr.Markdown("SuperSwap") with gr.Tab("Image"): with gr.Row(): with gr.Column(scale=3): image1_input = gr.Image(label='source') image2_input = gr.Image(label='target') use_post = gr.Checkbox(label="Post-Process") use_gpen = gr.Checkbox(label="Super Resolution") with gr.Column(scale=2): image_output = gr.Image() image_button = gr.Button("换脸") with gr.Tab("Video"): with gr.Row(): with gr.Column(scale=3): image3_input = gr.Image(label='source') video_input = gr.Video(label='target') with gr.Column(scale=2): video_output = gr.Video() video_button = gr.Button("换脸") image_button.click( swap_image_gr, inputs=[image1_input, image2_input, use_post, use_gpen], outputs=image_output, ) video_button.click( swap_video_gr, inputs=[image3_input, video_input], outputs=video_output, ) demo.launch(server_name="0.0.0.0", server_port=7860)