CapeX / test.py
matanru's picture
initial commit
93b49a4
raw
history blame
5.46 kB
import argparse
import os
import os.path as osp
import random
import uuid
import mmcv
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from models import * # noqa
from models.datasets import build_dataset
from mmpose.apis import multi_gpu_test, single_gpu_test
from mmpose.core import wrap_fp16_model
from mmpose.datasets import build_dataloader
from mmpose.models import build_posenet
def parse_args():
parser = argparse.ArgumentParser(description='mmpose test model')
parser.add_argument('config', default=None, help='test config file path')
parser.add_argument('checkpoint', default=None, help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase the inference speed')
parser.add_argument(
'--eval',
default=None,
nargs='+',
help='evaluation metric, which depends on the dataset,'
' e.g., "mAP" for MSCOCO')
parser.add_argument(
'--permute_keypoints',
action='store_true',
help='whether to randomly permute keypoints')
parser.add_argument(
'--gpu_collect',
action='store_true',
help='whether to use gpu to collect results')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def merge_configs(cfg1, cfg2):
# Merge cfg2 into cfg1
# Overwrite cfg1 if repeated, ignore if value is None.
cfg1 = {} if cfg1 is None else cfg1.copy()
cfg2 = {} if cfg2 is None else cfg2
for k, v in cfg2.items():
if v:
cfg1[k] = v
return cfg1
def main():
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
uuid.UUID(int=0)
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# cfg.model.pretrained = None
cfg.data.test.test_mode = True
args.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(
samples_per_gpu=1,
workers_per_gpu=cfg.data.get('workers_per_gpu', 12),
dist=distributed,
shuffle=False,
drop_last=False)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)
# build the model and load checkpoint
model = build_posenet(cfg.model)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect)
rank, _ = get_dist_info()
eval_config = cfg.get('evaluation', {})
eval_config = merge_configs(eval_config, dict(metric=args.eval))
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
results = dataset.evaluate(outputs, **eval_config)
print('\n')
for k, v in sorted(results.items()):
print(f'{k}: {v}')
# save testing log
test_log = "./work_dirs/testing_log.txt"
with open(test_log, 'a') as f:
f.write("** config_file: " + args.config + "\t checkpoint: " + args.checkpoint + "\t \n")
for k, v in sorted(results.items()):
f.write(f'\t {k}: {v}'+'\n')
f.write("********************************************************************\n")
if __name__ == '__main__':
main()