ReHiFace-S / HifiFaceAPI_parallel_trt_roi_realtime_sr_api.py
GuijiAI's picture
Upload 117 files
89cf463 verified
raw
history blame
9.43 kB
import os
import cv2
import time
import numpy as np
import numexpr as ne
from multiprocessing.dummy import Process, Queue
from options.hifi_test_options import HifiTestOptions
from HifiFaceAPI_parallel_base import Consumer0Base, Consumer2Base, Consumer3Base,Consumer1BaseONNX
from color_transfer import color_transfer
def np_norm(x):
return (x - np.average(x)) / np.std(x)
def reverse2wholeimage_hifi_trt_roi(swaped_img, mat_rev, img_mask, frame, roi_img, roi_box):
target_image = cv2.warpAffine(swaped_img, mat_rev, roi_img.shape[:2][::-1], borderMode=cv2.BORDER_REPLICATE)[
...,
::-1]
local_dict = {
'img_mask': img_mask,
'target_image': target_image,
'roi_img': roi_img,
}
img = ne.evaluate('img_mask * (target_image * 255)+(1 - img_mask) * roi_img', local_dict=local_dict,
global_dict=None)
img = img.astype(np.uint8)
frame[roi_box[1]:roi_box[3], roi_box[0]:roi_box[2]] = img
return frame
def get_max_face(np_rois):
roi_areas = []
for index in range(np_rois.shape[0]):
roi_areas.append((np_rois[index, 2] - np_rois[index, 0]) * (np_rois[index, 3] - np_rois[index, 1]))
return np.argmax(np.array(roi_areas))
class Consumer0(Consumer0Base):
def __init__(self, opt, frame_queue_in, queue_list: list, block=True, fps_counter=False, align_method='68'):
super().__init__(opt, frame_queue_in, None, queue_list, block, fps_counter)
self.align_method = align_method
def run(self):
counter = 0
start_time = time.time()
kpss_old = None
rois_old = faces_old = Ms_old = masks_old = None
while True:
frame = self.frame_queue_in.get()
if frame is None:
break
try:
_, bboxes, kpss = self.scrfd_detector.get_bboxes(frame, max_num=0)
if self.align_method == '5class':
rois, faces, Ms, masks = self.mtcnn_detector.align_multi_for_scrfd(
frame, bboxes, kpss, limit=1, min_face_size=30,
crop_size=(self.crop_size, self.crop_size), apply_roi=True, detector=None
)
else:
rois, faces, Ms, masks = self.face_alignment.forward(
frame, bboxes, kpss, limit=5, min_face_size=30,
crop_size=(self.crop_size, self.crop_size), apply_roi=True
)
except (TypeError, IndexError, ValueError) as e:
self.queue_list[0].put([None, frame])
continue
if len(faces)==0:
self.queue_list[0].put([None, frame])
continue
elif len(faces)==1:
face = np.array(faces[0])
mat = Ms[0]
roi_box = rois[0]
else:
max_index = get_max_face(np.array(rois))
face = np.array(faces[max_index])
mat = Ms[max_index]
roi_box = rois[max_index]
roi_img = frame[roi_box[1]:roi_box[3], roi_box[0]:roi_box[2]]
#The default normalization to the range of -1 to 1, where the model input is in RGB format
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
self.queue_list[0].put([face, mat, [], frame, roi_img, roi_box])
if self.fps_counter:
counter += 1
if (time.time() - start_time) > 10:
print("Consumer0 FPS: {}".format(counter / (time.time() - start_time)))
counter = 0
start_time = time.time()
self.queue_list[0].put(None)
print('co stop')
class Consumer1(Consumer1BaseONNX):
def __init__(self, opt, feature_list, queue_list: list, block=True, fps_counter=False):
super().__init__(opt, feature_list, queue_list, block, fps_counter)
def run(self):
counter = 0
start_time = time.time()
while True:
something_in = self.queue_list[0].get()
if something_in is None:
break
elif len(something_in) == 2:
self.queue_list[1].put([None, something_in[1]])
continue
if len(self.feature_list) > 1:
self.feature_list.pop(0)
image_latent = self.feature_list[0][0]
mask_out, swap_face_out = self.predict(something_in[0], image_latent[0].reshape(1, -1))
mask = cv2.warpAffine(mask_out[0][0].astype(np.float32), something_in[1],
something_in[4].shape[:2][::-1])
mask[mask > 0.2] = 1
mask = mask[:, :, np.newaxis].astype(np.uint8)
swap_face = swap_face_out[0].transpose((1, 2, 0)).astype(np.float32)
self.queue_list[1].put(
[swap_face, something_in[1], mask, something_in[3], something_in[4], something_in[5], something_in[0]])
if self.fps_counter:
counter += 1
if (time.time() - start_time) > 10:
print("Consumer1 FPS: {}".format(counter / (time.time() - start_time)))
counter = 0
start_time = time.time()
self.queue_list[1].put(None)
print('c1 stop')
class Consumer2(Consumer2Base):
def __init__(self, queue_list: list, frame_queue_out, block=True, fps_counter=False):
super().__init__(queue_list, frame_queue_out, block, fps_counter)
def forward_func(self, something_in):
if len(something_in) == 2:
frame = something_in[1]
frame_out = frame.astype(np.uint8)
else:
swap_face = ((something_in[0] + 1) / 2)
frame_out = reverse2wholeimage_hifi_trt_roi(
swap_face, something_in[1], something_in[2],
something_in[3], something_in[4], something_in[5]
)
self.frame_queue_out.put(frame_out)
# cv2.imshow('output', frame_out)
# cv2.waitKey(1)
class Consumer3(Consumer3Base):
def __init__(self, queue_list, block=True, fps_counter=False, use_gfpgan=True, sr_weight=1.0,
use_color_trans=False, color_trans_mode=''):
super().__init__(queue_list, block, fps_counter)
self.use_gfpgan = use_gfpgan
self.sr_weight = sr_weight
self.use_color_trans = use_color_trans
self.color_trans_mode = color_trans_mode
def forward_func(self, something_in):
if len(something_in) == 2:
self.queue_list[1].put([None, something_in[1]])
else:
swap_face = something_in[0]
target_face = (something_in[6] / 255).astype(np.float32)
if self.use_gfpgan:
sr_face = self.gfp.forward(swap_face)
if self.sr_weight != 1.0:
sr_face = cv2.addWeighted(sr_face, alpha=self.sr_weight, src2=swap_face, beta=1.0 - self.sr_weight, gamma=0, dtype=cv2.CV_32F)
if self.use_color_trans:
transed_face = color_transfer(self.color_trans_mode, (sr_face + 1) / 2, target_face)
result_face = (transed_face * 2) - 1
else:
result_face = sr_face
else:
if self.use_color_trans:
transed_face = color_transfer(self.color_trans_mode, (swap_face + 1) / 2, target_face)
result_face = (transed_face * 2) - 1
else:
result_face = swap_face
self.queue_list[1].put([result_face, something_in[1], something_in[2], something_in[3],
something_in[4], something_in[5]])
class HifiFaceRealTime:
def __init__(self, feature_dict_list_, frame_queue_in, frame_queue_out, gpu=True, model_name='er8_bs1', align_method='68',
use_gfpgan=True, sr_weight=1.0, use_color_trans=False, color_trans_mode='rct'):
self.opt = HifiTestOptions().parse()
if model_name != '':
self.opt.model_name = model_name
self.opt.input_size = 256
self.feature_dict_list = feature_dict_list_
self.frame_queue_in = frame_queue_in
self.frame_queue_out = frame_queue_out
self.gpu = gpu
self.align_method = align_method
self.use_gfpgan = use_gfpgan
self.sr_weight = sr_weight
self.use_color_trans = use_color_trans
self.color_trans_mode = color_trans_mode
def forward(self):
self.q0 = Queue(2)
self.q1 = Queue(2)
self.q2 = Queue(2)
self.c0 = Consumer0(self.opt, self.frame_queue_in, [self.q0], fps_counter=False, align_method=self.align_method)
self.c1 = Consumer1(self.opt, self.feature_dict_list, [self.q0, self.q1], fps_counter=False)
self.c3 = Consumer3([self.q1, self.q2], fps_counter=False,
use_gfpgan=self.use_gfpgan, sr_weight=self.sr_weight,
use_color_trans=self.use_color_trans, color_trans_mode=self.color_trans_mode)
self.c2 = Consumer2([self.q2], self.frame_queue_out, fps_counter=False)
self.c0.start()
self.c1.start()
self.c3.start()
self.c2.start()
self.c0.join()
self.c1.join()
self.c3.join()
self.c2.join()
return