monai
medical
File size: 2,043 Bytes
618f7d3
 
 
 
 
 
 
 
cf880b5
618f7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f392e
618f7d3
49f392e
618f7d3
 
 
49f392e
618f7d3
49f392e
 
618f7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
---
device: "$torch.device(f'cuda:{dist.get_rank()}')"
network:
  _target_: torch.nn.parallel.DistributedDataParallel
  module: "$@network_def.to(@device)"
  find_unused_parameters: true
  device_ids:
  - "@device"
optimizer#lr: "$0.025*dist.get_world_size()"
lr_scheduler#step_size: "$80*dist.get_world_size()"
train#handlers:
  - _target_: LrScheduleHandler
    lr_scheduler: "@lr_scheduler"
    print_lr: true
  - _target_: ValidationHandler
    validator: "@validate#evaluator"
    epoch_level: true
    interval: "$10*dist.get_world_size()"
  - _target_: StatsHandler
    tag_name: train_loss
    output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
  - _target_: TensorBoardStatsHandler
    log_dir: "@output_dir"
    tag_name: train_loss
    output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
train#trainer#max_epochs: "$400*dist.get_world_size()"
train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]"
validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers"
initialize:
- "$import torch.distributed as dist"
- "$dist.is_initialized() or dist.init_process_group(backend='nccl')"
- "$torch.cuda.set_device(@device)"
- "$monai.utils.set_determinism(seed=123)"
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "$@train#trainer.run()"
finalize:
- "$dist.is_initialized() and dist.destroy_process_group()"
train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(),
  shuffle=True, even_divisible=True,)[dist.get_rank()]"
train#dataset:
  _target_: CacheDataset
  data: "@train_data_partition"
  transform: "@train#preprocessing"
  cache_rate: 1
  num_workers: 4
val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(),
  shuffle=False, even_divisible=False,)[dist.get_rank()]"
validate#dataset:
  _target_: CacheDataset
  data: "@val_data_partition"
  transform: "@validate#preprocessing"
  cache_rate: 1
  num_workers: 4