Spaces:
Running
Running
TheEeeeLin
commited on
Commit
·
2c368dd
1
Parent(s):
06fbec3
Update human_matting.py
Browse files
hivision/creator/human_matting.py
CHANGED
@@ -95,7 +95,7 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
|
|
95 |
:param ctx: 上下文
|
96 |
"""
|
97 |
# 抠图
|
98 |
-
matting_image =
|
99 |
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
|
100 |
)
|
101 |
# 修复抠图
|
@@ -221,6 +221,38 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
|
|
221 |
return output_image
|
222 |
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
|
225 |
global RMBG_SESS
|
226 |
|
@@ -332,7 +364,8 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
|
|
332 |
load_start_time = time()
|
333 |
|
334 |
if BIREFNET_V1_LITE_SESS is None:
|
335 |
-
|
|
|
336 |
|
337 |
# 记录加载onnx模型的结束时间
|
338 |
load_end_time = time()
|
|
|
95 |
:param ctx: 上下文
|
96 |
"""
|
97 |
# 抠图
|
98 |
+
matting_image = get_modnet_matting_photographic_portrait_matting(
|
99 |
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
|
100 |
)
|
101 |
# 修复抠图
|
|
|
221 |
return output_image
|
222 |
|
223 |
|
224 |
+
def get_modnet_matting_photographic_portrait_matting(
|
225 |
+
input_image, checkpoint_path, ref_size=512
|
226 |
+
):
|
227 |
+
global MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS
|
228 |
+
|
229 |
+
if not os.path.exists(checkpoint_path):
|
230 |
+
print(f"Checkpoint file not found: {checkpoint_path}")
|
231 |
+
return None
|
232 |
+
|
233 |
+
if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
|
234 |
+
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
|
235 |
+
checkpoint_path, set_cpu=True
|
236 |
+
)
|
237 |
+
|
238 |
+
input_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_inputs()[0].name
|
239 |
+
output_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_outputs()[0].name
|
240 |
+
|
241 |
+
im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
|
242 |
+
|
243 |
+
matte = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.run(
|
244 |
+
[output_name], {input_name: im}
|
245 |
+
)
|
246 |
+
matte = (matte[0] * 255).astype("uint8")
|
247 |
+
matte = np.squeeze(matte)
|
248 |
+
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
|
249 |
+
b, g, r = cv2.split(np.uint8(input_image))
|
250 |
+
|
251 |
+
output_image = cv2.merge((b, g, r, mask))
|
252 |
+
|
253 |
+
return output_image
|
254 |
+
|
255 |
+
|
256 |
def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
|
257 |
global RMBG_SESS
|
258 |
|
|
|
364 |
load_start_time = time()
|
365 |
|
366 |
if BIREFNET_V1_LITE_SESS is None:
|
367 |
+
print("首次加载birefnet-v1-lite模型...")
|
368 |
+
BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
|
369 |
|
370 |
# 记录加载onnx模型的结束时间
|
371 |
load_end_time = time()
|