TheEeeeLin commited on
Commit
1c25fe3
1 Parent(s): 23cd1cf

update new model

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/demoImage.png filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/demoImage.png filter=lfs diff=lfs merge=lfs -text
37
+ hivision/creator/weights/rmbg-1.4.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,10 +1,11 @@
1
  *.pyc
2
- **/__pycache__
3
  .idea
4
  .vscode/*
5
  .DS_Store
6
- app/output/*.jpg
7
  demo/kb_output/*.jpg
 
8
  # build outputs
9
  dist
10
  build
@@ -12,5 +13,9 @@ build
12
  *.pth
13
  *.pt
14
  *.onnx
 
15
  test/temp/*
16
- !test/temp/.gitkeep
 
 
 
 
1
  *.pyc
2
+ **/__pycache__/
3
  .idea
4
  .vscode/*
5
  .DS_Store
6
+ .env
7
  demo/kb_output/*.jpg
8
+ demo/kb_output/*.png
9
  # build outputs
10
  dist
11
  build
 
13
  *.pth
14
  *.pt
15
  *.onnx
16
+ *.mnn
17
  test/temp/*
18
+ !test/temp/.gitkeep
19
+ !hivision/creator/weights/rmbg-1.4.onnx
20
+
21
+ .python-version
app.py CHANGED
@@ -7,8 +7,7 @@ from hivision.creator.layout_calculator import (
7
  generate_layout_photo,
8
  generate_layout_image,
9
  )
10
- from hivision.creator.human_matting import *
11
- from hivision.creator.face_detector import *
12
  import pathlib
13
  import numpy as np
14
  from demo.utils import csv_to_size_list
@@ -150,15 +149,7 @@ def idphoto_inference(
150
  idphoto_json["custom_image_kb"] = None
151
 
152
  creator = IDCreator()
153
- if matting_model_option == "modnet_photographic_portrait_matting":
154
- creator.matting_handler = extract_human_modnet_photographic_portrait_matting
155
- else:
156
- creator.matting_handler = extract_human
157
-
158
- if face_detect_option == "mtcnn":
159
- creator.detection_handler = detect_face_mtcnn
160
- else:
161
- creator.detection_handler = detect_face_face_plusplus
162
 
163
  change_bg_only = idphoto_json["size_mode"] in ["只换底", "Only Change Background"]
164
  # 生成证件照
@@ -294,28 +285,28 @@ def idphoto_inference(
294
 
295
 
296
  if __name__ == "__main__":
297
- # argparser = argparse.ArgumentParser()
298
- # argparser.add_argument(
299
- # "--port", type=int, default=7860, help="The port number of the server"
300
- # )
301
- # argparser.add_argument(
302
- # "--host", type=str, default="127.0.0.1", help="The host of the server"
303
- # )
304
- # argparser.add_argument(
305
- # "--root_path",
306
- # type=str,
307
- # default=None,
308
- # help="The root path of the server, default is None (='/'), e.g. '/myapp'",
309
- # )
310
-
311
- # args = argparser.parse_args()
312
 
313
  language = ["中文", "English"]
314
 
315
  matting_model_list = [
316
  os.path.splitext(file)[0]
317
  for file in os.listdir(os.path.join(root_dir, "hivision/creator/weights"))
318
- if file.endswith(".onnx")
319
  ]
320
  DEFAULT_MATTING_MODEL = "modnet_photographic_portrait_matting"
321
  if DEFAULT_MATTING_MODEL in matting_model_list:
@@ -366,7 +357,7 @@ if __name__ == "__main__":
366
  content = f.read()
367
  return content
368
 
369
- demo = gr.Blocks(css=css)
370
 
371
  with demo:
372
  gr.HTML(load_description(os.path.join(root_dir, "assets/title.md")))
@@ -669,7 +660,7 @@ if __name__ == "__main__":
669
  demo.launch(
670
  # server_name=args.host,
671
  # server_port=args.port,
672
- # show_api=False,
673
  # favicon_path=os.path.join(root_dir, "assets/hivision_logo.png"),
674
  # root_path=args.root_path,
675
  )
 
7
  generate_layout_photo,
8
  generate_layout_image,
9
  )
10
+ from hivision.creator.choose_handler import choose_handler
 
11
  import pathlib
12
  import numpy as np
13
  from demo.utils import csv_to_size_list
 
149
  idphoto_json["custom_image_kb"] = None
150
 
151
  creator = IDCreator()
152
+ choose_handler(creator, matting_model_option, face_detect_option)
 
 
 
 
 
 
 
 
153
 
154
  change_bg_only = idphoto_json["size_mode"] in ["只换底", "Only Change Background"]
155
  # 生成证件照
 
285
 
286
 
287
  if __name__ == "__main__":
288
+ argparser = argparse.ArgumentParser()
289
+ argparser.add_argument(
290
+ "--port", type=int, default=7860, help="The port number of the server"
291
+ )
292
+ argparser.add_argument(
293
+ "--host", type=str, default="127.0.0.1", help="The host of the server"
294
+ )
295
+ argparser.add_argument(
296
+ "--root_path",
297
+ type=str,
298
+ default=None,
299
+ help="The root path of the server, default is None (='/'), e.g. '/myapp'",
300
+ )
301
+
302
+ args = argparser.parse_args()
303
 
304
  language = ["中文", "English"]
305
 
306
  matting_model_list = [
307
  os.path.splitext(file)[0]
308
  for file in os.listdir(os.path.join(root_dir, "hivision/creator/weights"))
309
+ if file.endswith(".onnx") or file.endswith(".mnn")
310
  ]
311
  DEFAULT_MATTING_MODEL = "modnet_photographic_portrait_matting"
312
  if DEFAULT_MATTING_MODEL in matting_model_list:
 
357
  content = f.read()
358
  return content
359
 
360
+ demo = gr.Blocks(title="HivisionIDPhotos", css=css)
361
 
362
  with demo:
363
  gr.HTML(load_description(os.path.join(root_dir, "assets/title.md")))
 
660
  demo.launch(
661
  # server_name=args.host,
662
  # server_port=args.port,
663
+ show_api=False,
664
  # favicon_path=os.path.join(root_dir, "assets/hivision_logo.png"),
665
  # root_path=args.root_path,
666
  )
hivision/creator/__init__.py CHANGED
@@ -41,8 +41,7 @@ class IDCreator:
41
  """
42
  # 处理者
43
  self.matting_handler: ContextHandler = extract_human
44
- self.detection_handler: ContextHandler = detect_face_face_plusplus
45
- # self.detection_handler: ContextHandler = detect_face_mtcnn
46
 
47
  # 上下文
48
  self.ctx = None
 
41
  """
42
  # 处理者
43
  self.matting_handler: ContextHandler = extract_human
44
+ self.detection_handler: ContextHandler = detect_face_mtcnn
 
45
 
46
  # 上下文
47
  self.ctx = None
hivision/creator/choose_handler.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hivision.creator.human_matting import *
2
+ from hivision.creator.face_detector import *
3
+
4
+
5
+ def choose_handler(creator, matting_model_option=None, face_detect_option=None):
6
+ if matting_model_option == "modnet_photographic_portrait_matting":
7
+ creator.matting_handler = extract_human_modnet_photographic_portrait_matting
8
+ elif matting_model_option == "mnn_hivision_modnet":
9
+ creator.matting_handler = extract_human_mnn_modnet
10
+ elif matting_model_option == "rmbg-1.4":
11
+ creator.matting_handler = extract_human_rmbg
12
+ else:
13
+ creator.matting_handler = extract_human
14
+
15
+ if face_detect_option == "face_plusplus":
16
+ creator.detection_handler = detect_face_face_plusplus
17
+ else:
18
+ creator.detection_handler = detect_face_mtcnn
hivision/creator/human_matting.py CHANGED
@@ -25,8 +25,43 @@ WEIGHTS = {
25
  "weights",
26
  "modnet_photographic_portrait_matting.onnx",
27
  ),
 
 
 
 
 
 
28
  }
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def extract_human(ctx: Context):
32
  """
@@ -50,10 +85,24 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
50
  ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
51
  )
52
  # 修复抠图
 
 
 
 
 
 
 
 
53
  ctx.processing_image = hollow_out_fix(matting_image)
54
  ctx.matting_image = ctx.processing_image.copy()
55
 
56
 
 
 
 
 
 
 
57
  def hollow_out_fix(src: np.ndarray) -> np.ndarray:
58
  """
59
  修补抠图区域,作为抠图模型精度不够的补充
@@ -120,9 +169,11 @@ def read_modnet_image(input_image, ref_size=512):
120
 
121
 
122
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
123
- # global sess
124
- # if sess is None:
125
- sess = onnxruntime.InferenceSession(checkpoint_path)
 
 
126
 
127
  input_name = sess.get_inputs()[0].name
128
  output_name = sess.get_outputs()[0].name
@@ -138,3 +189,85 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
138
  output_image = cv2.merge((b, g, r, mask))
139
 
140
  return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  "weights",
26
  "modnet_photographic_portrait_matting.onnx",
27
  ),
28
+ "mnn_hivision_modnet": os.path.join(
29
+ os.path.dirname(__file__),
30
+ "weights",
31
+ "mnn_hivision_modnet.mnn",
32
+ ),
33
+ "rmbg-1.4": os.path.join(os.path.dirname(__file__), "weights", "rmbg-1.4.onnx"),
34
  }
35
 
36
+ ONNX_DEVICE = (
37
+ "CUDAExecutionProvider"
38
+ if onnxruntime.get_device() == "GPU"
39
+ else "CPUExecutionProvider"
40
+ )
41
+
42
+
43
+ def load_onnx_model(checkpoint_path):
44
+ providers = (
45
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
46
+ if ONNX_DEVICE == "CUDAExecutionProvider"
47
+ else ["CPUExecutionProvider"]
48
+ )
49
+
50
+ try:
51
+ sess = onnxruntime.InferenceSession(checkpoint_path, providers=providers)
52
+ except Exception as e:
53
+ if ONNX_DEVICE == "CUDAExecutionProvider":
54
+ print(f"Failed to load model with CUDAExecutionProvider: {e}")
55
+ print("Falling back to CPUExecutionProvider")
56
+ # 尝试使用CPU加载模型
57
+ sess = onnxruntime.InferenceSession(
58
+ checkpoint_path, providers=["CPUExecutionProvider"]
59
+ )
60
+ else:
61
+ raise e # 如果是CPU执行失败,重新抛出异常
62
+
63
+ return sess
64
+
65
 
66
  def extract_human(ctx: Context):
67
  """
 
85
  ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
86
  )
87
  # 修复抠图
88
+ ctx.processing_image = matting_image
89
+ ctx.matting_image = ctx.processing_image.copy()
90
+
91
+
92
+ def extract_human_mnn_modnet(ctx: Context):
93
+ matting_image = get_mnn_modnet_matting(
94
+ ctx.processing_image, WEIGHTS["mnn_hivision_modnet"]
95
+ )
96
  ctx.processing_image = hollow_out_fix(matting_image)
97
  ctx.matting_image = ctx.processing_image.copy()
98
 
99
 
100
+ def extract_human_rmbg(ctx: Context):
101
+ matting_image = get_rmbg_matting(ctx.processing_image, WEIGHTS["rmbg-1.4"])
102
+ ctx.processing_image = matting_image
103
+ ctx.matting_image = ctx.processing_image.copy()
104
+
105
+
106
  def hollow_out_fix(src: np.ndarray) -> np.ndarray:
107
  """
108
  修补抠图区域,作为抠图模型精度不够的补充
 
169
 
170
 
171
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
172
+ if not os.path.exists(checkpoint_path):
173
+ print(f"Checkpoint file not found: {checkpoint_path}")
174
+ return None
175
+
176
+ sess = load_onnx_model(checkpoint_path)
177
 
178
  input_name = sess.get_inputs()[0].name
179
  output_name = sess.get_outputs()[0].name
 
189
  output_image = cv2.merge((b, g, r, mask))
190
 
191
  return output_image
192
+
193
+
194
+ def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
195
+ if not os.path.exists(checkpoint_path):
196
+ print(f"Checkpoint file not found: {checkpoint_path}")
197
+ return None
198
+
199
+ def resize_rmbg_image(image):
200
+ image = image.convert("RGB")
201
+ model_input_size = (ref_size, ref_size)
202
+ image = image.resize(model_input_size, Image.BILINEAR)
203
+ return image
204
+
205
+ sess = load_onnx_model(checkpoint_path)
206
+
207
+ orig_image = Image.fromarray(input_image)
208
+ image = resize_rmbg_image(orig_image)
209
+ im_np = np.array(image).astype(np.float32)
210
+ im_np = im_np.transpose(2, 0, 1) # Change to CxHxW format
211
+ im_np = np.expand_dims(im_np, axis=0) # Add batch dimension
212
+ im_np = im_np / 255.0 # Normalize to [0, 1]
213
+ im_np = (im_np - 0.5) / 0.5 # Normalize to [-1, 1]
214
+
215
+ # Inference
216
+ result = sess.run(None, {sess.get_inputs()[0].name: im_np})[0]
217
+
218
+ # Post process
219
+ result = np.squeeze(result)
220
+ ma = np.max(result)
221
+ mi = np.min(result)
222
+ result = (result - mi) / (ma - mi) # Normalize to [0, 1]
223
+
224
+ # Convert to PIL image
225
+ im_array = (result * 255).astype(np.uint8)
226
+ pil_im = Image.fromarray(
227
+ im_array, mode="L"
228
+ ) # Ensure mask is single channel (L mode)
229
+
230
+ # Resize the mask to match the original image size
231
+ pil_im = pil_im.resize(orig_image.size, Image.BILINEAR)
232
+
233
+ # Paste the mask on the original image
234
+ new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
235
+ new_im.paste(orig_image, mask=pil_im)
236
+
237
+ return np.array(new_im)
238
+
239
+
240
+ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
241
+ if not os.path.exists(checkpoint_path):
242
+ print(f"Checkpoint file not found: {checkpoint_path}")
243
+ return None
244
+
245
+ try:
246
+ import MNN.expr as expr
247
+ import MNN.nn as nn
248
+ except ImportError as e:
249
+ raise ImportError(
250
+ "The MNN module is not installed or there was an import error. Please ensure that the MNN library is installed by using the command 'pip install mnn'."
251
+ ) from e
252
+
253
+ config = {}
254
+ config["precision"] = "low" # 当硬件支持(armv8.2)时使用fp16推理
255
+ config["backend"] = 0 # CPU
256
+ config["numThread"] = 4 # 线程数
257
+ im, width, length = read_modnet_image(input_image, ref_size=512)
258
+ rt = nn.create_runtime_manager((config,))
259
+ net = nn.load_module_from_file(
260
+ checkpoint_path, ["input1"], ["output1"], runtime_manager=rt
261
+ )
262
+ input_var = expr.convert(im, expr.NCHW)
263
+ output_var = net.forward(input_var)
264
+ matte = expr.convert(output_var, expr.NCHW)
265
+ matte = matte.read() # var转换为np
266
+ matte = (matte * 255).astype("uint8")
267
+ matte = np.squeeze(matte)
268
+ mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
269
+ b, g, r = cv2.split(np.uint8(input_image))
270
+
271
+ output_image = cv2.merge((b, g, r, mask))
272
+
273
+ return output_image
hivision/creator/weights/rmbg-1.4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cafcf770b06757c4eaced21b1a88e57fd2b66de01b8045f35f01535ba742e0f
3
+ size 176153355