Spaces:
Running
Running
import os | |
import uuid | |
import glob | |
import shutil | |
from pathlib import Path | |
from multiprocessing.pool import Pool | |
import gradio as gr | |
import torch | |
from torchvision import transforms | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import tqdm | |
from modules.networks.faceshifter import FSGenerator | |
from inference.alignment import norm_crop, norm_crop_with_M, paste_back | |
from inference.utils import save, get_5_from_98, get_detector, get_lmk | |
from third_party.PIPNet.lib.tools import get_lmk_model, demo_image | |
from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark | |
from inference.tricks import Trick | |
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn)) | |
fs_model_name = 'faceshifter' | |
in_size = 256 | |
mouth_net_param = { | |
"use": True, | |
"feature_dim": 128, | |
"crop_param": (28, 56, 84, 112), | |
"weight_path": make_abs_path("./weights/arcface/mouth_net_28_56_84_112.pth"), | |
} | |
trick = Trick() | |
T = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize(0.5, 0.5), | |
] | |
) | |
tensor2pil_transform = transforms.ToPILImage() | |
def extract_generator(ckpt: str, pt: str): | |
print(f'[extract_generator] loading ckpt...') | |
from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL | |
import yaml | |
with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
config['mouth_net'] = mouth_net_param | |
if in_size == 256: | |
net = FaceshifterPL(n_layers=3, num_D=3, config=config) | |
elif in_size == 512: | |
net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False) | |
else: | |
raise ValueError('Not supported in_size.') | |
checkpoint = torch.load(ckpt, map_location="cpu", ) | |
net.load_state_dict(checkpoint["state_dict"], strict=False) | |
net.eval() | |
G = net.generator | |
torch.save(G.state_dict(), pt) | |
print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}') | |
''' load model ''' | |
if fs_model_name == 'faceshifter': | |
pt_path = make_abs_path("./weights/extracted/G_mouth1_t38_post.pth") | |
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth") | |
# ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt" | |
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth") | |
# ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt" | |
if not os.path.exists(pt_path) or 't512' in pt_path: | |
extract_generator(ckpt_path, pt_path) | |
fs_model = FSGenerator( | |
make_abs_path("./weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"), | |
mouth_net_param=mouth_net_param, | |
in_size=in_size, | |
downup=in_size == 512, | |
) | |
fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True) | |
fs_model.eval() | |
def infer_batch_to_img(i_s, i_t, post: bool = False): | |
i_r = fs_model(i_s, i_t)[0] # x, id_vector, att | |
if post: | |
target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) | |
target_hair_mask = trick.smooth_mask(target_hair_mask) | |
i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r | |
i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r | |
img_r = trick.tensor_to_arr(i_r)[0] | |
return img_r | |
elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla': | |
from modules.networks.simswap import Generator_Adain_Upsample | |
sw_model = Generator_Adain_Upsample( | |
input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False, | |
mouth_net_param=mouth_net_param | |
) | |
if fs_model_name == 'simswap_triplet': | |
pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth") | |
ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" | |
"simswap_triplet_5/epoch=12-step=782999.ckpt") | |
elif fs_model_name == 'simswap_vanilla': | |
pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth") | |
ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/" | |
"simswap_vanilla_4/epoch=694-step=1487999.ckpt") | |
else: | |
pt_path = None | |
ckpt_path = None | |
sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False) | |
sw_model.eval() | |
fs_model = sw_model | |
from trainer.simswap.simswap_pl import SimSwapPL | |
import yaml | |
with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
config['mouth_net'] = mouth_net_param | |
net = SimSwapPL(config=config, use_official_arc='off' in pt_path) | |
checkpoint = torch.load(ckpt_path, map_location="cpu") | |
net.load_state_dict(checkpoint["state_dict"], strict=False) | |
net.eval() | |
sw_mouth_net = net.mouth_net # maybe None | |
sw_netArc = net.netArc | |
fs_model = fs_model.cuda() | |
sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net | |
sw_netArc = sw_netArc.cuda() | |
def infer_batch_to_img(i_s, i_t, post: bool = False): | |
i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,) | |
if post: | |
target_hair_mask = trick.get_any_mask(i_t, par=[0, 17]) | |
target_hair_mask = trick.smooth_mask(target_hair_mask) | |
i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r | |
i_r = i_r.clamp(-1, 1) | |
i_r = trick.tensor_to_arr(i_r)[0] | |
return i_r | |
elif fs_model_name == 'simswap_official': | |
from simswap.image_infer import SimSwapOfficialImageInfer | |
fs_model = SimSwapOfficialImageInfer() | |
pt_path = 'Simswap Official' | |
mouth_net_param = { | |
"use": False | |
} | |
def infer_batch_to_img(i_s, i_t): | |
i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t) | |
i_r = i_r.clamp(-1, 1) | |
return i_r | |
else: | |
raise ValueError('Not supported fs_model_name.') | |
print(f'[demo] model loaded from {pt_path}') | |
def swap_image( | |
source_image, | |
target_path, | |
out_path, | |
transform, | |
G, | |
align_source="arcface", | |
align_target="set1", | |
gpu_mode=True, | |
paste_back=True, | |
use_post=False, | |
use_gpen=False, | |
in_size=256, | |
): | |
name = target_path.split("/")[-1] | |
name = "out_" + name | |
if isinstance(G, torch.nn.Module): | |
G.eval() | |
if gpu_mode: | |
G = G.cuda() | |
source_img = np.array(Image.open(source_image).convert("RGB")) | |
net, detector = get_lmk_model() | |
lmk = get_5_from_98(demo_image(source_img, net, detector)[0]) | |
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) | |
source_img = transform(source_img).unsqueeze(0) | |
target = np.array(Image.open(target_path).convert("RGB")) | |
original_target = target.copy() | |
lmk = get_5_from_98(demo_image(target, net, detector)[0]) | |
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) | |
target = transform(target).unsqueeze(0) | |
if gpu_mode: | |
target = target.cuda() | |
source_img = source_img.cuda() | |
cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1]) | |
cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1]) | |
# both inputs should be 512 | |
result = infer_batch_to_img(source_img, target, post=use_post) | |
cv2.imwrite('result.png', result[:, :, ::-1]) | |
os.makedirs(out_path, exist_ok=True) | |
Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name)) | |
save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None), | |
trick=trick, use_post=use_gpen) | |
def process_video( | |
source_image, | |
target_path, | |
out_path, | |
transform, | |
G, | |
align_source="arcface", | |
align_target="set1", | |
gpu_mode=True, | |
frames=9999999, | |
use_tddfav2=False, | |
landmark_smooth="kalman", | |
): | |
if isinstance(G, torch.nn.Module): | |
G.eval() | |
if gpu_mode: | |
G = G.cuda() | |
''' Target video to frames (.png) ''' | |
fps = 25.0 | |
if not os.path.isdir(target_path): | |
vidcap = cv2.VideoCapture(target_path) | |
fps = vidcap.get(cv2.CAP_PROP_FPS) | |
try: | |
for match in glob.glob(os.path.join("./tmp/", "*.png")): | |
os.remove(match) | |
for match in glob.glob(os.path.join(out_path, "*.png")): | |
os.remove(match) | |
except Exception as e: | |
print(e) | |
os.makedirs("./tmp/", exist_ok=True) | |
os.system( | |
f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 ./tmp/frame_%05d.png" | |
) | |
target_path = "./tmp/" | |
globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png"))) | |
''' Get target landmarks ''' | |
print('[Extracting target landmarks...]') | |
if not use_tddfav2: | |
align_net, align_detector = get_lmk_model() | |
else: | |
align_net, align_detector = get_detector(gpu_mode=gpu_mode) | |
target_lmks = [] | |
for frame_path in tqdm.tqdm(globbed_images): | |
target = np.array(Image.open(frame_path).convert("RGB")) | |
lmk = demo_image(target, align_net, align_detector) | |
lmk = lmk[0] | |
target_lmks.append(lmk) | |
''' Landmark smoothing ''' | |
target_lmks = np.array(target_lmks, np.float32) # (#frames, 98, 2) | |
if landmark_smooth == 'kalman': | |
target_lmks = kalman_filter_landmark(target_lmks, | |
process_noise=0.01, | |
measure_noise=0.01).astype(np.int) | |
elif landmark_smooth == 'savgol': | |
target_lmks = savgol_filter_landmark(target_lmks).astype(np.int) | |
elif landmark_smooth == 'cancel': | |
target_lmks = target_lmks.astype(np.int) | |
else: | |
raise KeyError('Not supported landmark_smooth choice') | |
''' Crop source image ''' | |
source_img = np.array(Image.open(source_image).convert("RGB")) | |
if not use_tddfav2: | |
lmk = get_5_from_98(demo_image(source_img, align_net, align_detector)[0]) | |
else: | |
lmk = get_lmk(source_img, align_net, align_detector) | |
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0) | |
source_img = transform(source_img).unsqueeze(0) | |
if gpu_mode: | |
source_img = source_img.cuda() | |
''' Process by frames ''' | |
targets = [] | |
t_facial_masks = [] | |
Ms = [] | |
original_frames = [] | |
names = [] | |
count = 0 | |
for image in tqdm.tqdm(globbed_images): | |
names.append(os.path.join(out_path, Path(image).name)) | |
target = np.array(Image.open(image).convert("RGB")) | |
original_frames.append(target) | |
''' Crop target frames ''' | |
lmk = get_5_from_98(target_lmks[count]) | |
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0) | |
target = transform(target).unsqueeze(0) # in [-1,1] | |
if gpu_mode: | |
target = target.cuda() | |
''' Finetune paste masks ''' | |
target_facial_mask = trick.get_any_mask(target, | |
par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze() # in [0,1] | |
target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float) | |
target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks) # in [0,1] | |
t_facial_masks.append(target_facial_mask) | |
''' Face swapping ''' | |
with torch.no_grad(): | |
if 'faceshifter' in fs_model_name: | |
output = G(source_img, target) | |
target_hair_mask = trick.get_any_mask(target, par=[0, 17]) | |
target_hair_mask = trick.smooth_mask(target_hair_mask) | |
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output | |
output = trick.finetune_mouth(source_img, target, output) | |
elif 'simswap' in fs_model_name and 'official' not in fs_model_name: | |
output = fs_model(source=source_img, target=target, | |
net_arc=sw_netArc, mouth_net=sw_mouth_net,) | |
if 'vanilla' not in fs_model_name: | |
target_hair_mask = trick.get_any_mask(target, par=[0, 17]) | |
target_hair_mask = trick.smooth_mask(target_hair_mask) | |
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output | |
output = trick.finetune_mouth(source_img, target, output) | |
output = output.clamp(-1, 1) | |
elif 'simswap_official' in fs_model_name: | |
output = fs_model.image_infer(source_tensor=source_img, target_tensor=target) | |
output = output.clamp(-1, 1) | |
if isinstance(output, tuple): | |
target = output[0][0] * 0.5 + 0.5 | |
else: | |
target = output[0] * 0.5 + 0.5 | |
targets.append(np.array(tensor2pil_transform(target))) | |
Ms.append(M) | |
count += 1 | |
if count > frames: | |
break | |
os.makedirs(out_path, exist_ok=True) | |
return targets, t_facial_masks, Ms, original_frames, names, fps | |
def swap_image_gr(img1, img2, use_post=False, use_gpen=False, ): | |
root_dir = make_abs_path("./online_data") | |
req_id = uuid.uuid1().hex | |
data_dir = os.path.join(root_dir, req_id) | |
os.makedirs(data_dir, exist_ok=True) | |
source_path = os.path.join(data_dir, "source.png") | |
target_path = os.path.join(data_dir, "target.png") | |
filename = "paste_back_out_target.png" | |
out_path = os.path.join(data_dir, filename) | |
cv2.imwrite(source_path, img1[:, :, ::-1]) | |
cv2.imwrite(target_path, img2[:, :, ::-1]) | |
swap_image( | |
source_path, | |
target_path, | |
data_dir, | |
T, | |
fs_model, | |
gpu_mode=use_gpu, | |
align_target='ffhq', | |
align_source='ffhq', | |
use_post=use_post, | |
use_gpen=use_gpen, | |
in_size=in_size, | |
) | |
out = cv2.imread(out_path)[..., ::-1] | |
return out | |
def swap_video_gr(img1, target_path, frames=9999999): | |
root_dir = make_abs_path("./online_data") | |
req_id = uuid.uuid1().hex | |
data_dir = os.path.join(root_dir, req_id) | |
os.makedirs(data_dir, exist_ok=True) | |
source_path = os.path.join(data_dir, "source.png") | |
cv2.imwrite(source_path, img1[:, :, ::-1]) | |
out_dir = os.path.join(data_dir, "out") | |
out_name = "output.mp4" | |
targets, t_facial_masks, Ms, original_frames, names, fps = process_video( | |
source_path, | |
target_path, | |
out_dir, | |
T, | |
fs_model, | |
gpu_mode=use_gpu, | |
frames=frames, | |
align_target='ffhq', | |
align_source='ffhq', | |
use_tddfav2=False, | |
) | |
pool_process = 170 | |
audio = True | |
concat = False | |
if pool_process <= 1: | |
for target, M, original_target, name, t_facial_mask in tqdm.tqdm( | |
zip(targets, Ms, original_frames, names, t_facial_masks) | |
): | |
if M is None or target is None: | |
Image.fromarray(original_target.astype(np.uint8)).save(name) | |
continue | |
Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name) | |
else: | |
with Pool(pool_process) as pool: | |
pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks)) | |
video_save_path = os.path.join(out_dir, out_name) | |
if audio: | |
print("use audio") | |
os.system( | |
f"ffmpeg -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}" | |
f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" | |
) | |
else: | |
print("no audio") | |
os.system( | |
f"ffmpeg -y -r {fps} -i ./tmp/frame_%05d.png " | |
f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}" | |
) | |
# ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4 | |
if concat: | |
concat_video_save_path = os.path.join(out_dir, "concat_" + out_name) | |
os.system( | |
f"ffmpeg -y -i {target_path} -i {video_save_path} -filter_complex hstack {concat_video_save_path}" | |
) | |
# delete tmp file | |
shutil.rmtree("./tmp/") | |
for match in glob.glob(os.path.join(out_dir, "*.png")): | |
os.remove(match) | |
print(video_save_path) | |
return video_save_path | |
if __name__ == "__main__": | |
use_gpu = torch.cuda.is_available() | |
with gr.Blocks() as demo: | |
gr.Markdown("SuperSwap") | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image1_input = gr.Image(label='source') | |
image2_input = gr.Image(label='target') | |
use_post = gr.Checkbox(label="Post-Process") | |
use_gpen = gr.Checkbox(label="Super Resolution") | |
with gr.Column(scale=2): | |
image_output = gr.Image() | |
image_button = gr.Button("换脸") | |
with gr.Tab("Video"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image3_input = gr.Image(label='source') | |
video_input = gr.Video(label='target') | |
with gr.Column(scale=2): | |
video_output = gr.Video() | |
video_button = gr.Button("换脸") | |
image_button.click( | |
swap_image_gr, | |
inputs=[image1_input, image2_input, use_post, use_gpen], | |
outputs=image_output, | |
) | |
video_button.click( | |
swap_video_gr, | |
inputs=[image3_input, video_input], | |
outputs=video_output, | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |