monai
medical
katielink commited on
Commit
618f7d3
1 Parent(s): 2784532

complete the model package

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/model.ts filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: unknown
7
+ ---
8
+ # Description
9
+ A neural architecture search algorithm for volumetric (3D) segmentation of the pancreas and pancreatic tumor from CT image.
10
+
11
+ # Model Overview
12
+ This model is trained using the state-of-the-art algorithm [1] of the "Medical Segmentation Decathlon Challenge 2018" with 196 training images, 56 validation images, and 28 testing images.
13
+
14
+ ## Data
15
+ The training dataset is Task07_Pancreas.tar from http://medicaldecathlon.com/. And the data list/split can be created with the script `scripts/prepare_datalist.py`.
16
+
17
+ ## Training configuration
18
+ The training was performed with at least 16GB-memory GPUs.
19
+
20
+ Actual Model Input: 96 x 96 x 96
21
+
22
+ ## Input and output formats
23
+ Input: 1 channel CT image
24
+
25
+ Output: 3 channels: Label 2: pancreatic tumor; Label 1: pancreas; Label 0: everything else
26
+
27
+ ## Scores
28
+ This model achieves the following Dice score on the validation data (our own split from the training dataset):
29
+
30
+ Mean Dice = 0.72
31
+
32
+ ## commands example
33
+ Create data split (.json file):
34
+
35
+ ```
36
+ python scripts/prepare_datalist.py --path /path-to-Task07_Pancreas/ --output configs/dataset_0.json
37
+ ```
38
+
39
+ Execute model searching:
40
+
41
+ ```
42
+ python -m scripts.search run --config_file configs/search.yaml
43
+ ```
44
+
45
+ Execute multi-GPU model searching (recommended):
46
+
47
+ ```
48
+ torchrun --nnodes=1 --nproc_per_node=8 -m scripts.search run --config_file configs/search.yaml
49
+ ```
50
+
51
+ Execute training:
52
+
53
+ ```
54
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.yaml --logging_file configs/logging.conf
55
+ ```
56
+
57
+ Override the `train` config to execute multi-GPU training:
58
+
59
+ ```
60
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']" --logging_file configs/logging.conf
61
+ ```
62
+
63
+ Override the `train` config to execute evaluation with the trained model:
64
+
65
+ ```
66
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/evaluate.yaml']" --logging_file configs/logging.conf
67
+ ```
68
+
69
+ Execute inference:
70
+
71
+ ```
72
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.yaml --logging_file configs/logging.conf
73
+ ```
74
+
75
+ # Disclaimer
76
+ This is an example, not to be used for diagnostic purposes.
77
+
78
+ # References
79
+ [1] He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850).
configs/evaluate.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ validate#postprocessing:
3
+ _target_: Compose
4
+ transforms:
5
+ - _target_: Activationsd
6
+ keys: pred
7
+ softmax: true
8
+ - _target_: Invertd
9
+ keys:
10
+ - pred
11
+ - label
12
+ transform: "@validate#preprocessing"
13
+ orig_keys: image
14
+ meta_key_postfix: meta_dict
15
+ nearest_interp:
16
+ - false
17
+ - true
18
+ to_tensor: true
19
+ - _target_: AsDiscreted
20
+ keys:
21
+ - pred
22
+ - label
23
+ argmax:
24
+ - true
25
+ - false
26
+ to_onehot: 3
27
+ - _target_: CopyItemsd
28
+ keys: "pred"
29
+ times: 1
30
+ names: "pred_save"
31
+ - _target_: AsDiscreted
32
+ keys:
33
+ - pred_save
34
+ argmax:
35
+ - true
36
+ - _target_: SaveImaged
37
+ keys: pred_save
38
+ meta_keys: pred_meta_dict
39
+ output_dir: "@output_dir"
40
+ resample: false
41
+ squeeze_end_dims: true
42
+ validate#dataset:
43
+ _target_: Dataset
44
+ data: "@val_datalist"
45
+ transform: "@validate#preprocessing"
46
+ validate#handlers:
47
+ - _target_: CheckpointLoader
48
+ load_path: "$@ckpt_dir + '/model.pt'"
49
+ load_dict:
50
+ model: "@network"
51
+ - _target_: StatsHandler
52
+ iteration_log: false
53
+ - _target_: MetricsSaver
54
+ save_dir: "@output_dir"
55
+ metrics:
56
+ - val_mean_dice
57
+ - val_acc
58
+ metric_details:
59
+ - val_mean_dice
60
+ batch_transform: "$monai.handlers.from_engine(['image_meta_dict'])"
61
+ summary_ops: "*"
62
+ evaluating:
63
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
64
+ - "$@validate#evaluator.run()"
configs/inference.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ imports:
3
+ - "$import glob"
4
+ - "$import os"
5
+ input_channels: 1
6
+ output_classes: 3
7
+ arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
8
+ arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
9
+ bundle_root: "/workspace/MONAI/model-zoo/models/pancreas_ct_dints_segmentation"
10
+ output_dir: "$@bundle_root + '/eval'"
11
+ dataset_dir: "/workspace/data/msd/Task07_Pancreas"
12
+ data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
13
+ datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='testing',
14
+ base_dir=@dataset_dir)"
15
+ device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
16
+ dints_space:
17
+ _target_: monai.networks.nets.TopologyInstance
18
+ channel_mul: 1
19
+ num_blocks: 12
20
+ num_depths: 4
21
+ use_downsample: true
22
+ arch_code:
23
+ - "$@arch_ckpt['arch_code_a']"
24
+ - "$@arch_ckpt['arch_code_c']"
25
+ device: "$torch.device('cuda')"
26
+ network_def:
27
+ _target_: monai.networks.nets.DiNTS
28
+ dints_space: "@dints_space"
29
+ in_channels: "@input_channels"
30
+ num_classes: "@output_classes"
31
+ use_downsample: true
32
+ node_a: "$torch.from_numpy(@arch_ckpt['node_a'])"
33
+ network: "$@network_def.to(@device)"
34
+ preprocessing:
35
+ _target_: Compose
36
+ transforms:
37
+ - _target_: LoadImaged
38
+ keys: image
39
+ - _target_: EnsureChannelFirstd
40
+ keys: image
41
+ - _target_: Orientationd
42
+ keys: image
43
+ axcodes: RAS
44
+ - _target_: Spacingd
45
+ keys: image
46
+ pixdim:
47
+ - 1
48
+ - 1
49
+ - 1
50
+ mode: bilinear
51
+ - _target_: ScaleIntensityRanged
52
+ keys: image
53
+ a_min: -87
54
+ a_max: 199
55
+ b_min: 0
56
+ b_max: 1
57
+ clip: true
58
+ - _target_: EnsureTyped
59
+ keys: image
60
+ dataset:
61
+ _target_: Dataset
62
+ data: "@datalist"
63
+ transform: "@preprocessing"
64
+ dataloader:
65
+ _target_: DataLoader
66
+ dataset: "@dataset"
67
+ batch_size: 1
68
+ shuffle: false
69
+ num_workers: 4
70
+ inferer:
71
+ _target_: SlidingWindowInferer
72
+ roi_size:
73
+ - 96
74
+ - 96
75
+ - 96
76
+ sw_batch_size: 4
77
+ overlap: 0.625
78
+ postprocessing:
79
+ _target_: Compose
80
+ transforms:
81
+ - _target_: Activationsd
82
+ keys: pred
83
+ softmax: true
84
+ - _target_: Invertd
85
+ keys: pred
86
+ transform: "@preprocessing"
87
+ orig_keys: image
88
+ meta_key_postfix: meta_dict
89
+ nearest_interp: false
90
+ to_tensor: true
91
+ - _target_: AsDiscreted
92
+ keys: pred
93
+ argmax: true
94
+ - _target_: SaveImaged
95
+ keys: pred
96
+ meta_keys: pred_meta_dict
97
+ output_dir: "@output_dir"
98
+ handlers:
99
+ - _target_: CheckpointLoader
100
+ load_path: "$@bundle_root + '/models/model.pt'"
101
+ load_dict:
102
+ model: "@network"
103
+ - _target_: StatsHandler
104
+ iteration_log: false
105
+ evaluator:
106
+ _target_: SupervisedEvaluator
107
+ device: "@device"
108
+ val_data_loader: "@dataloader"
109
+ network: "@network"
110
+ inferer: "@inferer"
111
+ postprocessing: "@postprocessing"
112
+ val_handlers: "@handlers"
113
+ amp: true
114
+ evaluating:
115
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
116
+ - "$@evaluator.run()"
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.1.0",
4
+ "changelog": {
5
+ "0.1.0": "complete the model package",
6
+ "0.0.1": "initialize the model package structure"
7
+ },
8
+ "monai_version": "0.9.0",
9
+ "pytorch_version": "1.12.0",
10
+ "numpy_version": "1.21.2",
11
+ "optional_packages_version": {
12
+ "fire": "0.4.0",
13
+ "nibabel": "3.2.1",
14
+ "pytorch-ignite": "0.4.9"
15
+ },
16
+ "task": "Neural architecture search on pancreas CT segmentation",
17
+ "description": "Searched architectures for volumetric (3D) segmentation of the pancreas from CT image",
18
+ "authors": "MONAI team",
19
+ "copyright": "Copyright (c) MONAI Consortium",
20
+ "data_source": "Task07_Pancreas.tar from http://medicaldecathlon.com/",
21
+ "data_type": "nibabel",
22
+ "image_classes": "single channel data, intensity scaled to [0, 1]",
23
+ "label_classes": "single channel data, 1 is pancreas, 2 is pancreatic tumor, 0 is everything else",
24
+ "pred_classes": "3 channels OneHot data, channel 1 is pancreas, channel 2 is pancreatic tumor, channel 0 is background",
25
+ "eval_metrics": {
26
+ "mean_dice": 0.72
27
+ },
28
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
29
+ "references": [
30
+ "He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850)."
31
+ ],
32
+ "network_data_format": {
33
+ "inputs": {
34
+ "image": {
35
+ "type": "image",
36
+ "format": "hounsfield",
37
+ "modality": "CT",
38
+ "num_channels": 1,
39
+ "spatial_shape": [
40
+ 96,
41
+ 96,
42
+ 96
43
+ ],
44
+ "dtype": "float32",
45
+ "value_range": [
46
+ 0,
47
+ 1
48
+ ],
49
+ "is_patch_data": true,
50
+ "channel_def": {
51
+ "0": "image"
52
+ }
53
+ }
54
+ },
55
+ "outputs": {
56
+ "pred": {
57
+ "type": "image",
58
+ "format": "segmentation",
59
+ "num_channels": 3,
60
+ "spatial_shape": [
61
+ 96,
62
+ 96,
63
+ 96
64
+ ],
65
+ "dtype": "float32",
66
+ "value_range": [
67
+ 0,
68
+ 1,
69
+ 2
70
+ ],
71
+ "is_patch_data": true,
72
+ "channel_def": {
73
+ "0": "background",
74
+ "1": "pancreas",
75
+ "2": "pancreatic tumor"
76
+ }
77
+ }
78
+ }
79
+ }
80
+ }
configs/multi_gpu_train.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ device: "$torch.device(f'cuda:{dist.get_rank()}')"
3
+ network:
4
+ _target_: torch.nn.parallel.DistributedDataParallel
5
+ module: "$@network_def.to(@device)"
6
+ find_unused_parameters: true
7
+ device_ids:
8
+ - "@device"
9
+ optimizer#lr: "$0.0125*dist.get_world_size()"
10
+ lr_scheduler#step_size: "$80*dist.get_world_size()"
11
+ train#handlers:
12
+ - _target_: LrScheduleHandler
13
+ lr_scheduler: "@lr_scheduler"
14
+ print_lr: true
15
+ - _target_: ValidationHandler
16
+ validator: "@validate#evaluator"
17
+ epoch_level: true
18
+ interval: "$10*dist.get_world_size()"
19
+ - _target_: StatsHandler
20
+ tag_name: train_loss
21
+ output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
22
+ - _target_: TensorBoardStatsHandler
23
+ log_dir: "@output_dir"
24
+ tag_name: train_loss
25
+ output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
26
+ train#trainer#max_epochs: "$400*dist.get_world_size()"
27
+ train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]"
28
+ validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers"
29
+ training:
30
+ - "$import torch.distributed as dist"
31
+ - "$dist.init_process_group(backend='nccl')"
32
+ - "$torch.cuda.set_device(@device)"
33
+ - "$monai.utils.set_determinism(seed=123)"
34
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
35
+ - "$@train#trainer.run()"
36
+ - "$dist.destroy_process_group()"
37
+ train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(),
38
+ shuffle=True, even_divisible=True,)[dist.get_rank()]"
39
+ train#dataset:
40
+ _target_: CacheDataset
41
+ data: "@train_data_partition"
42
+ transform: "@train#preprocessing"
43
+ cache_rate: 1
44
+ num_workers: 4
45
+ val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(),
46
+ shuffle=False, even_divisible=False,)[dist.get_rank()]"
47
+ validate#dataset:
48
+ _target_: CacheDataset
49
+ data: "@val_data_partition"
50
+ transform: "@validate#preprocessing"
51
+ cache_rate: 1
52
+ num_workers: 4
configs/search.yaml ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ imports:
3
+ - "$from scipy import ndimage"
4
+ arch_ckpt_path: models
5
+ amp: true
6
+ data_file_base_dir: /workspace/data/msd/Task07_Pancreas
7
+ data_list_file_path: configs/dataset_0.json
8
+ determ: true
9
+ input_channels: 1
10
+ learning_rate: 0.025
11
+ learning_rate_arch: 0.001
12
+ learning_rate_milestones:
13
+ - 0.4
14
+ - 0.8
15
+ num_images_per_batch: 1
16
+ num_epochs: 1430
17
+ num_epochs_per_validation: 100
18
+ num_epochs_warmup: 715
19
+ num_patches_per_image: 1
20
+ num_sw_batch_size: 6
21
+ output_classes: 3
22
+ overlap_ratio: 0.625
23
+ patch_size:
24
+ - 96
25
+ - 96
26
+ - 96
27
+ patch_size_valid:
28
+ - 96
29
+ - 96
30
+ - 96
31
+ ram_cost_factor: 0.8
32
+ image_key: image
33
+ label_key: label
34
+ transform_train:
35
+ _target_: Compose
36
+ transforms:
37
+ - _target_: LoadImaged
38
+ keys:
39
+ - "@image_key"
40
+ - "@label_key"
41
+ - _target_: EnsureChannelFirstd
42
+ keys:
43
+ - "@image_key"
44
+ - "@label_key"
45
+ - _target_: Orientationd
46
+ keys:
47
+ - "@image_key"
48
+ - "@label_key"
49
+ axcodes: RAS
50
+ - _target_: Spacingd
51
+ keys:
52
+ - "@image_key"
53
+ - "@label_key"
54
+ pixdim:
55
+ - 1
56
+ - 1
57
+ - 1
58
+ mode:
59
+ - bilinear
60
+ - nearest
61
+ align_corners:
62
+ - true
63
+ - true
64
+ - _target_: CastToTyped
65
+ keys: "@image_key"
66
+ dtype: "$torch.float32"
67
+ - _target_: ScaleIntensityRanged
68
+ keys: "@image_key"
69
+ a_min: -87
70
+ a_max: 199
71
+ b_min: 0
72
+ b_max: 1
73
+ clip: true
74
+ - _target_: CastToTyped
75
+ keys:
76
+ - "@image_key"
77
+ - "@label_key"
78
+ dtype:
79
+ - "$np.float16"
80
+ - "$np.uint8"
81
+ - _target_: CopyItemsd
82
+ keys: "@label_key"
83
+ times: 1
84
+ names:
85
+ - label4crop
86
+ - _target_: Lambdad
87
+ keys: label4crop
88
+ func: "$lambda x, s=@output_classes: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(s)]), axis=0)"
89
+ overwrite: true
90
+ - _target_: EnsureTyped
91
+ keys:
92
+ - "@image_key"
93
+ - "@label_key"
94
+ - _target_: CastToTyped
95
+ keys: "@image_key"
96
+ dtype: "$torch.float32"
97
+ - _target_: SpatialPadd
98
+ keys:
99
+ - "@image_key"
100
+ - "@label_key"
101
+ - label4crop
102
+ spatial_size: "@patch_size"
103
+ mode:
104
+ - reflect
105
+ - constant
106
+ - constant
107
+ - _target_: RandCropByLabelClassesd
108
+ keys:
109
+ - "@image_key"
110
+ - "@label_key"
111
+ label_key: label4crop
112
+ num_classes: "@output_classes"
113
+ ratios: "$[1,] * @output_classes"
114
+ spatial_size: "@patch_size"
115
+ num_samples: "@num_patches_per_image"
116
+ - _target_: Lambdad
117
+ keys: label4crop
118
+ func: "$lambda x: 0"
119
+ - _target_: RandRotated
120
+ keys:
121
+ - "@image_key"
122
+ - "@label_key"
123
+ range_x: 0.3
124
+ range_y: 0.3
125
+ range_z: 0.3
126
+ mode:
127
+ - bilinear
128
+ - nearest
129
+ prob: 0.2
130
+ - _target_: RandZoomd
131
+ keys:
132
+ - "@image_key"
133
+ - "@label_key"
134
+ min_zoom: 0.8
135
+ max_zoom: 1.2
136
+ mode:
137
+ - trilinear
138
+ - nearest
139
+ align_corners:
140
+ - null
141
+ - null
142
+ prob: 0.16
143
+ - _target_: RandGaussianSmoothd
144
+ keys: "@image_key"
145
+ sigma_x:
146
+ - 0.5
147
+ - 1.15
148
+ sigma_y:
149
+ - 0.5
150
+ - 1.15
151
+ sigma_z:
152
+ - 0.5
153
+ - 1.15
154
+ prob: 0.15
155
+ - _target_: RandScaleIntensityd
156
+ keys: "@image_key"
157
+ factors: 0.3
158
+ prob: 0.5
159
+ - _target_: RandShiftIntensityd
160
+ keys: "@image_key"
161
+ offsets: 0.1
162
+ prob: 0.5
163
+ - _target_: RandGaussianNoised
164
+ keys: "@image_key"
165
+ std: 0.01
166
+ prob: 0.15
167
+ - _target_: RandFlipd
168
+ keys:
169
+ - "@image_key"
170
+ - "@label_key"
171
+ spatial_axis: 0
172
+ prob: 0.5
173
+ - _target_: RandFlipd
174
+ keys:
175
+ - "@image_key"
176
+ - "@label_key"
177
+ spatial_axis: 1
178
+ prob: 0.5
179
+ - _target_: RandFlipd
180
+ keys:
181
+ - "@image_key"
182
+ - "@label_key"
183
+ spatial_axis: 2
184
+ prob: 0.5
185
+ - _target_: CastToTyped
186
+ keys:
187
+ - "@image_key"
188
+ - "@label_key"
189
+ dtype:
190
+ - "$torch.float32"
191
+ - "$torch.uint8"
192
+ - _target_: ToTensord
193
+ keys:
194
+ - "@image_key"
195
+ - "@label_key"
196
+ transform_validation:
197
+ _target_: Compose
198
+ transforms:
199
+ - _target_: LoadImaged
200
+ keys:
201
+ - "@image_key"
202
+ - "@label_key"
203
+ - _target_: EnsureChannelFirstd
204
+ keys:
205
+ - "@image_key"
206
+ - "@label_key"
207
+ - _target_: Orientationd
208
+ keys:
209
+ - "@image_key"
210
+ - "@label_key"
211
+ axcodes: RAS
212
+ - _target_: Spacingd
213
+ keys:
214
+ - "@image_key"
215
+ - "@label_key"
216
+ pixdim:
217
+ - 1
218
+ - 1
219
+ - 1
220
+ mode:
221
+ - bilinear
222
+ - nearest
223
+ align_corners:
224
+ - true
225
+ - true
226
+ - _target_: CastToTyped
227
+ keys: "@image_key"
228
+ dtype: "$torch.float32"
229
+ - _target_: ScaleIntensityRanged
230
+ keys: "@image_key"
231
+ a_min: -87
232
+ a_max: 199
233
+ b_min: 0
234
+ b_max: 1
235
+ clip: true
236
+ - _target_: CastToTyped
237
+ keys:
238
+ - "@image_key"
239
+ - "@label_key"
240
+ dtype:
241
+ - "$np.float16"
242
+ - "$np.uint8"
243
+ - _target_: CastToTyped
244
+ keys:
245
+ - "@image_key"
246
+ - "@label_key"
247
+ dtype:
248
+ - "$torch.float32"
249
+ - "$torch.uint8"
250
+ - _target_: ToTensord
251
+ keys:
252
+ - "@image_key"
253
+ - "@label_key"
254
+ loss:
255
+ _target_: DiceCELoss
256
+ include_background: false
257
+ to_onehot_y: true
258
+ softmax: true
259
+ squared_pred: true
260
+ batch: true
261
+ smooth_nr: 0.00001
262
+ smooth_dr: 0.00001
263
+ dints_space:
264
+ _target_: monai.networks.nets.TopologySearch
265
+ channel_mul: 0.5
266
+ num_blocks: 12
267
+ num_depths: 4
268
+ use_downsample: true
269
+ device: "$torch.device('cuda')"
270
+ network:
271
+ _target_: monai.networks.nets.DiNTS
272
+ dints_space: "@dints_space"
273
+ in_channels: "@input_channels"
274
+ num_classes: "@output_classes"
275
+ use_downsample: true
configs/train.yaml ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ imports:
3
+ - "$import glob"
4
+ - "$import json"
5
+ - "$import os"
6
+ - "$import ignite"
7
+ - "$from scipy import ndimage"
8
+ input_channels: 1
9
+ output_classes: 3
10
+ arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
11
+ arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
12
+ bundle_root: "/workspace/MONAI/model-zoo/models/pancreas_ct_dints_segmentation"
13
+ ckpt_dir: "$@bundle_root + '/models'"
14
+ output_dir: "$@bundle_root + '/eval'"
15
+ dataset_dir: "/workspace/data/msd/Task07_Pancreas"
16
+ data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
17
+ train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training',
18
+ base_dir=@dataset_dir)"
19
+ val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation',
20
+ base_dir=@dataset_dir)"
21
+ device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
22
+ dints_space:
23
+ _target_: monai.networks.nets.TopologyInstance
24
+ channel_mul: 1
25
+ num_blocks: 12
26
+ num_depths: 4
27
+ use_downsample: true
28
+ arch_code:
29
+ - "$@arch_ckpt['arch_code_a']"
30
+ - "$@arch_ckpt['arch_code_c']"
31
+ device: "$torch.device('cuda')"
32
+ network_def:
33
+ _target_: monai.networks.nets.DiNTS
34
+ dints_space: "@dints_space"
35
+ in_channels: "@input_channels"
36
+ num_classes: "@output_classes"
37
+ use_downsample: true
38
+ node_a: "$@arch_ckpt['node_a']"
39
+ network: "$@network_def.to(@device)"
40
+ loss:
41
+ _target_: DiceCELoss
42
+ include_background: false
43
+ to_onehot_y: true
44
+ softmax: true
45
+ squared_pred: true
46
+ batch: true
47
+ smooth_nr: 1.0e-05
48
+ smooth_dr: 1.0e-05
49
+ optimizer:
50
+ _target_: torch.optim.SGD
51
+ params: "$@network.parameters()"
52
+ momentum: 0.9
53
+ weight_decay: 4.0e-05
54
+ lr: 0.025
55
+ lr_scheduler:
56
+ _target_: torch.optim.lr_scheduler.StepLR
57
+ optimizer: "@optimizer"
58
+ step_size: 80
59
+ gamma: 0.5
60
+ image_key: image
61
+ label_key: label
62
+ train:
63
+ deterministic_transforms:
64
+ - _target_: LoadImaged
65
+ keys:
66
+ - "@image_key"
67
+ - "@label_key"
68
+ - _target_: EnsureChannelFirstd
69
+ keys:
70
+ - "@image_key"
71
+ - "@label_key"
72
+ - _target_: Orientationd
73
+ keys:
74
+ - "@image_key"
75
+ - "@label_key"
76
+ axcodes: RAS
77
+ - _target_: Spacingd
78
+ keys:
79
+ - "@image_key"
80
+ - "@label_key"
81
+ pixdim:
82
+ - 1
83
+ - 1
84
+ - 1
85
+ mode:
86
+ - bilinear
87
+ - nearest
88
+ align_corners:
89
+ - true
90
+ - true
91
+ - _target_: CastToTyped
92
+ keys: "@image_key"
93
+ dtype: "$torch.float32"
94
+ - _target_: ScaleIntensityRanged
95
+ keys: "@image_key"
96
+ a_min: -87
97
+ a_max: 199
98
+ b_min: 0
99
+ b_max: 1
100
+ clip: true
101
+ - _target_: CastToTyped
102
+ keys:
103
+ - "@image_key"
104
+ - "@label_key"
105
+ dtype:
106
+ - "$np.float16"
107
+ - "$np.uint8"
108
+ - _target_: CopyItemsd
109
+ keys: "@label_key"
110
+ times: 1
111
+ names:
112
+ - label4crop
113
+ - _target_: Lambdad
114
+ keys: label4crop
115
+ func: "$lambda x, s=@output_classes: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype),
116
+ iterations=48).astype(x.dtype) for _k in range(s)]), axis=0)"
117
+ overwrite: true
118
+ - _target_: EnsureTyped
119
+ keys:
120
+ - "@image_key"
121
+ - "@label_key"
122
+ - _target_: CastToTyped
123
+ keys: "@image_key"
124
+ dtype: "$torch.float32"
125
+ - _target_: SpatialPadd
126
+ keys:
127
+ - "@image_key"
128
+ - "@label_key"
129
+ - label4crop
130
+ spatial_size:
131
+ - 96
132
+ - 96
133
+ - 96
134
+ mode:
135
+ - reflect
136
+ - constant
137
+ - constant
138
+ random_transforms:
139
+ - _target_: RandCropByLabelClassesd
140
+ keys:
141
+ - "@image_key"
142
+ - "@label_key"
143
+ label_key: label4crop
144
+ num_classes: "@output_classes"
145
+ ratios: "$[1,] * @output_classes"
146
+ spatial_size:
147
+ - 96
148
+ - 96
149
+ - 96
150
+ num_samples: 1
151
+ - _target_: Lambdad
152
+ keys: label4crop
153
+ func: "$lambda x: 0"
154
+ - _target_: RandRotated
155
+ keys:
156
+ - "@image_key"
157
+ - "@label_key"
158
+ range_x: 0.3
159
+ range_y: 0.3
160
+ range_z: 0.3
161
+ mode:
162
+ - bilinear
163
+ - nearest
164
+ prob: 0.2
165
+ - _target_: RandZoomd
166
+ keys:
167
+ - "@image_key"
168
+ - "@label_key"
169
+ min_zoom: 0.8
170
+ max_zoom: 1.2
171
+ mode:
172
+ - trilinear
173
+ - nearest
174
+ align_corners:
175
+ - true
176
+ -
177
+ prob: 0.16
178
+ - _target_: RandGaussianSmoothd
179
+ keys: "@image_key"
180
+ sigma_x:
181
+ - 0.5
182
+ - 1.15
183
+ sigma_y:
184
+ - 0.5
185
+ - 1.15
186
+ sigma_z:
187
+ - 0.5
188
+ - 1.15
189
+ prob: 0.15
190
+ - _target_: RandScaleIntensityd
191
+ keys: "@image_key"
192
+ factors: 0.3
193
+ prob: 0.5
194
+ - _target_: RandShiftIntensityd
195
+ keys: "@image_key"
196
+ offsets: 0.1
197
+ prob: 0.5
198
+ - _target_: RandGaussianNoised
199
+ keys: "@image_key"
200
+ std: 0.01
201
+ prob: 0.15
202
+ - _target_: RandFlipd
203
+ keys:
204
+ - "@image_key"
205
+ - "@label_key"
206
+ spatial_axis: 0
207
+ prob: 0.5
208
+ - _target_: RandFlipd
209
+ keys:
210
+ - "@image_key"
211
+ - "@label_key"
212
+ spatial_axis: 1
213
+ prob: 0.5
214
+ - _target_: RandFlipd
215
+ keys:
216
+ - "@image_key"
217
+ - "@label_key"
218
+ spatial_axis: 2
219
+ prob: 0.5
220
+ - _target_: CastToTyped
221
+ keys:
222
+ - "@image_key"
223
+ - "@label_key"
224
+ dtype:
225
+ - "$torch.float32"
226
+ - "$torch.uint8"
227
+ - _target_: ToTensord
228
+ keys:
229
+ - "@image_key"
230
+ - "@label_key"
231
+ preprocessing:
232
+ _target_: Compose
233
+ transforms: "$@train#deterministic_transforms + @train#random_transforms"
234
+ dataset:
235
+ _target_: CacheDataset
236
+ data: "@train_datalist"
237
+ transform: "@train#preprocessing"
238
+ cache_rate: 0.125
239
+ num_workers: 4
240
+ dataloader:
241
+ _target_: DataLoader
242
+ dataset: "@train#dataset"
243
+ batch_size: 2
244
+ shuffle: true
245
+ num_workers: 4
246
+ inferer:
247
+ _target_: SimpleInferer
248
+ postprocessing:
249
+ _target_: Compose
250
+ transforms:
251
+ - _target_: Activationsd
252
+ keys: pred
253
+ softmax: true
254
+ - _target_: AsDiscreted
255
+ keys:
256
+ - pred
257
+ - label
258
+ argmax:
259
+ - true
260
+ - false
261
+ to_onehot: "@output_classes"
262
+ handlers:
263
+ - _target_: LrScheduleHandler
264
+ lr_scheduler: "@lr_scheduler"
265
+ print_lr: true
266
+ - _target_: ValidationHandler
267
+ validator: "@validate#evaluator"
268
+ epoch_level: true
269
+ interval: 10
270
+ - _target_: StatsHandler
271
+ tag_name: train_loss
272
+ output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
273
+ - _target_: TensorBoardStatsHandler
274
+ log_dir: "@output_dir"
275
+ tag_name: train_loss
276
+ output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
277
+ key_metric:
278
+ train_accuracy:
279
+ _target_: ignite.metrics.Accuracy
280
+ output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
281
+ trainer:
282
+ _target_: SupervisedTrainer
283
+ max_epochs: 400
284
+ device: "@device"
285
+ train_data_loader: "@train#dataloader"
286
+ network: "@network"
287
+ loss_function: "@loss"
288
+ optimizer: "@optimizer"
289
+ inferer: "@train#inferer"
290
+ postprocessing: "@train#postprocessing"
291
+ key_train_metric: "@train#key_metric"
292
+ train_handlers: "@train#handlers"
293
+ amp: true
294
+ validate:
295
+ preprocessing:
296
+ _target_: Compose
297
+ transforms: "%train#deterministic_transforms"
298
+ dataset:
299
+ _target_: CacheDataset
300
+ data: "@val_datalist"
301
+ transform: "@validate#preprocessing"
302
+ cache_rate: 0.125
303
+ dataloader:
304
+ _target_: DataLoader
305
+ dataset: "@validate#dataset"
306
+ batch_size: 1
307
+ shuffle: false
308
+ num_workers: 4
309
+ inferer:
310
+ _target_: SlidingWindowInferer
311
+ roi_size:
312
+ - 96
313
+ - 96
314
+ - 96
315
+ sw_batch_size: 6
316
+ overlap: 0.625
317
+ postprocessing: "%train#postprocessing"
318
+ handlers:
319
+ - _target_: StatsHandler
320
+ iteration_log: false
321
+ - _target_: TensorBoardStatsHandler
322
+ log_dir: "@output_dir"
323
+ iteration_log: false
324
+ - _target_: CheckpointSaver
325
+ save_dir: "@ckpt_dir"
326
+ save_dict:
327
+ model: "@network"
328
+ save_key_metric: true
329
+ key_metric_filename: model.pt
330
+ key_metric:
331
+ val_mean_dice:
332
+ _target_: MeanDice
333
+ include_background: false
334
+ output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
335
+ additional_metrics:
336
+ val_accuracy:
337
+ _target_: ignite.metrics.Accuracy
338
+ output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
339
+ evaluator:
340
+ _target_: SupervisedEvaluator
341
+ device: "@device"
342
+ val_data_loader: "@validate#dataloader"
343
+ network: "@network"
344
+ inferer: "@validate#inferer"
345
+ postprocessing: "@validate#postprocessing"
346
+ key_val_metric: "@validate#key_metric"
347
+ additional_metrics: "@validate#additional_metrics"
348
+ val_handlers: "@validate#handlers"
349
+ amp: true
350
+ training:
351
+ - "$monai.utils.set_determinism(seed=123)"
352
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
353
+ - "$@train#trainer.run()"
docs/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ A neural architecture search algorithm for volumetric (3D) segmentation of the pancreas and pancreatic tumor from CT image.
3
+
4
+ # Model Overview
5
+ This model is trained using the state-of-the-art algorithm [1] of the "Medical Segmentation Decathlon Challenge 2018" with 196 training images, 56 validation images, and 28 testing images.
6
+
7
+ ## Data
8
+ The training dataset is Task07_Pancreas.tar from http://medicaldecathlon.com/. And the data list/split can be created with the script `scripts/prepare_datalist.py`.
9
+
10
+ ## Training configuration
11
+ The training was performed with at least 16GB-memory GPUs.
12
+
13
+ Actual Model Input: 96 x 96 x 96
14
+
15
+ ## Input and output formats
16
+ Input: 1 channel CT image
17
+
18
+ Output: 3 channels: Label 2: pancreatic tumor; Label 1: pancreas; Label 0: everything else
19
+
20
+ ## Scores
21
+ This model achieves the following Dice score on the validation data (our own split from the training dataset):
22
+
23
+ Mean Dice = 0.72
24
+
25
+ ## commands example
26
+ Create data split (.json file):
27
+
28
+ ```
29
+ python scripts/prepare_datalist.py --path /path-to-Task07_Pancreas/ --output configs/dataset_0.json
30
+ ```
31
+
32
+ Execute model searching:
33
+
34
+ ```
35
+ python -m scripts.search run --config_file configs/search.yaml
36
+ ```
37
+
38
+ Execute multi-GPU model searching (recommended):
39
+
40
+ ```
41
+ torchrun --nnodes=1 --nproc_per_node=8 -m scripts.search run --config_file configs/search.yaml
42
+ ```
43
+
44
+ Execute training:
45
+
46
+ ```
47
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.yaml --logging_file configs/logging.conf
48
+ ```
49
+
50
+ Override the `train` config to execute multi-GPU training:
51
+
52
+ ```
53
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']" --logging_file configs/logging.conf
54
+ ```
55
+
56
+ Override the `train` config to execute evaluation with the trained model:
57
+
58
+ ```
59
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/evaluate.yaml']" --logging_file configs/logging.conf
60
+ ```
61
+
62
+ Execute inference:
63
+
64
+ ```
65
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.yaml --logging_file configs/logging.conf
66
+ ```
67
+
68
+ # Disclaimer
69
+ This is an example, not to be used for diagnostic purposes.
70
+
71
+ # References
72
+ [1] He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850).
docs/license.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. Medical Segmentation Decathlon
6
+ http://medicaldecathlon.com/
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:975201057eb16225a1abfd047cb9b293f6d481dc604468d512710f3543f29066
3
+ size 616210421
models/model.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:526d2bdb4d88f6f55f2d88eb0c79deeeb90b3ced182269e33cdc5da6e46ea5fb
3
+ size 616338455
models/search_code_18590.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01e361e9843e2f4e5ff1599da0abac77013ea38cab8fdd6c9286bb6572c9a32d
3
+ size 4335
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/prepare_datalist.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import os
5
+
6
+ import monai
7
+ from sklearn.model_selection import train_test_split
8
+
9
+
10
+ def produce_sample_dict(line: str):
11
+ return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
12
+
13
+
14
+ def produce_datalist(dataset_dir: str):
15
+ """
16
+ This function is used to split the dataset.
17
+ It will produce 200 samples for training, and the other samples are divided equally
18
+ into val and test sets.
19
+ """
20
+
21
+ samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
22
+ samples = [_item.replace(os.path.join(dataset_dir, "labelsTr"), "labelsTr") for _item in samples]
23
+ datalist = []
24
+ for line in samples:
25
+ datalist.append(produce_sample_dict(line))
26
+ train_list, other_list = train_test_split(datalist, train_size=196)
27
+ val_list, test_list = train_test_split(other_list, train_size=0.66)
28
+
29
+ return {"training": train_list, "validation": val_list, "testing": test_list}
30
+
31
+
32
+ def main(args):
33
+ """
34
+ split the dataset and output the data list into a json file.
35
+ """
36
+ data_file_base_dir = args.path
37
+ output_json = args.output
38
+ # produce deterministic data splits
39
+ monai.utils.set_determinism(seed=123)
40
+ datalist = produce_datalist(dataset_dir=data_file_base_dir)
41
+ with open(output_json, "w") as f:
42
+ json.dump(datalist, f, ensure_ascii=True, indent=4)
43
+
44
+
45
+ if __name__ == "__main__":
46
+
47
+ parser = argparse.ArgumentParser(description="")
48
+ parser.add_argument(
49
+ "--path",
50
+ type=str,
51
+ default="/workspace/data/msd/Task07_Pancreas",
52
+ help="root path of MSD Task07_Pancreas dataset.",
53
+ )
54
+ parser.add_argument(
55
+ "--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
56
+ )
57
+ args = parser.parse_args()
58
+
59
+ main(args)
scripts/search.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import logging
14
+ import os
15
+ import random
16
+ import sys
17
+ import time
18
+ from datetime import datetime
19
+ from typing import Sequence, Union
20
+
21
+ import monai
22
+ import numpy as np
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn.functional as F
26
+ import yaml
27
+ from monai import transforms
28
+ from monai.bundle import ConfigParser
29
+ from monai.data import ThreadDataLoader, partition_dataset
30
+ from monai.inferers import sliding_window_inference
31
+ from monai.metrics import compute_meandice
32
+ from monai.utils import set_determinism
33
+ from torch.nn.parallel import DistributedDataParallel
34
+ from torch.utils.tensorboard import SummaryWriter
35
+
36
+
37
+ def run(config_file: Union[str, Sequence[str]]):
38
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
39
+
40
+ parser = ConfigParser()
41
+ parser.read_config(config_file)
42
+
43
+ arch_ckpt_path = parser["arch_ckpt_path"]
44
+ amp = parser["amp"]
45
+ data_file_base_dir = parser["data_file_base_dir"]
46
+ data_list_file_path = parser["data_list_file_path"]
47
+ determ = parser["determ"]
48
+ learning_rate = parser["learning_rate"]
49
+ learning_rate_arch = parser["learning_rate_arch"]
50
+ learning_rate_milestones = np.array(parser["learning_rate_milestones"])
51
+ num_images_per_batch = parser["num_images_per_batch"]
52
+ num_epochs = parser["num_epochs"] # around 20k iterations
53
+ num_epochs_per_validation = parser["num_epochs_per_validation"]
54
+ num_epochs_warmup = parser["num_epochs_warmup"]
55
+ num_sw_batch_size = parser["num_sw_batch_size"]
56
+ output_classes = parser["output_classes"]
57
+ overlap_ratio = parser["overlap_ratio"]
58
+ patch_size_valid = parser["patch_size_valid"]
59
+ ram_cost_factor = parser["ram_cost_factor"]
60
+ print("[info] GPU RAM cost factor:", ram_cost_factor)
61
+
62
+ train_transforms = parser.get_parsed_content("transform_train")
63
+ val_transforms = parser.get_parsed_content("transform_validation")
64
+
65
+ # deterministic training
66
+ if determ:
67
+ set_determinism(seed=0)
68
+
69
+ print("[info] number of GPUs:", torch.cuda.device_count())
70
+ if torch.cuda.device_count() > 1:
71
+ # initialize the distributed training process, every GPU runs in a process
72
+ dist.init_process_group(backend="nccl", init_method="env://")
73
+ world_size = dist.get_world_size()
74
+ else:
75
+ world_size = 1
76
+ print("[info] world_size:", world_size)
77
+
78
+ with open(data_list_file_path, "r") as f:
79
+ json_data = json.load(f)
80
+
81
+ list_train = json_data["training"]
82
+ list_valid = json_data["validation"]
83
+
84
+ # training data
85
+ files = []
86
+ for _i in range(len(list_train)):
87
+ str_img = os.path.join(data_file_base_dir, list_train[_i]["image"])
88
+ str_seg = os.path.join(data_file_base_dir, list_train[_i]["label"])
89
+
90
+ if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
91
+ continue
92
+
93
+ files.append({"image": str_img, "label": str_seg})
94
+ train_files = files
95
+
96
+ random.shuffle(train_files)
97
+
98
+ train_files_w = train_files[: len(train_files) // 2]
99
+ if torch.cuda.device_count() > 1:
100
+ train_files_w = partition_dataset(
101
+ data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
102
+ )[dist.get_rank()]
103
+ print("train_files_w:", len(train_files_w))
104
+
105
+ train_files_a = train_files[len(train_files) // 2 :]
106
+ if torch.cuda.device_count() > 1:
107
+ train_files_a = partition_dataset(
108
+ data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
109
+ )[dist.get_rank()]
110
+ print("train_files_a:", len(train_files_a))
111
+
112
+ # validation data
113
+ files = []
114
+ for _i in range(len(list_valid)):
115
+ str_img = os.path.join(data_file_base_dir, list_valid[_i]["image"])
116
+ str_seg = os.path.join(data_file_base_dir, list_valid[_i]["label"])
117
+
118
+ if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
119
+ continue
120
+
121
+ files.append({"image": str_img, "label": str_seg})
122
+ val_files = files
123
+
124
+ if torch.cuda.device_count() > 1:
125
+ val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
126
+ dist.get_rank()
127
+ ]
128
+ print("val_files:", len(val_files))
129
+
130
+ # network architecture
131
+ if torch.cuda.device_count() > 1:
132
+ device = torch.device(f"cuda:{dist.get_rank()}")
133
+ else:
134
+ device = torch.device("cuda:0")
135
+ torch.cuda.set_device(device)
136
+
137
+ if torch.cuda.device_count() > 1:
138
+ train_ds_a = monai.data.CacheDataset(
139
+ data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8
140
+ )
141
+ train_ds_w = monai.data.CacheDataset(
142
+ data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8
143
+ )
144
+ val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
145
+ else:
146
+ train_ds_a = monai.data.CacheDataset(
147
+ data=train_files_a, transform=train_transforms, cache_rate=0.125, num_workers=8
148
+ )
149
+ train_ds_w = monai.data.CacheDataset(
150
+ data=train_files_w, transform=train_transforms, cache_rate=0.125, num_workers=8
151
+ )
152
+ val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.125, num_workers=2)
153
+
154
+ train_loader_a = ThreadDataLoader(train_ds_a, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
155
+ train_loader_w = ThreadDataLoader(train_ds_w, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
156
+ val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
157
+
158
+ model = parser.get_parsed_content("network")
159
+ dints_space = parser.get_parsed_content("dints_space")
160
+
161
+ model = model.to(device)
162
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
163
+
164
+ post_pred = transforms.Compose(
165
+ [transforms.EnsureType(), transforms.AsDiscrete(argmax=True, to_onehot=output_classes)]
166
+ )
167
+ post_label = transforms.Compose([transforms.EnsureType(), transforms.AsDiscrete(to_onehot=output_classes)])
168
+
169
+ # loss function
170
+ loss_func = parser.get_parsed_content("loss")
171
+
172
+ # optimizer
173
+ optimizer = torch.optim.SGD(
174
+ model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004
175
+ )
176
+ arch_optimizer_a = torch.optim.Adam(
177
+ [dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
178
+ )
179
+ arch_optimizer_c = torch.optim.Adam(
180
+ [dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
181
+ )
182
+
183
+ if torch.cuda.device_count() > 1:
184
+ model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
185
+
186
+ # amp
187
+ if amp:
188
+ from torch.cuda.amp import GradScaler, autocast
189
+
190
+ scaler = GradScaler()
191
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
192
+ print("[info] amp enabled")
193
+
194
+ # start a typical PyTorch training
195
+ val_interval = num_epochs_per_validation
196
+ best_metric = -1
197
+ best_metric_epoch = -1
198
+ idx_iter = 0
199
+
200
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
201
+ writer = SummaryWriter(log_dir=os.path.join(arch_ckpt_path, "Events"))
202
+
203
+ with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
204
+ f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")
205
+
206
+ dataloader_a_iterator = iter(train_loader_a)
207
+
208
+ start_time = time.time()
209
+ for epoch in range(num_epochs):
210
+ decay = 0.5 ** np.sum(
211
+ [(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones]
212
+ )
213
+ lr = learning_rate * decay * world_size
214
+ for param_group in optimizer.param_groups:
215
+ param_group["lr"] = lr
216
+
217
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
218
+ print("-" * 10)
219
+ print(f"epoch {epoch + 1}/{num_epochs}")
220
+ print("learning rate is set to {}".format(lr))
221
+
222
+ model.train()
223
+ epoch_loss = 0
224
+ loss_torch = torch.zeros(2, dtype=torch.float, device=device)
225
+ epoch_loss_arch = 0
226
+ loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
227
+ step = 0
228
+
229
+ for batch_data in train_loader_w:
230
+ step += 1
231
+ inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
232
+ if world_size == 1:
233
+ for _ in model.weight_parameters():
234
+ _.requires_grad = True
235
+ else:
236
+ for _ in model.module.weight_parameters():
237
+ _.requires_grad = True
238
+ dints_space.log_alpha_a.requires_grad = False
239
+ dints_space.log_alpha_c.requires_grad = False
240
+
241
+ optimizer.zero_grad()
242
+
243
+ if amp:
244
+ with autocast():
245
+ outputs = model(inputs)
246
+ if output_classes == 2:
247
+ loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
248
+ else:
249
+ loss = loss_func(outputs, labels)
250
+
251
+ scaler.scale(loss).backward()
252
+ scaler.step(optimizer)
253
+ scaler.update()
254
+ else:
255
+ outputs = model(inputs)
256
+ if output_classes == 2:
257
+ loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
258
+ else:
259
+ loss = loss_func(outputs, labels)
260
+ loss.backward()
261
+ optimizer.step()
262
+
263
+ epoch_loss += loss.item()
264
+ loss_torch[0] += loss.item()
265
+ loss_torch[1] += 1.0
266
+ epoch_len = len(train_loader_w)
267
+ idx_iter += 1
268
+
269
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
270
+ print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
271
+ writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
272
+
273
+ if epoch < num_epochs_warmup:
274
+ continue
275
+
276
+ try:
277
+ sample_a = next(dataloader_a_iterator)
278
+ except StopIteration:
279
+ dataloader_a_iterator = iter(train_loader_a)
280
+ sample_a = next(dataloader_a_iterator)
281
+ inputs_search, labels_search = (sample_a["image"].to(device), sample_a["label"].to(device))
282
+ if world_size == 1:
283
+ for _ in model.weight_parameters():
284
+ _.requires_grad = False
285
+ else:
286
+ for _ in model.module.weight_parameters():
287
+ _.requires_grad = False
288
+ dints_space.log_alpha_a.requires_grad = True
289
+ dints_space.log_alpha_c.requires_grad = True
290
+
291
+ # linear increase topology and RAM loss
292
+ entropy_alpha_c = torch.tensor(0.0).to(device)
293
+ entropy_alpha_a = torch.tensor(0.0).to(device)
294
+ ram_cost_full = torch.tensor(0.0).to(device)
295
+ ram_cost_usage = torch.tensor(0.0).to(device)
296
+ ram_cost_loss = torch.tensor(0.0).to(device)
297
+ topology_loss = torch.tensor(0.0).to(device)
298
+
299
+ probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
300
+ entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
301
+ entropy_alpha_c = -(
302
+ F.softmax(dints_space.log_alpha_c, dim=-1) * F.log_softmax(dints_space.log_alpha_c, dim=-1)
303
+ ).mean()
304
+ topology_loss = dints_space.get_topology_entropy(probs_a)
305
+
306
+ ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True)
307
+ ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
308
+ ram_cost_loss = torch.abs(ram_cost_factor - ram_cost_usage / ram_cost_full)
309
+
310
+ arch_optimizer_a.zero_grad()
311
+ arch_optimizer_c.zero_grad()
312
+
313
+ combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup)
314
+
315
+ if amp:
316
+ with autocast():
317
+ outputs_search = model(inputs_search)
318
+ if output_classes == 2:
319
+ loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
320
+ else:
321
+ loss = loss_func(outputs_search, labels_search)
322
+
323
+ loss += combination_weights * (
324
+ (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
325
+ )
326
+
327
+ scaler.scale(loss).backward()
328
+ scaler.step(arch_optimizer_a)
329
+ scaler.step(arch_optimizer_c)
330
+ scaler.update()
331
+ else:
332
+ outputs_search = model(inputs_search)
333
+ if output_classes == 2:
334
+ loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
335
+ else:
336
+ loss = loss_func(outputs_search, labels_search)
337
+
338
+ loss += 1.0 * (
339
+ combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
340
+ )
341
+
342
+ loss.backward()
343
+ arch_optimizer_a.step()
344
+ arch_optimizer_c.step()
345
+
346
+ epoch_loss_arch += loss.item()
347
+ loss_torch_arch[0] += loss.item()
348
+ loss_torch_arch[1] += 1.0
349
+
350
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
351
+ print(
352
+ "[{0}] ".format(str(datetime.now())[:19])
353
+ + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}"
354
+ )
355
+ writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step)
356
+
357
+ # synchronizes all processes and reduce results
358
+ if torch.cuda.device_count() > 1:
359
+ dist.barrier()
360
+ dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
361
+
362
+ loss_torch = loss_torch.tolist()
363
+ loss_torch_arch = loss_torch_arch.tolist()
364
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
365
+ loss_torch_epoch = loss_torch[0] / loss_torch[1]
366
+ print(
367
+ f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, "
368
+ f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
369
+ )
370
+
371
+ if epoch >= num_epochs_warmup:
372
+ loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
373
+ print(
374
+ f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, "
375
+ f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
376
+ )
377
+
378
+ if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs:
379
+ torch.cuda.empty_cache()
380
+ model.eval()
381
+ with torch.no_grad():
382
+ metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device)
383
+ metric_sum = 0.0
384
+ metric_count = 0
385
+ metric_mat = []
386
+ val_images = None
387
+ val_labels = None
388
+ val_outputs = None
389
+
390
+ _index = 0
391
+ for val_data in val_loader:
392
+ val_images = val_data["image"].to(device)
393
+ val_labels = val_data["label"].to(device)
394
+
395
+ roi_size = patch_size_valid
396
+ sw_batch_size = num_sw_batch_size
397
+
398
+ if amp:
399
+ with torch.cuda.amp.autocast():
400
+ pred = sliding_window_inference(
401
+ val_images,
402
+ roi_size,
403
+ sw_batch_size,
404
+ lambda x: model(x),
405
+ mode="gaussian",
406
+ overlap=overlap_ratio,
407
+ )
408
+ else:
409
+ pred = sliding_window_inference(
410
+ val_images,
411
+ roi_size,
412
+ sw_batch_size,
413
+ lambda x: model(x),
414
+ mode="gaussian",
415
+ overlap=overlap_ratio,
416
+ )
417
+ val_outputs = pred
418
+
419
+ val_outputs = post_pred(val_outputs[0, ...])
420
+ val_outputs = val_outputs[None, ...]
421
+ val_labels = post_label(val_labels[0, ...])
422
+ val_labels = val_labels[None, ...]
423
+
424
+ value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False)
425
+
426
+ print(_index + 1, "/", len(val_loader), value)
427
+
428
+ metric_count += len(value)
429
+ metric_sum += value.sum().item()
430
+ metric_vals = value.cpu().numpy()
431
+ if len(metric_mat) == 0:
432
+ metric_mat = metric_vals
433
+ else:
434
+ metric_mat = np.concatenate((metric_mat, metric_vals), axis=0)
435
+
436
+ for _c in range(output_classes - 1):
437
+ val0 = torch.nan_to_num(value[0, _c], nan=0.0)
438
+ val1 = 1.0 - torch.isnan(value[0, 0]).float()
439
+ metric[2 * _c] += val0 * val1
440
+ metric[2 * _c + 1] += val1
441
+
442
+ _index += 1
443
+
444
+ # synchronizes all processes and reduce results
445
+ if torch.cuda.device_count() > 1:
446
+ dist.barrier()
447
+ dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
448
+
449
+ metric = metric.tolist()
450
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
451
+ for _c in range(output_classes - 1):
452
+ print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1])
453
+ avg_metric = 0
454
+ for _c in range(output_classes - 1):
455
+ avg_metric += metric[2 * _c] / metric[2 * _c + 1]
456
+ avg_metric = avg_metric / float(output_classes - 1)
457
+ print("avg_metric", avg_metric)
458
+
459
+ if avg_metric > best_metric:
460
+ best_metric = avg_metric
461
+ best_metric_epoch = epoch + 1
462
+ best_metric_iterations = idx_iter
463
+
464
+ (node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d) = dints_space.decode()
465
+ torch.save(
466
+ {
467
+ "node_a": node_a_d,
468
+ "arch_code_a": arch_code_a_d,
469
+ "arch_code_a_max": arch_code_a_max_d,
470
+ "arch_code_c": arch_code_c_d,
471
+ "iter_num": idx_iter,
472
+ "epochs": epoch + 1,
473
+ "best_dsc": best_metric,
474
+ "best_path": best_metric_iterations,
475
+ },
476
+ os.path.join(arch_ckpt_path, "search_code_" + str(idx_iter) + ".pt"),
477
+ )
478
+ print("saved new best metric model")
479
+
480
+ dict_file = {}
481
+ dict_file["best_avg_dice_score"] = float(best_metric)
482
+ dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch)
483
+ dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
484
+ with open(os.path.join(arch_ckpt_path, "progress.yaml"), "w") as out_file:
485
+ _ = yaml.dump(dict_file, stream=out_file)
486
+
487
+ print(
488
+ "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
489
+ epoch + 1, avg_metric, best_metric, best_metric_epoch
490
+ )
491
+ )
492
+
493
+ current_time = time.time()
494
+ elapsed_time = (current_time - start_time) / 60.0
495
+ with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
496
+ f.write(
497
+ "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format(
498
+ epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter
499
+ )
500
+ )
501
+
502
+ if torch.cuda.device_count() > 1:
503
+ dist.barrier()
504
+
505
+ torch.cuda.empty_cache()
506
+
507
+ print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
508
+
509
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
510
+ writer.close()
511
+
512
+ if torch.cuda.device_count() > 1:
513
+ dist.destroy_process_group()
514
+
515
+
516
+ if __name__ == "__main__":
517
+ from monai.utils import optional_import
518
+
519
+ fire, _ = optional_import("fire")
520
+ fire.Fire()