TheEeeeLin commited on
Commit
06fbec3
·
1 Parent(s): 3cdc8a1
.gitattributes CHANGED
@@ -1,37 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- assets/demoImage.png filter=lfs diff=lfs merge=lfs -text
37
- hivision/creator/weights/rmbg-1.4.onnx filter=lfs diff=lfs merge=lfs -text
 
1
+ hivision/creator/weights/birefnet-v1-lite.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -17,5 +17,6 @@ build
17
  test/temp/*
18
  !test/temp/.gitkeep
19
  !hivision/creator/weights/rmbg-1.4.onnx
 
20
 
21
  .python-version
 
17
  test/temp/*
18
  !test/temp/.gitkeep
19
  !hivision/creator/weights/rmbg-1.4.onnx
20
+ !hivision/creator/weights/birefnet-v1-lite.onnx
21
 
22
  .python-version
app.py CHANGED
@@ -444,7 +444,7 @@ if __name__ == "__main__":
444
  minimum=0.1,
445
  maximum=0.5,
446
  value=0.2,
447
- step=0.02,
448
  label="面部比例",
449
  interactive=True,
450
  )
@@ -453,7 +453,7 @@ if __name__ == "__main__":
453
  minimum=0.02,
454
  maximum=0.5,
455
  value=0.12,
456
- step=0.02,
457
  label="头距顶距离",
458
  interactive=True,
459
  )
 
444
  minimum=0.1,
445
  maximum=0.5,
446
  value=0.2,
447
+ step=0.01,
448
  label="面部比例",
449
  interactive=True,
450
  )
 
453
  minimum=0.02,
454
  maximum=0.5,
455
  value=0.12,
456
+ step=0.01,
457
  label="头距顶距离",
458
  interactive=True,
459
  )
hivision/.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demoImage.png filter=lfs diff=lfs merge=lfs -text
37
+ hivision/creator/weights/rmbg-1.4.onnx filter=lfs diff=lfs merge=lfs -text
hivision/creator/choose_handler.py CHANGED
@@ -9,10 +9,17 @@ def choose_handler(creator, matting_model_option=None, face_detect_option=None):
9
  creator.matting_handler = extract_human_mnn_modnet
10
  elif matting_model_option == "rmbg-1.4":
11
  creator.matting_handler = extract_human_rmbg
 
 
 
 
12
  else:
13
  creator.matting_handler = extract_human
14
 
15
- if face_detect_option == "face_plusplus":
 
 
 
16
  creator.detection_handler = detect_face_face_plusplus
17
  else:
18
  creator.detection_handler = detect_face_mtcnn
 
9
  creator.matting_handler = extract_human_mnn_modnet
10
  elif matting_model_option == "rmbg-1.4":
11
  creator.matting_handler = extract_human_rmbg
12
+ # elif matting_model_option == "birefnet-portrait":
13
+ # creator.matting_handler = extract_human_birefnet_portrait
14
+ elif matting_model_option == "birefnet-v1-lite":
15
+ creator.matting_handler = extract_human_birefnet_lite
16
  else:
17
  creator.matting_handler = extract_human
18
 
19
+ if (
20
+ face_detect_option == "face_plusplus"
21
+ or face_detect_option == "face++ (联网API)"
22
+ ):
23
  creator.detection_handler = detect_face_face_plusplus
24
  else:
25
  creator.detection_handler = detect_face_mtcnn
hivision/creator/face_detector.py CHANGED
@@ -65,6 +65,8 @@ def detect_face_face_plusplus(ctx: Context):
65
  api_key = os.getenv("FACE_PLUS_API_KEY")
66
  api_secret = os.getenv("FACE_PLUS_API_SECRET")
67
 
 
 
68
  image = ctx.origin_image
69
  # 将图片转为 base64, 且不大于2MB(Face++ API接口限制)
70
  image_base64 = resize_image_to_kb_base64(image, 2000, mode="max")
 
65
  api_key = os.getenv("FACE_PLUS_API_KEY")
66
  api_secret = os.getenv("FACE_PLUS_API_SECRET")
67
 
68
+ print("调用了face++")
69
+
70
  image = ctx.origin_image
71
  # 将图片转为 base64, 且不大于2MB(Face++ API接口限制)
72
  image_base64 = resize_image_to_kb_base64(image, 2000, mode="max")
hivision/creator/human_matting.py CHANGED
@@ -14,6 +14,7 @@ from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze
14
  from .context import Context
15
  import cv2
16
  import os
 
17
 
18
 
19
  WEIGHTS = {
@@ -31,6 +32,9 @@ WEIGHTS = {
31
  "mnn_hivision_modnet.mnn",
32
  ),
33
  "rmbg-1.4": os.path.join(os.path.dirname(__file__), "weights", "rmbg-1.4.onnx"),
 
 
 
34
  }
35
 
36
  ONNX_DEVICE = (
@@ -39,26 +43,36 @@ ONNX_DEVICE = (
39
  else "CPUExecutionProvider"
40
  )
41
 
 
 
 
 
 
42
 
43
- def load_onnx_model(checkpoint_path):
44
  providers = (
45
  ["CUDAExecutionProvider", "CPUExecutionProvider"]
46
  if ONNX_DEVICE == "CUDAExecutionProvider"
47
  else ["CPUExecutionProvider"]
48
  )
49
 
50
- try:
51
- sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
52
- except Exception as e:
53
- if ONNX_DEVICE == "CUDAExecutionProvider":
54
- print(f"Failed to load model with CUDAExecutionProvider: {e}")
55
- print("Falling back to CPUExecutionProvider")
56
- # 尝试使用CPU加载模型
57
- sess = onnxruntime.InferenceSession(
58
- checkpoint_path, providers=["CPUExecutionProvider"]
59
- )
60
- else:
61
- raise e # 如果是CPU执行失败,重新抛出异常
 
 
 
 
 
62
 
63
  return sess
64
 
@@ -103,6 +117,22 @@ def extract_human_rmbg(ctx: Context):
103
  ctx.matting_image = ctx.processing_image.copy()
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def hollow_out_fix(src: np.ndarray) -> np.ndarray:
107
  """
108
  修补抠图区域,作为抠图模型精度不够的补充
@@ -165,22 +195,22 @@ def read_modnet_image(input_image, ref_size=512):
165
  return im, width, length
166
 
167
 
168
- # sess = None
169
-
170
-
171
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
 
 
172
  if not os.path.exists(checkpoint_path):
173
  print(f"Checkpoint file not found: {checkpoint_path}")
174
  return None
175
 
176
- sess = load_onnx_model(checkpoint_path)
 
177
 
178
- input_name = sess.get_inputs()[0].name
179
- output_name = sess.get_outputs()[0].name
180
 
181
  im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
182
 
183
- matte = sess.run([output_name], {input_name: im})
184
  matte = (matte[0] * 255).astype("uint8")
185
  matte = np.squeeze(matte)
186
  mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
@@ -192,6 +222,8 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
192
 
193
 
194
  def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
 
 
195
  if not os.path.exists(checkpoint_path):
196
  print(f"Checkpoint file not found: {checkpoint_path}")
197
  return None
@@ -202,7 +234,8 @@ def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
202
  image = image.resize(model_input_size, Image.BILINEAR)
203
  return image
204
 
205
- sess = load_onnx_model(checkpoint_path)
 
206
 
207
  orig_image = Image.fromarray(input_image)
208
  image = resize_rmbg_image(orig_image)
@@ -213,7 +246,7 @@ def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
213
  im_np = (im_np - 0.5) / 0.5 # Normalize to [-1, 1]
214
 
215
  # Inference
216
- result = sess.run(None, {sess.get_inputs()[0].name: im_np})[0]
217
 
218
  # Post process
219
  result = np.squeeze(result)
@@ -271,3 +304,64 @@ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
271
  output_image = cv2.merge((b, g, r, mask))
272
 
273
  return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from .context import Context
15
  import cv2
16
  import os
17
+ from time import time
18
 
19
 
20
  WEIGHTS = {
 
32
  "mnn_hivision_modnet.mnn",
33
  ),
34
  "rmbg-1.4": os.path.join(os.path.dirname(__file__), "weights", "rmbg-1.4.onnx"),
35
+ "birefnet-v1-lite": os.path.join(
36
+ os.path.dirname(__file__), "weights", "birefnet-v1-lite.onnx"
37
+ ),
38
  }
39
 
40
  ONNX_DEVICE = (
 
43
  else "CPUExecutionProvider"
44
  )
45
 
46
+ HIVISION_MODNET_SESS = None
47
+ MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = None
48
+ RMBG_SESS = None
49
+ BIREFNET_V1_LITE_SESS = None
50
+
51
 
52
+ def load_onnx_model(checkpoint_path, set_cpu=False):
53
  providers = (
54
  ["CUDAExecutionProvider", "CPUExecutionProvider"]
55
  if ONNX_DEVICE == "CUDAExecutionProvider"
56
  else ["CPUExecutionProvider"]
57
  )
58
 
59
+ if set_cpu:
60
+ sess = onnxruntime.InferenceSession(
61
+ checkpoint_path, providers=["CPUExecutionProvider"]
62
+ )
63
+ else:
64
+ try:
65
+ sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
66
+ except Exception as e:
67
+ if ONNX_DEVICE == "CUDAExecutionProvider":
68
+ print(f"Failed to load model with CUDAExecutionProvider: {e}")
69
+ print("Falling back to CPUExecutionProvider")
70
+ # 尝试使用CPU加载模型
71
+ sess = onnxruntime.InferenceSession(
72
+ checkpoint_path, providers=["CPUExecutionProvider"]
73
+ )
74
+ else:
75
+ raise e # 如果是CPU执行失败,重新抛出异常
76
 
77
  return sess
78
 
 
117
  ctx.matting_image = ctx.processing_image.copy()
118
 
119
 
120
+ # def extract_human_birefnet_portrait(ctx: Context):
121
+ # matting_image = get_birefnet_portrait_matting(
122
+ # ctx.processing_image, WEIGHTS["birefnet-portrait"]
123
+ # )
124
+ # ctx.processing_image = matting_image
125
+ # ctx.matting_image = ctx.processing_image.copy()
126
+
127
+
128
+ def extract_human_birefnet_lite(ctx: Context):
129
+ matting_image = get_birefnet_portrait_matting(
130
+ ctx.processing_image, WEIGHTS["birefnet-v1-lite"]
131
+ )
132
+ ctx.processing_image = matting_image
133
+ ctx.matting_image = ctx.processing_image.copy()
134
+
135
+
136
  def hollow_out_fix(src: np.ndarray) -> np.ndarray:
137
  """
138
  修补抠图区域,作为抠图模型精度不够的补充
 
195
  return im, width, length
196
 
197
 
 
 
 
198
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
199
+ global HIVISION_MODNET_SESS
200
+
201
  if not os.path.exists(checkpoint_path):
202
  print(f"Checkpoint file not found: {checkpoint_path}")
203
  return None
204
 
205
+ if HIVISION_MODNET_SESS is None:
206
+ HIVISION_MODNET_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
207
 
208
+ input_name = HIVISION_MODNET_SESS.get_inputs()[0].name
209
+ output_name = HIVISION_MODNET_SESS.get_outputs()[0].name
210
 
211
  im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size)
212
 
213
+ matte = HIVISION_MODNET_SESS.run([output_name], {input_name: im})
214
  matte = (matte[0] * 255).astype("uint8")
215
  matte = np.squeeze(matte)
216
  mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
 
222
 
223
 
224
  def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
225
+ global RMBG_SESS
226
+
227
  if not os.path.exists(checkpoint_path):
228
  print(f"Checkpoint file not found: {checkpoint_path}")
229
  return None
 
234
  image = image.resize(model_input_size, Image.BILINEAR)
235
  return image
236
 
237
+ if RMBG_SESS is None:
238
+ RMBG_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
239
 
240
  orig_image = Image.fromarray(input_image)
241
  image = resize_rmbg_image(orig_image)
 
246
  im_np = (im_np - 0.5) / 0.5 # Normalize to [-1, 1]
247
 
248
  # Inference
249
+ result = RMBG_SESS.run(None, {RMBG_SESS.get_inputs()[0].name: im_np})[0]
250
 
251
  # Post process
252
  result = np.squeeze(result)
 
304
  output_image = cv2.merge((b, g, r, mask))
305
 
306
  return output_image
307
+
308
+
309
+ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
310
+ global BIREFNET_V1_LITE_SESS
311
+
312
+ if not os.path.exists(checkpoint_path):
313
+ print(f"Checkpoint file not found: {checkpoint_path}")
314
+ return None
315
+
316
+ def transform_image(image):
317
+ image = image.resize((1024, 1024)) # Resize to 1024x1024
318
+ image = (
319
+ np.array(image, dtype=np.float32) / 255.0
320
+ ) # Convert to numpy array and normalize to [0, 1]
321
+ image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] # Normalize
322
+ image = np.transpose(image, (2, 0, 1)) # Change from (H, W, C) to (C, H, W)
323
+ image = np.expand_dims(image, axis=0) # Add batch dimension
324
+ return image.astype(np.float32) # Ensure the output is float32
325
+
326
+ orig_image = Image.fromarray(input_image)
327
+ input_images = transform_image(
328
+ orig_image
329
+ ) # This will already have the correct shape
330
+
331
+ # 记录加载onnx模型的开始时间
332
+ load_start_time = time()
333
+
334
+ if BIREFNET_V1_LITE_SESS is None:
335
+ BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
336
+
337
+ # 记录加载onnx模型的结束时间
338
+ load_end_time = time()
339
+
340
+ # 打印加载onnx模型所花的时间
341
+ print(f"Loading ONNX model took {load_end_time - load_start_time:.4f} seconds")
342
+
343
+ input_name = BIREFNET_V1_LITE_SESS.get_inputs()[0].name
344
+ print(onnxruntime.get_device(), BIREFNET_V1_LITE_SESS.get_providers())
345
+
346
+ time_st = time()
347
+ pred_onnx = BIREFNET_V1_LITE_SESS.run(None, {input_name: input_images})[
348
+ -1
349
+ ] # Use float32 input
350
+ pred_onnx = np.squeeze(pred_onnx) # Use numpy to squeeze
351
+ result = 1 / (1 + np.exp(-pred_onnx)) # Sigmoid function using numpy
352
+ print(f"Inference time: {time() - time_st:.4f} seconds")
353
+
354
+ # Convert to PIL image
355
+ im_array = (result * 255).astype(np.uint8)
356
+ pil_im = Image.fromarray(
357
+ im_array, mode="L"
358
+ ) # Ensure mask is single channel (L mode)
359
+
360
+ # Resize the mask to match the original image size
361
+ pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)
362
+
363
+ # Paste the mask on the original image
364
+ new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
365
+ new_im.paste(orig_image, mask=pil_im)
366
+
367
+ return np.array(new_im)
hivision/creator/weights/birefnet-v1-lite.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5600024376f572a557870a5eb0afb1e5961636bef4e1e22132025467d0f03333
3
+ size 224005088