File size: 3,823 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) OpenMMLab. All rights reserved.
"""Deploy demo for mmdeploy.

This script help user to run mmdeploy demo after convert the
checkpoint to backends.

Usage:
    python deploy_demo.py img \
                          config \
                          checkpoint \
                          [--deploy-cfg DEPLOY_CFG] \
                          [--device DEVICE] \
                          [--out-dir OUT_DIR] \
                          [--show] \
                          [--score-thr SCORE_THR]

Example:
    python deploy_demo.py \
        ${MMYOLO_PATH}/data/cat/images \
        ./yolov5_s-v61_syncbn_fast_1xb32-100e_cat.py \
        ./end2end.engine \
        --deploy-cfg ./detection_tensorrt-fp16_dynamic-192x192-960x960.py \
        --out-dir ${MMYOLO_PATH}/work_dirs/deploy_predict_out \
        --device cuda:0 \
        --score-thr 0.5
"""
import argparse
import os

import torch
from mmengine import ProgressBar

from mmyolo.utils.misc import get_file_list

try:
    from mmdeploy.apis.utils import build_task_processor
    from mmdeploy.utils import get_input_shape, load_config
except ImportError:
    raise ImportError(
        'mmdeploy is not installed, please see '
        'https://mmdeploy.readthedocs.io/en/1.x/01-how-to-build/build_from_source.html'  # noqa
    )


def parse_args():
    parser = argparse.ArgumentParser(description='For mmdeploy predict')
    parser.add_argument(
        'img', help='Image path, include image file, dir and URL.')
    parser.add_argument('config', help='model config root')
    parser.add_argument('checkpoint', help='checkpoint backend model path')
    parser.add_argument('--deploy-cfg', help='deploy config path')
    parser.add_argument(
        '--device', default='cuda:0', help='device used for conversion')
    parser.add_argument(
        '--out-dir', default='./output', help='Path to output file')
    parser.add_argument(
        '--show', action='store_true', help='Show the detection results')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
    args = parser.parse_args()
    return args


# TODO Still need to refactor to not building dataset.
def main():
    args = parse_args()

    if not os.path.exists(args.out_dir) and not args.show:
        os.mkdir(args.out_dir)

    # read deploy_cfg and config
    deploy_cfg, model_cfg = load_config(args.deploy_cfg, args.config)

    # build task and backend model
    task_processor = build_task_processor(model_cfg, deploy_cfg, args.device)
    model = task_processor.build_backend_model([args.checkpoint])

    # get model input shape
    input_shape = get_input_shape(deploy_cfg)

    # get file list
    files, source_type = get_file_list(args.img)

    # start detector inference
    progress_bar = ProgressBar(len(files))
    for file in files:
        # process input image
        model_inputs, _ = task_processor.create_input(file, input_shape)

        # do model inference
        with torch.no_grad():
            result = model.test_step(model_inputs)

        if source_type['is_dir']:
            filename = os.path.relpath(file, args.img).replace('/', '_')
        else:
            filename = os.path.basename(file)
        out_file = None if args.show else os.path.join(args.out_dir, filename)

        # filter score
        result = result[0]
        result.pred_instances = result.pred_instances[
            result.pred_instances.scores > args.score_thr]

        # visualize results
        task_processor.visualize(
            image=file,
            model=model,
            result=result,
            show_result=args.show,
            window_name=os.path.basename(filename),
            output_file=out_file)

        progress_bar.update()

    print('All done!')


if __name__ == '__main__':
    main()