monai
medical
katielink's picture
re-train model with updated dints implementation
49f392e
raw
history blame
2.04 kB
---
device: "$torch.device(f'cuda:{dist.get_rank()}')"
network:
_target_: torch.nn.parallel.DistributedDataParallel
module: "$@network_def.to(@device)"
find_unused_parameters: true
device_ids:
- "@device"
optimizer#lr: "$0.025*dist.get_world_size()"
lr_scheduler#step_size: "$80*dist.get_world_size()"
train#handlers:
- _target_: LrScheduleHandler
lr_scheduler: "@lr_scheduler"
print_lr: true
- _target_: ValidationHandler
validator: "@validate#evaluator"
epoch_level: true
interval: "$10*dist.get_world_size()"
- _target_: StatsHandler
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
- _target_: TensorBoardStatsHandler
log_dir: "@output_dir"
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
train#trainer#max_epochs: "$400*dist.get_world_size()"
train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]"
validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers"
initialize:
- "$import torch.distributed as dist"
- "$dist.is_initialized() or dist.init_process_group(backend='nccl')"
- "$torch.cuda.set_device(@device)"
- "$monai.utils.set_determinism(seed=123)"
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "$@train#trainer.run()"
finalize:
- "$dist.is_initialized() and dist.destroy_process_group()"
train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(),
shuffle=True, even_divisible=True,)[dist.get_rank()]"
train#dataset:
_target_: CacheDataset
data: "@train_data_partition"
transform: "@train#preprocessing"
cache_rate: 1
num_workers: 4
val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(),
shuffle=False, even_divisible=False,)[dist.get_rank()]"
validate#dataset:
_target_: CacheDataset
data: "@val_data_partition"
transform: "@validate#preprocessing"
cache_rate: 1
num_workers: 4