Spaces:
Runtime error
Runtime error
import os | |
import os.path as osp | |
import argparse | |
import torch | |
import time | |
from torch.utils.data import DataLoader | |
from mmengine.utils import mkdir_or_exist | |
from mmengine.config import Config, DictAction | |
from mmengine.logging import MMLogger | |
from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed | |
from estimator.models.builder import build_model | |
from estimator.datasets.builder import build_dataset | |
from estimator.tester import Tester | |
from estimator.models.patchfusion import PatchFusion | |
from mmengine import print_log | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a segmentor') | |
parser.add_argument('config', help='train config file path') | |
parser.add_argument( | |
'--work-dir', | |
help='the dir to save logs and models', | |
default=None) | |
parser.add_argument( | |
'--test-type', | |
type=str, | |
default='normal', | |
help='evaluation type') | |
parser.add_argument( | |
'--ckp-path', | |
type=str, | |
help='ckp_path') | |
parser.add_argument( | |
'--amp', | |
action='store_true', | |
default=False, | |
help='enable automatic-mixed-precision training') | |
parser.add_argument( | |
'--save', | |
action='store_true', | |
default=False, | |
help='save colored prediction & depth predictions') | |
parser.add_argument( | |
'--cai-mode', | |
type=str, | |
default='m1', | |
help='m1, m2, or rx') | |
parser.add_argument( | |
'--process-num', | |
type=int, default=4, | |
help='batchsize number for inference') | |
parser.add_argument( | |
'--tag', | |
type=str, default='', | |
help='infer_infos') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
parser.add_argument( | |
'--launcher', | |
choices=['none', 'pytorch', 'slurm', 'mpi'], | |
default='none', | |
help='job launcher') | |
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch` | |
# will pass the `--local-rank` parameter to `tools/train.py` instead | |
# of `--local_rank`. | |
parser.add_argument('--local_rank', '--local-rank', type=int, default=0) | |
args = parser.parse_args() | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = str(args.local_rank) | |
return args | |
def main(): | |
args = parse_args() | |
# load config | |
cfg = Config.fromfile(args.config) | |
cfg.launcher = args.launcher | |
if args.cfg_options is not None: | |
cfg.merge_from_dict(args.cfg_options) | |
# work_dir is determined in this priority: CLI > segment in file > filename | |
if args.work_dir is not None: | |
# update configs according to CLI args if args.work_dir is not None | |
cfg.work_dir = args.work_dir | |
elif cfg.get('work_dir', None) is None: | |
# use ckp path as default work_dir if cfg.work_dir is None | |
if '.pth' in args.ckp_path: | |
args.work_dir = osp.dirname(args.ckp_path) | |
else: | |
args.work_dir = osp.join('work_dir', args.ckp_path.split('/')[1]) | |
cfg.work_dir = args.work_dir | |
mkdir_or_exist(cfg.work_dir) | |
cfg.ckp_path = args.ckp_path | |
# fix seed | |
seed = cfg.get('seed', 5621) | |
fix_random_seed(seed) | |
# start dist training | |
if cfg.launcher == 'none': | |
distributed = False | |
timestamp = torch.tensor(time.time(), dtype=torch.float64) | |
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(timestamp.item())) | |
rank = 0 | |
world_size = 1 | |
env_cfg = cfg.get('env_cfg') | |
else: | |
distributed = True | |
env_cfg = cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))) | |
rank, world_size, timestamp = setup_env(env_cfg, distributed, cfg.launcher) | |
# build dataloader | |
if args.test_type == 'consistency': | |
dataset = build_dataset(cfg.val_consistency_dataloader.dataset) | |
elif args.test_type == 'normal': | |
dataset = build_dataset(cfg.val_dataloader.dataset) | |
elif args.test_type == 'test_in': | |
dataset = build_dataset(cfg.test_in_dataloader.dataset) | |
elif args.test_type == 'test_out': | |
dataset = build_dataset(cfg.test_out_dataloader.dataset) | |
elif args.test_type == 'general': | |
dataset = build_dataset(cfg.general_dataloader.dataset) | |
else: | |
dataset = build_dataset(cfg.val_dataloader.dataset) | |
# extract experiment name from cmd | |
config_path = args.config | |
exp_cfg_filename = config_path.split('/')[-1].split('.')[0] | |
ckp_name = args.ckp_path.replace('/', '_').replace('.pth', '') | |
dataset_name = dataset.dataset_name | |
# log_filename = 'eval_{}_{}_{}_{}.log'.format(timestamp, exp_cfg_filename, ckp_name, dataset_name) | |
log_filename = 'eval_{}_{}_{}_{}_{}.log'.format(exp_cfg_filename, args.tag, ckp_name, dataset_name, timestamp) | |
# prepare basic text logger | |
log_file = osp.join(args.work_dir, log_filename) | |
log_cfg = dict(log_level='INFO', log_file=log_file) | |
log_cfg.setdefault('name', timestamp) | |
log_cfg.setdefault('logger_name', 'patchstitcher') | |
# `torch.compile` in PyTorch 2.0 could close all user defined handlers | |
# unexpectedly. Using file mode 'a' can help prevent abnormal | |
# termination of the FileHandler and ensure that the log file could | |
# be continuously updated during the lifespan of the runner. | |
log_cfg.setdefault('file_mode', 'a') | |
logger = MMLogger.get_instance(**log_cfg) | |
# save some information useful during the training | |
runner_info = RunnerInfo() | |
runner_info.config = cfg # ideally, cfg should not be changed during process. information should be temp saved in runner_info | |
runner_info.logger = logger # easier way: use print_log("infos", logger='current') | |
runner_info.rank = rank | |
runner_info.distributed = distributed | |
runner_info.launcher = cfg.launcher | |
runner_info.seed = seed | |
runner_info.world_size = world_size | |
runner_info.work_dir = cfg.work_dir | |
runner_info.timestamp = timestamp | |
runner_info.save = args.save | |
runner_info.log_filename = log_filename | |
if runner_info.save: | |
mkdir_or_exist(args.work_dir) | |
runner_info.work_dir = args.work_dir | |
# log_env(cfg, env_cfg, runner_info, logger) | |
# build model | |
if '.pth' in cfg.ckp_path: | |
model = build_model(cfg.model) | |
print_log('Checkpoint Path: {}. Loading from a local file'.format(cfg.ckp_path), logger='current') | |
if hasattr(model, 'load_dict'): | |
print_log(model.load_dict(torch.load(cfg.ckp_path)['model_state_dict']), logger='current') | |
else: | |
print_log(model.load_state_dict(torch.load(cfg.ckp_path)['model_state_dict'], strict=True), logger='current') | |
else: | |
print_log('Checkpoint Path: {}. Loading from the huggingface repo'.format(cfg.ckp_path), logger='current') | |
assert cfg.ckp_path in \ | |
['Zhyever/patchfusion_depth_anything_vits14', | |
'Zhyever/patchfusion_depth_anything_vitb14', | |
'Zhyever/patchfusion_depth_anything_vitl14', | |
'Zhyever/patchfusion_zoedepth'], 'Invalid model name' | |
model = PatchFusion.from_pretrained(cfg.ckp_path) | |
model.eval() | |
if runner_info.distributed: | |
torch.cuda.set_device(runner_info.rank) | |
model.cuda(runner_info.rank) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[runner_info.rank], output_device=runner_info.rank, | |
find_unused_parameters=cfg.get('find_unused_parameters', False)) | |
logger.info(model) | |
else: | |
model.cuda() | |
if runner_info.distributed: | |
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) | |
else: | |
val_sampler = None | |
val_dataloader = DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=cfg.val_dataloader.num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
sampler=val_sampler) | |
# build tester | |
tester = Tester( | |
config=cfg, | |
runner_info=runner_info, | |
dataloader=val_dataloader, | |
model=model) | |
if args.test_type == 'consistency': | |
tester.run_consistency() | |
else: | |
tester.run(args.cai_mode, process_num=args.process_num) | |
if __name__ == '__main__': | |
main() |