File size: 1,574 Bytes
650e244 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
{
"device": "$torch.device(f'cuda:{dist.get_rank()}')",
"network": {
"_target_": "torch.nn.parallel.DistributedDataParallel",
"module": "$@network_def.to(@device)",
"device_ids": [
"@device"
]
},
"train#sampler": {
"_target_": "DistributedSampler",
"dataset": "@train#dataset",
"even_divisible": true,
"shuffle": true
},
"train#dataloader#sampler": "@train#sampler",
"train#dataloader#shuffle": false,
"train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
"validate#sampler": {
"_target_": "DistributedSampler",
"dataset": "@validate#dataset",
"even_divisible": false,
"shuffle": false
},
"validate#dataloader#sampler": "@validate#sampler",
"validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
"training": [
"$import sys",
"$sys.path.append(@bundle_root)",
"$import torch.distributed as dist",
"$dist.init_process_group(backend='nccl')",
"$torch.cuda.set_device(@device)",
"$monai.utils.set_determinism(seed=123)",
"$setattr(torch.backends.cudnn, 'benchmark', True)",
"$import logging",
"$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
"$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
"$@train#trainer.run()",
"$dist.destroy_process_group()"
]
}
|