Spaces:
Runtime error
Runtime error
import os | |
import os.path as osp | |
import argparse | |
import torch | |
import torch.nn as nn | |
import wandb | |
from torch.utils.data import DataLoader | |
from mmengine.utils import mkdir_or_exist | |
from mmengine.config import Config, DictAction | |
from mmengine.logging import MMLogger | |
from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed | |
from estimator.models.builder import build_model | |
from estimator.datasets.builder import build_dataset | |
from estimator.trainer import Trainer | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a segmentor') | |
parser.add_argument('config', help='train config file path') | |
parser.add_argument('--work-dir', help='the dir to save logs and models') | |
parser.add_argument( | |
'--resume', | |
action='store_true', | |
default=False, | |
help='resume from the latest checkpoint in the work_dir automatically') | |
parser.add_argument( | |
'--debug', | |
action='store_true', | |
default=False, | |
help='debug mode') | |
parser.add_argument( | |
'--log-name', | |
type=str, default='', | |
help='log_name for wandb') | |
parser.add_argument( | |
'--tags', | |
type=str, default='', | |
help='tags for wandb') | |
parser.add_argument( | |
'--amp', | |
action='store_true', | |
default=False, | |
help='enable automatic-mixed-precision training') | |
parser.add_argument( | |
'--seed', | |
type=int, default=621, | |
help='for debug') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
parser.add_argument( | |
'--launcher', | |
choices=['none', 'pytorch', 'slurm', 'mpi'], | |
default='none', | |
help='job launcher') | |
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch` | |
# will pass the `--local-rank` parameter to `tools/train.py` instead | |
# of `--local_rank`. | |
parser.add_argument('--local_rank', '--local-rank', type=int, default=0) | |
args = parser.parse_args() | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = str(args.local_rank) | |
return args | |
def main(): | |
args = parse_args() | |
# if args.debug: | |
# torch.autograd.set_detect_anomaly(True) # for debug | |
# load config | |
cfg = Config.fromfile(args.config) | |
cfg.launcher = args.launcher | |
if args.cfg_options is not None: | |
cfg.merge_from_dict(args.cfg_options) | |
# work_dir is determined in this priority: CLI > segment in file > filename | |
cfg.work_dir = args.work_dir | |
cfg.work_dir = osp.join(cfg.work_dir, args.log_name) | |
mkdir_or_exist(cfg.work_dir) | |
cfg.debug = args.debug | |
cfg.log_name = args.log_name | |
tags = args.tags | |
if ',' in tags: | |
tag_list = tags.split(',') | |
else: | |
tag_list = [tags] | |
cfg.tags = tag_list | |
# fix seed | |
seed = args.seed | |
fix_random_seed(seed) | |
# start dist training | |
if cfg.launcher == 'none': | |
distributed = False | |
else: | |
distributed = True | |
env_cfg = cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))) | |
rank, world_size, timestamp = setup_env(env_cfg, distributed, cfg.launcher) | |
# prepare basic text logger | |
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') | |
log_cfg = dict(log_level='INFO', log_file=log_file) | |
log_cfg.setdefault('name', timestamp) | |
log_cfg.setdefault('logger_name', 'patchstitcher') | |
# `torch.compile` in PyTorch 2.0 could close all user defined handlers | |
# unexpectedly. Using file mode 'a' can help prevent abnormal | |
# termination of the FileHandler and ensure that the log file could | |
# be continuously updated during the lifespan of the runner. | |
log_cfg.setdefault('file_mode', 'a') | |
logger = MMLogger.get_instance(**log_cfg) | |
# save some information useful during the training | |
runner_info = RunnerInfo() | |
runner_info.config = cfg # ideally, cfg should not be changed during process. information should be temp saved in runner_info | |
runner_info.logger = logger # easier way: use print_log("infos", logger='current') | |
runner_info.rank = rank | |
runner_info.distributed = distributed | |
runner_info.launcher = cfg.launcher | |
runner_info.seed = seed | |
runner_info.world_size = world_size | |
runner_info.work_dir = cfg.work_dir | |
runner_info.timestamp = timestamp | |
# start wandb | |
if runner_info.rank == 0 and cfg.debug == False: | |
wandb.init( | |
project=cfg.project, | |
name=cfg.log_name+"_"+runner_info.timestamp, | |
tags=cfg.tags, | |
dir=runner_info.work_dir, | |
config=cfg, # have a test | |
settings=wandb.Settings(start_method="fork")) | |
wandb.define_metric("Val/step") | |
wandb.define_metric("Val/*", step_metric="Val/step") | |
wandb.define_metric("Train/step") | |
wandb.define_metric("Train/*", step_metric="Train/step") | |
log_env(cfg, env_cfg, runner_info, logger) | |
# resume training (future) | |
cfg.resume = args.resume | |
# build model | |
model = build_model(cfg.model) | |
if runner_info.distributed: | |
torch.cuda.set_device(runner_info.rank) | |
if cfg.get('convert_syncbn', False): | |
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
model = model.cuda(runner_info.rank) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[runner_info.rank], output_device=runner_info.rank, | |
find_unused_parameters=cfg.get('find_unused_parameters', False)) | |
logger.info(model) | |
else: | |
model = model.cuda(runner_info.rank) | |
logger.info(model) | |
# build dataloader | |
dataset = build_dataset(cfg.train_dataloader.dataset) | |
if runner_info.distributed: | |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) | |
else: | |
train_sampler = None | |
train_dataloader = DataLoader( | |
dataset, | |
batch_size=cfg.train_dataloader.batch_size, | |
shuffle=(train_sampler is None), | |
num_workers=cfg.train_dataloader.num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
sampler=train_sampler) | |
dataset = build_dataset(cfg.val_dataloader.dataset) | |
if runner_info.distributed: | |
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) | |
else: | |
val_sampler = None | |
val_dataloader = DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=cfg.val_dataloader.num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
sampler=val_sampler) | |
# everything is ready, start training. But before that, save your config! | |
cfg.dump(osp.join(cfg.work_dir, 'config.py')) | |
# build trainer | |
trainer = Trainer( | |
config=cfg, | |
runner_info=runner_info, | |
train_sampler=train_sampler, | |
train_dataloader=train_dataloader, | |
val_dataloader=val_dataloader, | |
model=model) | |
trainer.run() | |
wandb.finish() | |
if __name__ == '__main__': | |
main() |