PatchFusion / tools /test.py
Zhyever
refactor
1f418ff
raw
history blame
8.81 kB
import os
import os.path as osp
import argparse
import torch
import time
from torch.utils.data import DataLoader
from mmengine.utils import mkdir_or_exist
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger
from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed
from estimator.models.builder import build_model
from estimator.datasets.builder import build_dataset
from estimator.tester import Tester
from estimator.models.patchfusion import PatchFusion
from mmengine import print_log
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work-dir',
help='the dir to save logs and models',
default=None)
parser.add_argument(
'--test-type',
type=str,
default='normal',
help='evaluation type')
parser.add_argument(
'--ckp-path',
type=str,
help='ckp_path')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--save',
action='store_true',
default=False,
help='save colored prediction & depth predictions')
parser.add_argument(
'--cai-mode',
type=str,
default='m1',
help='m1, m2, or rx')
parser.add_argument(
'--process-num',
type=int, default=4,
help='batchsize number for inference')
parser.add_argument(
'--tag',
type=str, default='',
help='infer_infos')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use ckp path as default work_dir if cfg.work_dir is None
if '.pth' in args.ckp_path:
args.work_dir = osp.dirname(args.ckp_path)
else:
args.work_dir = osp.join('work_dir', args.ckp_path.split('/')[1])
cfg.work_dir = args.work_dir
mkdir_or_exist(cfg.work_dir)
cfg.ckp_path = args.ckp_path
# fix seed
seed = cfg.get('seed', 5621)
fix_random_seed(seed)
# start dist training
if cfg.launcher == 'none':
distributed = False
timestamp = torch.tensor(time.time(), dtype=torch.float64)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(timestamp.item()))
rank = 0
world_size = 1
env_cfg = cfg.get('env_cfg')
else:
distributed = True
env_cfg = cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl')))
rank, world_size, timestamp = setup_env(env_cfg, distributed, cfg.launcher)
# build dataloader
if args.test_type == 'consistency':
dataset = build_dataset(cfg.val_consistency_dataloader.dataset)
elif args.test_type == 'normal':
dataset = build_dataset(cfg.val_dataloader.dataset)
elif args.test_type == 'test_in':
dataset = build_dataset(cfg.test_in_dataloader.dataset)
elif args.test_type == 'test_out':
dataset = build_dataset(cfg.test_out_dataloader.dataset)
elif args.test_type == 'general':
dataset = build_dataset(cfg.general_dataloader.dataset)
else:
dataset = build_dataset(cfg.val_dataloader.dataset)
# extract experiment name from cmd
config_path = args.config
exp_cfg_filename = config_path.split('/')[-1].split('.')[0]
ckp_name = args.ckp_path.replace('/', '_').replace('.pth', '')
dataset_name = dataset.dataset_name
# log_filename = 'eval_{}_{}_{}_{}.log'.format(timestamp, exp_cfg_filename, ckp_name, dataset_name)
log_filename = 'eval_{}_{}_{}_{}_{}.log'.format(exp_cfg_filename, args.tag, ckp_name, dataset_name, timestamp)
# prepare basic text logger
log_file = osp.join(args.work_dir, log_filename)
log_cfg = dict(log_level='INFO', log_file=log_file)
log_cfg.setdefault('name', timestamp)
log_cfg.setdefault('logger_name', 'patchstitcher')
# `torch.compile` in PyTorch 2.0 could close all user defined handlers
# unexpectedly. Using file mode 'a' can help prevent abnormal
# termination of the FileHandler and ensure that the log file could
# be continuously updated during the lifespan of the runner.
log_cfg.setdefault('file_mode', 'a')
logger = MMLogger.get_instance(**log_cfg)
# save some information useful during the training
runner_info = RunnerInfo()
runner_info.config = cfg # ideally, cfg should not be changed during process. information should be temp saved in runner_info
runner_info.logger = logger # easier way: use print_log("infos", logger='current')
runner_info.rank = rank
runner_info.distributed = distributed
runner_info.launcher = cfg.launcher
runner_info.seed = seed
runner_info.world_size = world_size
runner_info.work_dir = cfg.work_dir
runner_info.timestamp = timestamp
runner_info.save = args.save
runner_info.log_filename = log_filename
if runner_info.save:
mkdir_or_exist(args.work_dir)
runner_info.work_dir = args.work_dir
# log_env(cfg, env_cfg, runner_info, logger)
# build model
if '.pth' in cfg.ckp_path:
model = build_model(cfg.model)
print_log('Checkpoint Path: {}. Loading from a local file'.format(cfg.ckp_path), logger='current')
if hasattr(model, 'load_dict'):
print_log(model.load_dict(torch.load(cfg.ckp_path)['model_state_dict']), logger='current')
else:
print_log(model.load_state_dict(torch.load(cfg.ckp_path)['model_state_dict'], strict=True), logger='current')
else:
print_log('Checkpoint Path: {}. Loading from the huggingface repo'.format(cfg.ckp_path), logger='current')
assert cfg.ckp_path in \
['Zhyever/patchfusion_depth_anything_vits14',
'Zhyever/patchfusion_depth_anything_vitb14',
'Zhyever/patchfusion_depth_anything_vitl14',
'Zhyever/patchfusion_zoedepth'], 'Invalid model name'
model = PatchFusion.from_pretrained(cfg.ckp_path)
model.eval()
if runner_info.distributed:
torch.cuda.set_device(runner_info.rank)
model.cuda(runner_info.rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[runner_info.rank], output_device=runner_info.rank,
find_unused_parameters=cfg.get('find_unused_parameters', False))
logger.info(model)
else:
model.cuda()
if runner_info.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
else:
val_sampler = None
val_dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=cfg.val_dataloader.num_workers,
pin_memory=True,
persistent_workers=True,
sampler=val_sampler)
# build tester
tester = Tester(
config=cfg,
runner_info=runner_info,
dataloader=val_dataloader,
model=model)
if args.test_type == 'consistency':
tester.run_consistency()
else:
tester.run(args.cai_mode, process_num=args.process_num)
if __name__ == '__main__':
main()