TheEeeeLin commited on
Commit
2c368dd
·
1 Parent(s): 06fbec3

Update human_matting.py

Browse files
Files changed (1) hide show
  1. hivision/creator/human_matting.py +35 -2
hivision/creator/human_matting.py CHANGED
@@ -95,7 +95,7 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
95
  :param ctx: 上下文
96
  """
97
  # 抠图
98
- matting_image = get_modnet_matting(
99
  ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
100
  )
101
  # 修复抠图
@@ -221,6 +221,38 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
221
  return output_image
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
225
  global RMBG_SESS
226
 
@@ -332,7 +364,8 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
332
  load_start_time = time()
333
 
334
  if BIREFNET_V1_LITE_SESS is None:
335
- BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
 
336
 
337
  # 记录加载onnx模型的结束时间
338
  load_end_time = time()
 
95
  :param ctx: 上下文
96
  """
97
  # 抠图
98
+ matting_image = get_modnet_matting_photographic_portrait_matting(
99
  ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
100
  )
101
  # 修复抠图
 
221
  return output_image
222
 
223
 
224
+ def get_modnet_matting_photographic_portrait_matting(
225
+ input_image, checkpoint_path, ref_size=512
226
+ ):
227
+ global MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS
228
+
229
+ if not os.path.exists(checkpoint_path):
230
+ print(f"Checkpoint file not found: {checkpoint_path}")
231
+ return None
232
+
233
+ if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
234
+ MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
235
+ checkpoint_path, set_cpu=True
236
+ )
237
+
238
+ input_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_inputs()[0].name
239
+ output_name = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.get_outputs()[0].name
240
+
241
+ im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
242
+
243
+ matte = MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS.run(
244
+ [output_name], {input_name: im}
245
+ )
246
+ matte = (matte[0] * 255).astype("uint8")
247
+ matte = np.squeeze(matte)
248
+ mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
249
+ b, g, r = cv2.split(np.uint8(input_image))
250
+
251
+ output_image = cv2.merge((b, g, r, mask))
252
+
253
+ return output_image
254
+
255
+
256
  def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
257
  global RMBG_SESS
258
 
 
364
  load_start_time = time()
365
 
366
  if BIREFNET_V1_LITE_SESS is None:
367
+ print("首次加载birefnet-v1-lite模型...")
368
+ BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
369
 
370
  # 记录加载onnx模型的结束时间
371
  load_end_time = time()