import argparse import yaml import gdown import os def load_config(path='configs/model.yaml'): with open(path, 'r', encoding='utf-8') as f: return yaml.load(f, Loader=yaml.FullLoader) def str2bool(v): return v.lower() in ("true", "yes", "t", "y", "1") def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=False) parser.add_argument("--use_xpu", type=str2bool, default=False) parser.add_argument("--use_npu", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) parser.add_argument("--gpu_id", type=int, default=0) # params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--page_num", type=int, default=0) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_type", type=str, default='max') parser.add_argument("--det_box_type", type=str, default='quad') # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5) parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) # PSE parmas parser.add_argument("--det_pse_thresh", type=float, default=0) parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) parser.add_argument("--det_pse_min_area", type=float, default=16) parser.add_argument("--det_pse_scale", type=int, default=1) # FCE parmas parser.add_argument("--scales", type=list, default=[8, 16, 32]) parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--fourier_degree", type=int, default=5) # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_image_inverse", type=str2bool, default=True) parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( "--rec_char_dict_path", type=str, default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument( "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") parser.add_argument("--drop_score", type=float, default=0.5) # params for e2e parser.add_argument("--e2e_algorithm", type=str, default='PGNet') parser.add_argument("--e2e_model_dir", type=str) parser.add_argument("--e2e_limit_side_len", type=float, default=768) parser.add_argument("--e2e_limit_type", type=str, default='max') # PGNet parmas parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) parser.add_argument( "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--cls_model_dir", type=str) parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--cls_batch_num", type=int, default=6) parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--warmup", type=str2bool, default=False) # SR parmas parser.add_argument("--sr_model_dir", type=str) parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128") parser.add_argument("--sr_batch_num", type=int, default=1) # parser.add_argument( "--draw_img_save_dir", type=str, default="./inference_results") parser.add_argument("--save_crop_res", type=str2bool, default=False) parser.add_argument("--crop_res_save_dir", type=str, default="./output") # multi-process parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--benchmark", type=str2bool, default=False) parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--use_onnx", type=str2bool, default=False) return parser def get_args(model_params): print(model_params) args, _ = init_args().parse_known_args() for key, val in model_params.items(): setattr(args, key, val) return args def download_model(filename, url, save_path='weights'): gdown.download(url=url, output=os.path.join(save_path, filename), quiet=False, fuzzy=True, use_cookies=False)