from lib.config import cfg, args from lib.networks import make_network from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler from lib.datasets import make_data_loader from lib.utils.net_utils import load_model, save_model, load_network from lib.evaluators import make_evaluator import torch.multiprocessing import torch import torch.distributed as dist import os if cfg.fix_random: torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train(cfg, network): trainer = make_trainer(cfg, network) optimizer = make_optimizer(cfg, network) scheduler = make_lr_scheduler(cfg, optimizer) recorder = make_recorder(cfg) evaluator = make_evaluator(cfg) begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.trained_model_dir, resume=cfg.resume) set_lr_scheduler(cfg, scheduler) train_loader = make_data_loader(cfg, is_train=True, is_distributed=cfg.distributed, max_iter=cfg.ep_iter) val_loader = make_data_loader(cfg, is_train=False) for epoch in range(begin_epoch, cfg.train.epoch): recorder.epoch = epoch if cfg.distributed: train_loader.batch_sampler.sampler.set_epoch(epoch) trainer.train(epoch, train_loader, optimizer, recorder) scheduler.step() if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0: save_model(network, optimizer, scheduler, recorder, cfg.trained_model_dir, epoch) if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0: save_model(network, optimizer, scheduler, recorder, cfg.trained_model_dir, epoch, last=True) if (epoch + 1) % cfg.eval_ep == 0: trainer.val(epoch, val_loader, evaluator, recorder) return network def test(cfg, network): trainer = make_trainer(cfg, network) val_loader = make_data_loader(cfg, is_train=False) evaluator = make_evaluator(cfg) epoch = load_network(network, cfg.trained_model_dir, resume=cfg.resume, epoch=cfg.test.epoch) trainer.val(epoch, val_loader, evaluator) def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def main(): if cfg.distributed: cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() torch.cuda.set_device(cfg.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() network = make_network(cfg) if args.test: test(cfg, network) else: train(cfg, network) if __name__ == "__main__": main()