Spaces:
Configuration error
Configuration error
from lib.config import cfg, args | |
from lib.networks import make_network | |
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler | |
from lib.datasets import make_data_loader | |
from lib.utils.net_utils import load_model, save_model, load_network | |
from lib.evaluators import make_evaluator | |
import torch.multiprocessing | |
import torch | |
import torch.distributed as dist | |
import os | |
if cfg.fix_random: | |
torch.manual_seed(0) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def train(cfg, network): | |
trainer = make_trainer(cfg, network) | |
optimizer = make_optimizer(cfg, network) | |
scheduler = make_lr_scheduler(cfg, optimizer) | |
recorder = make_recorder(cfg) | |
evaluator = make_evaluator(cfg) | |
begin_epoch = load_model(network, | |
optimizer, | |
scheduler, | |
recorder, | |
cfg.trained_model_dir, | |
resume=cfg.resume) | |
set_lr_scheduler(cfg, scheduler) | |
train_loader = make_data_loader(cfg, | |
is_train=True, | |
is_distributed=cfg.distributed, | |
max_iter=cfg.ep_iter) | |
val_loader = make_data_loader(cfg, is_train=False) | |
for epoch in range(begin_epoch, cfg.train.epoch): | |
recorder.epoch = epoch | |
if cfg.distributed: | |
train_loader.batch_sampler.sampler.set_epoch(epoch) | |
trainer.train(epoch, train_loader, optimizer, recorder) | |
scheduler.step() | |
if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0: | |
save_model(network, optimizer, scheduler, recorder, | |
cfg.trained_model_dir, epoch) | |
if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0: | |
save_model(network, | |
optimizer, | |
scheduler, | |
recorder, | |
cfg.trained_model_dir, | |
epoch, | |
last=True) | |
if (epoch + 1) % cfg.eval_ep == 0: | |
trainer.val(epoch, val_loader, evaluator, recorder) | |
return network | |
def test(cfg, network): | |
trainer = make_trainer(cfg, network) | |
val_loader = make_data_loader(cfg, is_train=False) | |
evaluator = make_evaluator(cfg) | |
epoch = load_network(network, | |
cfg.trained_model_dir, | |
resume=cfg.resume, | |
epoch=cfg.test.epoch) | |
trainer.val(epoch, val_loader, evaluator) | |
def synchronize(): | |
""" | |
Helper function to synchronize (barrier) among all processes when | |
using distributed training | |
""" | |
if not dist.is_available(): | |
return | |
if not dist.is_initialized(): | |
return | |
world_size = dist.get_world_size() | |
if world_size == 1: | |
return | |
dist.barrier() | |
def main(): | |
if cfg.distributed: | |
cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() | |
torch.cuda.set_device(cfg.local_rank) | |
torch.distributed.init_process_group(backend="nccl", | |
init_method="env://") | |
synchronize() | |
network = make_network(cfg) | |
if args.test: | |
test(cfg, network) | |
else: | |
train(cfg, network) | |
if __name__ == "__main__": | |
main() | |