TheEeeeLin commited on
Commit
1d213d9
·
1 Parent(s): be33c96

update app.py

Browse files
app.py CHANGED
@@ -15,10 +15,17 @@ HUMAN_MATTING_MODELS_EXIST = [
15
  if file.endswith(".onnx") or file.endswith(".mnn")
16
  ]
17
  # 在HUMAN_MATTING_MODELS中的模型才会被加载到Gradio中显示
18
- HUMAN_MATTING_MODELS = [
19
  model for model in HUMAN_MATTING_MODELS if model in HUMAN_MATTING_MODELS_EXIST
20
  ]
21
 
 
 
 
 
 
 
 
22
  FACE_DETECT_MODELS = ["face++ (联网Online API)", "mtcnn"]
23
  FACE_DETECT_MODELS_EXPAND = (
24
  ["retinaface-resnet50"]
@@ -29,7 +36,7 @@ FACE_DETECT_MODELS_EXPAND = (
29
  )
30
  else []
31
  )
32
- FACE_DETECT_MODELS += FACE_DETECT_MODELS_EXPAND
33
 
34
  LANGUAGE = ["zh", "en", "ko", "ja"]
35
 
@@ -54,8 +61,8 @@ if __name__ == "__main__":
54
  demo = create_ui(
55
  processor,
56
  root_dir,
57
- HUMAN_MATTING_MODELS_EXIST,
58
- FACE_DETECT_MODELS,
59
  LANGUAGE,
60
  )
61
  demo.launch(
 
15
  if file.endswith(".onnx") or file.endswith(".mnn")
16
  ]
17
  # 在HUMAN_MATTING_MODELS中的模型才会被加载到Gradio中显示
18
+ HUMAN_MATTING_MODELS_CHOICE = [
19
  model for model in HUMAN_MATTING_MODELS if model in HUMAN_MATTING_MODELS_EXIST
20
  ]
21
 
22
+ if len(HUMAN_MATTING_MODELS_CHOICE) == 0:
23
+ raise ValueError(
24
+ "未找到任何存在的人像分割模型,请检查 hivision/creator/weights 目录下的文件"
25
+ + "\n"
26
+ + "No existing portrait segmentation model was found, please check the files in the hivision/creator/weights directory."
27
+ )
28
+
29
  FACE_DETECT_MODELS = ["face++ (联网Online API)", "mtcnn"]
30
  FACE_DETECT_MODELS_EXPAND = (
31
  ["retinaface-resnet50"]
 
36
  )
37
  else []
38
  )
39
+ FACE_DETECT_MODELS_CHOICE = FACE_DETECT_MODELS + FACE_DETECT_MODELS_EXPAND
40
 
41
  LANGUAGE = ["zh", "en", "ko", "ja"]
42
 
 
61
  demo = create_ui(
62
  processor,
63
  root_dir,
64
+ HUMAN_MATTING_MODELS_CHOICE,
65
+ FACE_DETECT_MODELS_CHOICE,
66
  LANGUAGE,
67
  )
68
  demo.launch(
demo/{locals.py → locales.py} RENAMED
File without changes
demo/processor.py CHANGED
@@ -16,7 +16,7 @@ from demo.utils import range_check
16
  import gradio as gr
17
  import os
18
  import time
19
- from demo.locals import LOCALES
20
 
21
 
22
  class IDPhotoProcessor:
@@ -261,7 +261,7 @@ class IDPhotoProcessor:
261
  )
262
 
263
  # 生成排版照片
264
- result_layout_image = self._generate_layout_image(
265
  idphoto_json,
266
  result_image_standard,
267
  language,
@@ -289,7 +289,10 @@ class IDPhotoProcessor:
289
 
290
  # 调整图片大小
291
  output_image_path = self._resize_image_if_needed(
292
- result_image_standard, idphoto_json
 
 
 
293
  )
294
 
295
  return self._create_response(
@@ -297,7 +300,7 @@ class IDPhotoProcessor:
297
  result_image_hd,
298
  result_image_standard_png,
299
  result_image_hd_png,
300
- result_layout_image,
301
  output_image_path,
302
  )
303
 
@@ -319,7 +322,7 @@ class IDPhotoProcessor:
319
 
320
  return result_image_standard, result_image_hd
321
 
322
- def _generate_layout_image(
323
  self,
324
  idphoto_json,
325
  result_image_standard,
@@ -353,14 +356,16 @@ class IDPhotoProcessor:
353
  color=watermark_text_color,
354
  )
355
 
356
- return gr.update(
357
- value=generate_layout_image(
358
- image,
359
- typography_arr,
360
- typography_rotate,
361
- height=idphoto_json["size"][0],
362
- width=idphoto_json["size"][1],
363
- ),
 
 
364
  visible=True,
365
  )
366
 
@@ -390,31 +395,79 @@ class IDPhotoProcessor:
390
  result_image_hd = add_watermark(image=result_image_hd, **watermark_params)
391
  return result_image_standard, result_image_hd
392
 
393
- def _resize_image_if_needed(self, result_image_standard, idphoto_json):
 
 
 
 
 
 
394
  """如果需要,调整图片大小"""
395
- output_image_path = f"{os.path.join(os.path.dirname(os.path.dirname(__file__)), 'demo/kb_output')}/{int(time.time())}.jpg"
396
- # 如果设置了自定义KB大小
397
- if idphoto_json["custom_image_kb"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  resize_image_to_kb(
399
  result_image_standard,
400
- output_image_path,
401
- idphoto_json["custom_image_kb"],
402
- dpi=(
403
- idphoto_json["custom_image_dpi"]
404
- if idphoto_json["custom_image_dpi"]
405
- else 300
406
- ),
407
  )
408
- return output_image_path
409
- # 如果只设置了dpi
410
- elif idphoto_json["custom_image_dpi"]:
411
  save_image_dpi_to_bytes(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  result_image_standard,
413
- output_image_path,
414
- dpi=idphoto_json["custom_image_dpi"],
 
415
  )
416
- return output_image_path
417
 
 
 
 
418
  return None
419
 
420
  def _create_response(
 
16
  import gradio as gr
17
  import os
18
  import time
19
+ from demo.locales import LOCALES
20
 
21
 
22
  class IDPhotoProcessor:
 
261
  )
262
 
263
  # 生成排版照片
264
+ result_image_layout, result_image_layout_gr = self._generate_image_layout(
265
  idphoto_json,
266
  result_image_standard,
267
  language,
 
289
 
290
  # 调整图片大小
291
  output_image_path = self._resize_image_if_needed(
292
+ result_image_standard,
293
+ result_image_hd,
294
+ result_image_layout,
295
+ idphoto_json,
296
  )
297
 
298
  return self._create_response(
 
300
  result_image_hd,
301
  result_image_standard_png,
302
  result_image_hd_png,
303
+ result_image_layout_gr,
304
  output_image_path,
305
  )
306
 
 
322
 
323
  return result_image_standard, result_image_hd
324
 
325
+ def _generate_image_layout(
326
  self,
327
  idphoto_json,
328
  result_image_standard,
 
356
  color=watermark_text_color,
357
  )
358
 
359
+ result_image_layout = generate_layout_image(
360
+ image,
361
+ typography_arr,
362
+ typography_rotate,
363
+ height=idphoto_json["size"][0],
364
+ width=idphoto_json["size"][1],
365
+ )
366
+
367
+ return result_image_layout, gr.update(
368
+ value=result_image_layout,
369
  visible=True,
370
  )
371
 
 
395
  result_image_hd = add_watermark(image=result_image_hd, **watermark_params)
396
  return result_image_standard, result_image_hd
397
 
398
+ def _resize_image_if_needed(
399
+ self,
400
+ result_image_standard,
401
+ result_image_hd,
402
+ result_image_layout,
403
+ idphoto_json,
404
+ ):
405
  """如果需要,调整图片大小"""
406
+ # 设置输出路径
407
+ base_path = os.path.join(
408
+ os.path.dirname(os.path.dirname(__file__)), "demo/kb_output"
409
+ )
410
+ timestamp = int(time.time())
411
+ output_paths = {
412
+ "standard": f"{base_path}/{timestamp}_standard",
413
+ "hd": f"{base_path}/{timestamp}_hd",
414
+ "layout": f"{base_path}/{timestamp}_layout",
415
+ }
416
+
417
+ # 获取自定义的KB和DPI值
418
+ custom_kb = idphoto_json.get("custom_image_kb")
419
+ custom_dpi = idphoto_json.get("custom_image_dpi", 300)
420
+
421
+ # 处理同时有自定义KB和DPI的情况
422
+ if custom_kb and custom_dpi:
423
+ # 为所有输出路径添加DPI信息
424
+ for key in output_paths:
425
+ output_paths[key] += f"_{custom_dpi}dpi.jpg"
426
+ # 为标准图像添加KB信息
427
+ output_paths["standard"] = output_paths["standard"].replace(
428
+ ".jpg", f"_{custom_kb}kb.jpg"
429
+ )
430
+
431
+ # 调整标准图像大小并保存
432
  resize_image_to_kb(
433
  result_image_standard,
434
+ output_paths["standard"],
435
+ custom_kb,
436
+ dpi=custom_dpi,
 
 
 
 
437
  )
438
+ # 保存高清图像和排版图像
439
+ save_image_dpi_to_bytes(result_image_hd, output_paths["hd"], dpi=custom_dpi)
 
440
  save_image_dpi_to_bytes(
441
+ result_image_layout, output_paths["layout"], dpi=custom_dpi
442
+ )
443
+
444
+ return list(output_paths.values())
445
+
446
+ # 只有自定义DPI的情况
447
+ elif custom_dpi:
448
+ for key in output_paths:
449
+ output_paths[key] += f"_{custom_dpi}dpi.jpg"
450
+ # 保存所有图像,使用自定义DPI
451
+ save_image_dpi_to_bytes(
452
+ locals()[f"result_image_{key}"], output_paths[key], dpi=custom_dpi
453
+ )
454
+
455
+ return list(output_paths.values())
456
+
457
+ # 只有自定义KB的情况
458
+ elif custom_kb:
459
+ output_paths["standard"] += f"_{custom_kb}kb.jpg"
460
+ # 只调整标准图像大小并保存
461
+ resize_image_to_kb(
462
  result_image_standard,
463
+ output_paths["standard"],
464
+ custom_kb,
465
+ dpi=300,
466
  )
 
467
 
468
+ return [output_paths["standard"]]
469
+
470
+ # 如果没有自定义设置,返回None
471
  return None
472
 
473
  def _create_response(
demo/ui.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import pathlib
4
- from demo.locals import LOCALES
5
  from demo.processor import IDPhotoProcessor
6
 
7
  """
@@ -23,7 +23,7 @@ def create_ui(
23
  face_detect_models: list,
24
  language: list,
25
  ):
26
- DEFAULT_LANG = language[0]
27
  DEFAULT_HUMAN_MATTING_MODEL = "modnet_photographic_portrait_matting"
28
  DEFAULT_FACE_DETECT_MODEL = "retinaface-resnet50"
29
 
 
1
  import gradio as gr
2
  import os
3
  import pathlib
4
+ from demo.locales import LOCALES
5
  from demo.processor import IDPhotoProcessor
6
 
7
  """
 
23
  face_detect_models: list,
24
  language: list,
25
  ):
26
+ DEFAULT_LANG = "en"
27
  DEFAULT_HUMAN_MATTING_MODEL = "modnet_photographic_portrait_matting"
28
  DEFAULT_FACE_DETECT_MODEL = "retinaface-resnet50"
29
 
demo/utils.py CHANGED
@@ -1,7 +1,4 @@
1
  import csv
2
- import numpy as np
3
- from PIL import Image
4
- from hivision.plugin.watermark import Watermarker, WatermarkerStyles
5
 
6
 
7
  def csv_to_size_list(csv_file: str) -> dict:
 
1
  import csv
 
 
 
2
 
3
 
4
  def csv_to_size_list(csv_file: str) -> dict:
hivision/creator/human_matting.py CHANGED
@@ -37,10 +37,9 @@ WEIGHTS = {
37
  ),
38
  }
39
 
40
- ONNX_DEVICE = (
41
- "CUDAExecutionProvider"
42
- if onnxruntime.get_device() == "GPU"
43
- else "CPUExecutionProvider"
44
  )
45
 
46
  HIVISION_MODNET_SESS = None
@@ -52,7 +51,7 @@ BIREFNET_V1_LITE_SESS = None
52
  def load_onnx_model(checkpoint_path, set_cpu=False):
53
  providers = (
54
  ["CUDAExecutionProvider", "CPUExecutionProvider"]
55
- if ONNX_DEVICE == "CUDAExecutionProvider"
56
  else ["CPUExecutionProvider"]
57
  )
58
 
@@ -365,7 +364,17 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
365
 
366
  if BIREFNET_V1_LITE_SESS is None:
367
  print("首次加载birefnet-v1-lite模型...")
368
- BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
 
 
 
 
 
 
 
 
 
 
369
 
370
  # 记录加载onnx模型的结束时间
371
  load_end_time = time()
 
37
  ),
38
  }
39
 
40
+ ONNX_DEVICE = onnxruntime.get_device()
41
+ ONNX_PROVIDER = (
42
+ "CUDAExecutionProvider" if ONNX_DEVICE == "GPU" else "CPUExecutionProvider"
 
43
  )
44
 
45
  HIVISION_MODNET_SESS = None
 
51
  def load_onnx_model(checkpoint_path, set_cpu=False):
52
  providers = (
53
  ["CUDAExecutionProvider", "CPUExecutionProvider"]
54
+ if ONNX_PROVIDER == "CUDAExecutionProvider"
55
  else ["CPUExecutionProvider"]
56
  )
57
 
 
364
 
365
  if BIREFNET_V1_LITE_SESS is None:
366
  print("首次加载birefnet-v1-lite模型...")
367
+ if ONNX_DEVICE == "GPU":
368
+ print("onnxruntime-gpu已安装,尝试使用CUDA加载模型")
369
+ try:
370
+ import torch
371
+ except ImportError:
372
+ print(
373
+ "torch未安装,尝试直接使用onnxruntime-gpu加载模型,这需要配置好CUDA和cuDNN"
374
+ )
375
+ BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
376
+ else:
377
+ BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
378
 
379
  # 记录加载onnx模型的结束时间
380
  load_end_time = time()