Spaces:
Running
on
T4
Running
on
T4
File size: 12,071 Bytes
983684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import argparse
import numpy as np
import cv2
import dlib
import torch
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm
from model.vtoonify import VToonify
from model.bisenet.model import BiSeNet
from model.encoder.align_all_parallel import align_face
from util import save_image, load_image, visualize, load_psp_standalone, get_video_crop_parameter, tensor2cv2
class TestOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(description="Style Transfer")
self.parser.add_argument("--content", type=str, default='./data/077436.jpg', help="path of the content image/video")
self.parser.add_argument("--style_id", type=int, default=26, help="the id of the style image")
self.parser.add_argument("--style_degree", type=float, default=0.5, help="style degree for VToonify-D")
self.parser.add_argument("--color_transfer", action="store_true", help="transfer the color of the style")
self.parser.add_argument("--ckpt", type=str, default='./checkpoint/vtoonify_d_cartoon/vtoonify_s_d.pt', help="path of the saved model")
self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output images")
self.parser.add_argument("--scale_image", action="store_true", help="resize and crop the image to best fit the model")
self.parser.add_argument("--style_encoder_path", type=str, default='./checkpoint/encoder.pt', help="path of the style encoder")
self.parser.add_argument("--exstyle_path", type=str, default=None, help="path of the extrinsic style code")
self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model")
self.parser.add_argument("--video", action="store_true", help="if true, video stylization; if false, image stylization")
self.parser.add_argument("--cpu", action="store_true", help="if true, only use cpu")
self.parser.add_argument("--backbone", type=str, default='dualstylegan', help="dualstylegan | toonify")
self.parser.add_argument("--padding", type=int, nargs=4, default=[200,200,200,200], help="left, right, top, bottom paddings to the face center")
self.parser.add_argument("--batch_size", type=int, default=4, help="batch size of frames when processing video")
self.parser.add_argument("--parsing_map_path", type=str, default=None, help="path of the refined parsing map of the target video")
def parse(self):
self.opt = self.parser.parse_args()
if self.opt.exstyle_path is None:
self.opt.exstyle_path = os.path.join(os.path.dirname(self.opt.ckpt), 'exstyle_code.npy')
args = vars(self.opt)
print('Load options')
for name, value in sorted(args.items()):
print('%s: %s' % (str(name), str(value)))
return self.opt
if __name__ == "__main__":
parser = TestOptions()
args = parser.parse()
print('*'*98)
device = "cpu" if args.cpu else "cuda"
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
])
vtoonify = VToonify(backbone = args.backbone)
vtoonify.load_state_dict(torch.load(args.ckpt, map_location=lambda storage, loc: storage)['g_ema'])
vtoonify.to(device)
parsingpredictor = BiSeNet(n_classes=19)
parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
parsingpredictor.to(device).eval()
modelname = './checkpoint/shape_predictor_68_face_landmarks.dat'
if not os.path.exists(modelname):
import wget, bz2
wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
zipfile = bz2.BZ2File(modelname+'.bz2')
data = zipfile.read()
open(modelname, 'wb').write(data)
landmarkpredictor = dlib.shape_predictor(modelname)
pspencoder = load_psp_standalone(args.style_encoder_path, device)
if args.backbone == 'dualstylegan':
exstyles = np.load(args.exstyle_path, allow_pickle='TRUE').item()
stylename = list(exstyles.keys())[args.style_id]
exstyle = torch.tensor(exstyles[stylename]).to(device)
with torch.no_grad():
exstyle = vtoonify.zplus2wplus(exstyle)
if args.video and args.parsing_map_path is not None:
x_p_hat = torch.tensor(np.load(args.parsing_map_path))
print('Load models successfully!')
filename = args.content
basename = os.path.basename(filename).split('.')[0]
scale = 1
kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
print('Processing ' + os.path.basename(filename) + ' with vtoonify_' + args.backbone[0])
if args.video:
cropname = os.path.join(args.output_path, basename + '_input.mp4')
savename = os.path.join(args.output_path, basename + '_vtoonify_' + args.backbone[0] + '.mp4')
video_cap = cv2.VideoCapture(filename)
num = int(video_cap.get(7))
first_valid_frame = True
batch_frames = []
for i in tqdm(range(num)):
success, frame = video_cap.read()
if success == False:
assert('load video frames error')
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# We proprocess the video by detecting the face in the first frame,
# and resizing the frame so that the eye distance is 64 pixels.
# Centered on the eyes, we crop the first frame to almost 400x400 (based on args.padding).
# All other frames use the same resizing and cropping parameters as the first frame.
if first_valid_frame:
if args.scale_image:
paras = get_video_crop_parameter(frame, landmarkpredictor, args.padding)
if paras is None:
continue
h,w,top,bottom,left,right,scale = paras
H, W = int(bottom-top), int(right-left)
# for HR video, we apply gaussian blur to the frames to avoid flickers caused by bilinear downsampling
# this can also prevent over-sharp stylization results.
if scale <= 0.75:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
if scale <= 0.375:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
else:
H, W = frame.shape[0], frame.shape[1]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(cropname, fourcc, video_cap.get(5), (W, H))
videoWriter2 = cv2.VideoWriter(savename, fourcc, video_cap.get(5), (4*W, 4*H))
# For each video, we detect and align the face in the first frame for pSp to obtain the style code.
# This style code is used for all other frames.
with torch.no_grad():
I = align_face(frame, landmarkpredictor)
I = transform(I).unsqueeze(dim=0).to(device)
s_w = pspencoder(I)
s_w = vtoonify.zplus2wplus(s_w)
if vtoonify.backbone == 'dualstylegan':
if args.color_transfer:
s_w = exstyle
else:
s_w[:,:7] = exstyle[:,:7]
first_valid_frame = False
elif args.scale_image:
if scale <= 0.75:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
if scale <= 0.375:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
batch_frames += [transform(frame).unsqueeze(dim=0).to(device)]
if len(batch_frames) == args.batch_size or (i+1) == num:
x = torch.cat(batch_frames, dim=0)
batch_frames = []
with torch.no_grad():
# parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
# followed by downsampling the parsing maps
if args.video and args.parsing_map_path is not None:
x_p = x_p_hat[i+1-x.size(0):i+1].to(device)
else:
x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
scale_factor=0.5, recompute_scale_factor=False).detach()
# we give parsing maps lower weight (1/16)
inputs = torch.cat((x, x_p/16.), dim=1)
# d_s has no effect when backbone is toonify
y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = args.style_degree)
y_tilde = torch.clamp(y_tilde, -1, 1)
for k in range(y_tilde.size(0)):
videoWriter2.write(tensor2cv2(y_tilde[k].cpu()))
videoWriter.release()
videoWriter2.release()
video_cap.release()
else:
cropname = os.path.join(args.output_path, basename + '_input.jpg')
savename = os.path.join(args.output_path, basename + '_vtoonify_' + args.backbone[0] + '.jpg')
frame = cv2.imread(filename)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# We detect the face in the image, and resize the image so that the eye distance is 64 pixels.
# Centered on the eyes, we crop the image to almost 400x400 (based on args.padding).
if args.scale_image:
paras = get_video_crop_parameter(frame, landmarkpredictor, args.padding)
if paras is not None:
h,w,top,bottom,left,right,scale = paras
H, W = int(bottom-top), int(right-left)
# for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
if scale <= 0.75:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
if scale <= 0.375:
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
with torch.no_grad():
I = align_face(frame, landmarkpredictor)
I = transform(I).unsqueeze(dim=0).to(device)
s_w = pspencoder(I)
s_w = vtoonify.zplus2wplus(s_w)
if vtoonify.backbone == 'dualstylegan':
if args.color_transfer:
s_w = exstyle
else:
s_w[:,:7] = exstyle[:,:7]
x = transform(frame).unsqueeze(dim=0).to(device)
# parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
# followed by downsampling the parsing maps
x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
scale_factor=0.5, recompute_scale_factor=False).detach()
# we give parsing maps lower weight (1/16)
inputs = torch.cat((x, x_p/16.), dim=1)
# d_s has no effect when backbone is toonify
y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = args.style_degree)
y_tilde = torch.clamp(y_tilde, -1, 1)
cv2.imwrite(cropname, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
save_image(y_tilde[0].cpu(), savename)
print('Transfer style successfully!') |