Spaces:
Running
on
T4
Running
on
T4
File size: 5,401 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import argparse
import os
import sys
import warnings
from io import BytesIO
from pathlib import Path
import onnx
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist
# Add MMYOLO ROOT to sys.path
sys.path.append(str(Path(__file__).resolve().parents[3]))
from projects.easydeploy.model import DeployModel, MMYOLOBackend # noqa E402
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--model-only', action='store_true', help='Export model only')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--simplify',
action='store_true',
help='Simplify onnx model by onnx-sim')
parser.add_argument(
'--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument(
'--backend',
type=str,
default='onnxruntime',
help='Backend for export onnx')
parser.add_argument(
'--pre-topk',
type=int,
default=1000,
help='Postprocess pre topk bboxes feed into NMS')
parser.add_argument(
'--keep-topk',
type=int,
default=100,
help='Postprocess keep topk bboxes out of NMS')
parser.add_argument(
'--iou-threshold',
type=float,
default=0.65,
help='IoU threshold for NMS')
parser.add_argument(
'--score-threshold',
type=float,
default=0.25,
help='Score threshold for NMS')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model
def main():
args = parse_args()
mkdir_or_exist(args.work_dir)
backend = MMYOLOBackend(args.backend.lower())
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
args.model_only = True
print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
f'Set "args.model_only=True" default.')
if args.model_only:
postprocess_cfg = None
output_names = None
else:
postprocess_cfg = ConfigDict(
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold)
output_names = ['num_dets', 'boxes', 'scores', 'labels']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
deploy_model = DeployModel(
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg)
deploy_model.eval()
fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)
save_onnx_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
# export onnx
with BytesIO() as f:
torch.onnx.export(
deploy_model,
fake_input,
f,
input_names=['images'],
output_names=output_names,
opset_version=args.opset)
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
# Fix tensorrt onnx output shape, just for view
if not args.model_only and backend in (MMYOLOBackend.TENSORRT8,
MMYOLOBackend.TENSORRT7):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
args.keep_topk
]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print_log(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print_log(f'ONNX export success, save into {save_onnx_path}')
if __name__ == '__main__':
main()
|