complete the model package
Browse files- .gitattributes +1 -0
- README.md +79 -0
- configs/evaluate.yaml +64 -0
- configs/inference.yaml +116 -0
- configs/logging.conf +21 -0
- configs/metadata.json +80 -0
- configs/multi_gpu_train.yaml +52 -0
- configs/search.yaml +275 -0
- configs/train.yaml +353 -0
- docs/README.md +72 -0
- docs/license.txt +6 -0
- models/model.pt +3 -0
- models/model.ts +3 -0
- models/search_code_18590.pt +3 -0
- scripts/__init__.py +10 -0
- scripts/prepare_datalist.py +59 -0
- scripts/search.py +520 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/model.ts filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- monai
|
4 |
+
- medical
|
5 |
+
library_name: monai
|
6 |
+
license: unknown
|
7 |
+
---
|
8 |
+
# Description
|
9 |
+
A neural architecture search algorithm for volumetric (3D) segmentation of the pancreas and pancreatic tumor from CT image.
|
10 |
+
|
11 |
+
# Model Overview
|
12 |
+
This model is trained using the state-of-the-art algorithm [1] of the "Medical Segmentation Decathlon Challenge 2018" with 196 training images, 56 validation images, and 28 testing images.
|
13 |
+
|
14 |
+
## Data
|
15 |
+
The training dataset is Task07_Pancreas.tar from http://medicaldecathlon.com/. And the data list/split can be created with the script `scripts/prepare_datalist.py`.
|
16 |
+
|
17 |
+
## Training configuration
|
18 |
+
The training was performed with at least 16GB-memory GPUs.
|
19 |
+
|
20 |
+
Actual Model Input: 96 x 96 x 96
|
21 |
+
|
22 |
+
## Input and output formats
|
23 |
+
Input: 1 channel CT image
|
24 |
+
|
25 |
+
Output: 3 channels: Label 2: pancreatic tumor; Label 1: pancreas; Label 0: everything else
|
26 |
+
|
27 |
+
## Scores
|
28 |
+
This model achieves the following Dice score on the validation data (our own split from the training dataset):
|
29 |
+
|
30 |
+
Mean Dice = 0.72
|
31 |
+
|
32 |
+
## commands example
|
33 |
+
Create data split (.json file):
|
34 |
+
|
35 |
+
```
|
36 |
+
python scripts/prepare_datalist.py --path /path-to-Task07_Pancreas/ --output configs/dataset_0.json
|
37 |
+
```
|
38 |
+
|
39 |
+
Execute model searching:
|
40 |
+
|
41 |
+
```
|
42 |
+
python -m scripts.search run --config_file configs/search.yaml
|
43 |
+
```
|
44 |
+
|
45 |
+
Execute multi-GPU model searching (recommended):
|
46 |
+
|
47 |
+
```
|
48 |
+
torchrun --nnodes=1 --nproc_per_node=8 -m scripts.search run --config_file configs/search.yaml
|
49 |
+
```
|
50 |
+
|
51 |
+
Execute training:
|
52 |
+
|
53 |
+
```
|
54 |
+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.yaml --logging_file configs/logging.conf
|
55 |
+
```
|
56 |
+
|
57 |
+
Override the `train` config to execute multi-GPU training:
|
58 |
+
|
59 |
+
```
|
60 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']" --logging_file configs/logging.conf
|
61 |
+
```
|
62 |
+
|
63 |
+
Override the `train` config to execute evaluation with the trained model:
|
64 |
+
|
65 |
+
```
|
66 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/evaluate.yaml']" --logging_file configs/logging.conf
|
67 |
+
```
|
68 |
+
|
69 |
+
Execute inference:
|
70 |
+
|
71 |
+
```
|
72 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.yaml --logging_file configs/logging.conf
|
73 |
+
```
|
74 |
+
|
75 |
+
# Disclaimer
|
76 |
+
This is an example, not to be used for diagnostic purposes.
|
77 |
+
|
78 |
+
# References
|
79 |
+
[1] He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850).
|
configs/evaluate.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
validate#postprocessing:
|
3 |
+
_target_: Compose
|
4 |
+
transforms:
|
5 |
+
- _target_: Activationsd
|
6 |
+
keys: pred
|
7 |
+
softmax: true
|
8 |
+
- _target_: Invertd
|
9 |
+
keys:
|
10 |
+
- pred
|
11 |
+
- label
|
12 |
+
transform: "@validate#preprocessing"
|
13 |
+
orig_keys: image
|
14 |
+
meta_key_postfix: meta_dict
|
15 |
+
nearest_interp:
|
16 |
+
- false
|
17 |
+
- true
|
18 |
+
to_tensor: true
|
19 |
+
- _target_: AsDiscreted
|
20 |
+
keys:
|
21 |
+
- pred
|
22 |
+
- label
|
23 |
+
argmax:
|
24 |
+
- true
|
25 |
+
- false
|
26 |
+
to_onehot: 3
|
27 |
+
- _target_: CopyItemsd
|
28 |
+
keys: "pred"
|
29 |
+
times: 1
|
30 |
+
names: "pred_save"
|
31 |
+
- _target_: AsDiscreted
|
32 |
+
keys:
|
33 |
+
- pred_save
|
34 |
+
argmax:
|
35 |
+
- true
|
36 |
+
- _target_: SaveImaged
|
37 |
+
keys: pred_save
|
38 |
+
meta_keys: pred_meta_dict
|
39 |
+
output_dir: "@output_dir"
|
40 |
+
resample: false
|
41 |
+
squeeze_end_dims: true
|
42 |
+
validate#dataset:
|
43 |
+
_target_: Dataset
|
44 |
+
data: "@val_datalist"
|
45 |
+
transform: "@validate#preprocessing"
|
46 |
+
validate#handlers:
|
47 |
+
- _target_: CheckpointLoader
|
48 |
+
load_path: "$@ckpt_dir + '/model.pt'"
|
49 |
+
load_dict:
|
50 |
+
model: "@network"
|
51 |
+
- _target_: StatsHandler
|
52 |
+
iteration_log: false
|
53 |
+
- _target_: MetricsSaver
|
54 |
+
save_dir: "@output_dir"
|
55 |
+
metrics:
|
56 |
+
- val_mean_dice
|
57 |
+
- val_acc
|
58 |
+
metric_details:
|
59 |
+
- val_mean_dice
|
60 |
+
batch_transform: "$monai.handlers.from_engine(['image_meta_dict'])"
|
61 |
+
summary_ops: "*"
|
62 |
+
evaluating:
|
63 |
+
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
|
64 |
+
- "$@validate#evaluator.run()"
|
configs/inference.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
imports:
|
3 |
+
- "$import glob"
|
4 |
+
- "$import os"
|
5 |
+
input_channels: 1
|
6 |
+
output_classes: 3
|
7 |
+
arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
|
8 |
+
arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
|
9 |
+
bundle_root: "/workspace/MONAI/model-zoo/models/pancreas_ct_dints_segmentation"
|
10 |
+
output_dir: "$@bundle_root + '/eval'"
|
11 |
+
dataset_dir: "/workspace/data/msd/Task07_Pancreas"
|
12 |
+
data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
|
13 |
+
datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='testing',
|
14 |
+
base_dir=@dataset_dir)"
|
15 |
+
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
|
16 |
+
dints_space:
|
17 |
+
_target_: monai.networks.nets.TopologyInstance
|
18 |
+
channel_mul: 1
|
19 |
+
num_blocks: 12
|
20 |
+
num_depths: 4
|
21 |
+
use_downsample: true
|
22 |
+
arch_code:
|
23 |
+
- "$@arch_ckpt['arch_code_a']"
|
24 |
+
- "$@arch_ckpt['arch_code_c']"
|
25 |
+
device: "$torch.device('cuda')"
|
26 |
+
network_def:
|
27 |
+
_target_: monai.networks.nets.DiNTS
|
28 |
+
dints_space: "@dints_space"
|
29 |
+
in_channels: "@input_channels"
|
30 |
+
num_classes: "@output_classes"
|
31 |
+
use_downsample: true
|
32 |
+
node_a: "$torch.from_numpy(@arch_ckpt['node_a'])"
|
33 |
+
network: "$@network_def.to(@device)"
|
34 |
+
preprocessing:
|
35 |
+
_target_: Compose
|
36 |
+
transforms:
|
37 |
+
- _target_: LoadImaged
|
38 |
+
keys: image
|
39 |
+
- _target_: EnsureChannelFirstd
|
40 |
+
keys: image
|
41 |
+
- _target_: Orientationd
|
42 |
+
keys: image
|
43 |
+
axcodes: RAS
|
44 |
+
- _target_: Spacingd
|
45 |
+
keys: image
|
46 |
+
pixdim:
|
47 |
+
- 1
|
48 |
+
- 1
|
49 |
+
- 1
|
50 |
+
mode: bilinear
|
51 |
+
- _target_: ScaleIntensityRanged
|
52 |
+
keys: image
|
53 |
+
a_min: -87
|
54 |
+
a_max: 199
|
55 |
+
b_min: 0
|
56 |
+
b_max: 1
|
57 |
+
clip: true
|
58 |
+
- _target_: EnsureTyped
|
59 |
+
keys: image
|
60 |
+
dataset:
|
61 |
+
_target_: Dataset
|
62 |
+
data: "@datalist"
|
63 |
+
transform: "@preprocessing"
|
64 |
+
dataloader:
|
65 |
+
_target_: DataLoader
|
66 |
+
dataset: "@dataset"
|
67 |
+
batch_size: 1
|
68 |
+
shuffle: false
|
69 |
+
num_workers: 4
|
70 |
+
inferer:
|
71 |
+
_target_: SlidingWindowInferer
|
72 |
+
roi_size:
|
73 |
+
- 96
|
74 |
+
- 96
|
75 |
+
- 96
|
76 |
+
sw_batch_size: 4
|
77 |
+
overlap: 0.625
|
78 |
+
postprocessing:
|
79 |
+
_target_: Compose
|
80 |
+
transforms:
|
81 |
+
- _target_: Activationsd
|
82 |
+
keys: pred
|
83 |
+
softmax: true
|
84 |
+
- _target_: Invertd
|
85 |
+
keys: pred
|
86 |
+
transform: "@preprocessing"
|
87 |
+
orig_keys: image
|
88 |
+
meta_key_postfix: meta_dict
|
89 |
+
nearest_interp: false
|
90 |
+
to_tensor: true
|
91 |
+
- _target_: AsDiscreted
|
92 |
+
keys: pred
|
93 |
+
argmax: true
|
94 |
+
- _target_: SaveImaged
|
95 |
+
keys: pred
|
96 |
+
meta_keys: pred_meta_dict
|
97 |
+
output_dir: "@output_dir"
|
98 |
+
handlers:
|
99 |
+
- _target_: CheckpointLoader
|
100 |
+
load_path: "$@bundle_root + '/models/model.pt'"
|
101 |
+
load_dict:
|
102 |
+
model: "@network"
|
103 |
+
- _target_: StatsHandler
|
104 |
+
iteration_log: false
|
105 |
+
evaluator:
|
106 |
+
_target_: SupervisedEvaluator
|
107 |
+
device: "@device"
|
108 |
+
val_data_loader: "@dataloader"
|
109 |
+
network: "@network"
|
110 |
+
inferer: "@inferer"
|
111 |
+
postprocessing: "@postprocessing"
|
112 |
+
val_handlers: "@handlers"
|
113 |
+
amp: true
|
114 |
+
evaluating:
|
115 |
+
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
|
116 |
+
- "$@evaluator.run()"
|
configs/logging.conf
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=fullFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
[handler_consoleHandler]
|
15 |
+
class=StreamHandler
|
16 |
+
level=INFO
|
17 |
+
formatter=fullFormatter
|
18 |
+
args=(sys.stdout,)
|
19 |
+
|
20 |
+
[formatter_fullFormatter]
|
21 |
+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
configs/metadata.json
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
+
"version": "0.1.0",
|
4 |
+
"changelog": {
|
5 |
+
"0.1.0": "complete the model package",
|
6 |
+
"0.0.1": "initialize the model package structure"
|
7 |
+
},
|
8 |
+
"monai_version": "0.9.0",
|
9 |
+
"pytorch_version": "1.12.0",
|
10 |
+
"numpy_version": "1.21.2",
|
11 |
+
"optional_packages_version": {
|
12 |
+
"fire": "0.4.0",
|
13 |
+
"nibabel": "3.2.1",
|
14 |
+
"pytorch-ignite": "0.4.9"
|
15 |
+
},
|
16 |
+
"task": "Neural architecture search on pancreas CT segmentation",
|
17 |
+
"description": "Searched architectures for volumetric (3D) segmentation of the pancreas from CT image",
|
18 |
+
"authors": "MONAI team",
|
19 |
+
"copyright": "Copyright (c) MONAI Consortium",
|
20 |
+
"data_source": "Task07_Pancreas.tar from http://medicaldecathlon.com/",
|
21 |
+
"data_type": "nibabel",
|
22 |
+
"image_classes": "single channel data, intensity scaled to [0, 1]",
|
23 |
+
"label_classes": "single channel data, 1 is pancreas, 2 is pancreatic tumor, 0 is everything else",
|
24 |
+
"pred_classes": "3 channels OneHot data, channel 1 is pancreas, channel 2 is pancreatic tumor, channel 0 is background",
|
25 |
+
"eval_metrics": {
|
26 |
+
"mean_dice": 0.72
|
27 |
+
},
|
28 |
+
"intended_use": "This is an example, not to be used for diagnostic purposes",
|
29 |
+
"references": [
|
30 |
+
"He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850)."
|
31 |
+
],
|
32 |
+
"network_data_format": {
|
33 |
+
"inputs": {
|
34 |
+
"image": {
|
35 |
+
"type": "image",
|
36 |
+
"format": "hounsfield",
|
37 |
+
"modality": "CT",
|
38 |
+
"num_channels": 1,
|
39 |
+
"spatial_shape": [
|
40 |
+
96,
|
41 |
+
96,
|
42 |
+
96
|
43 |
+
],
|
44 |
+
"dtype": "float32",
|
45 |
+
"value_range": [
|
46 |
+
0,
|
47 |
+
1
|
48 |
+
],
|
49 |
+
"is_patch_data": true,
|
50 |
+
"channel_def": {
|
51 |
+
"0": "image"
|
52 |
+
}
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"outputs": {
|
56 |
+
"pred": {
|
57 |
+
"type": "image",
|
58 |
+
"format": "segmentation",
|
59 |
+
"num_channels": 3,
|
60 |
+
"spatial_shape": [
|
61 |
+
96,
|
62 |
+
96,
|
63 |
+
96
|
64 |
+
],
|
65 |
+
"dtype": "float32",
|
66 |
+
"value_range": [
|
67 |
+
0,
|
68 |
+
1,
|
69 |
+
2
|
70 |
+
],
|
71 |
+
"is_patch_data": true,
|
72 |
+
"channel_def": {
|
73 |
+
"0": "background",
|
74 |
+
"1": "pancreas",
|
75 |
+
"2": "pancreatic tumor"
|
76 |
+
}
|
77 |
+
}
|
78 |
+
}
|
79 |
+
}
|
80 |
+
}
|
configs/multi_gpu_train.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
device: "$torch.device(f'cuda:{dist.get_rank()}')"
|
3 |
+
network:
|
4 |
+
_target_: torch.nn.parallel.DistributedDataParallel
|
5 |
+
module: "$@network_def.to(@device)"
|
6 |
+
find_unused_parameters: true
|
7 |
+
device_ids:
|
8 |
+
- "@device"
|
9 |
+
optimizer#lr: "$0.0125*dist.get_world_size()"
|
10 |
+
lr_scheduler#step_size: "$80*dist.get_world_size()"
|
11 |
+
train#handlers:
|
12 |
+
- _target_: LrScheduleHandler
|
13 |
+
lr_scheduler: "@lr_scheduler"
|
14 |
+
print_lr: true
|
15 |
+
- _target_: ValidationHandler
|
16 |
+
validator: "@validate#evaluator"
|
17 |
+
epoch_level: true
|
18 |
+
interval: "$10*dist.get_world_size()"
|
19 |
+
- _target_: StatsHandler
|
20 |
+
tag_name: train_loss
|
21 |
+
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
|
22 |
+
- _target_: TensorBoardStatsHandler
|
23 |
+
log_dir: "@output_dir"
|
24 |
+
tag_name: train_loss
|
25 |
+
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
|
26 |
+
train#trainer#max_epochs: "$400*dist.get_world_size()"
|
27 |
+
train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]"
|
28 |
+
validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers"
|
29 |
+
training:
|
30 |
+
- "$import torch.distributed as dist"
|
31 |
+
- "$dist.init_process_group(backend='nccl')"
|
32 |
+
- "$torch.cuda.set_device(@device)"
|
33 |
+
- "$monai.utils.set_determinism(seed=123)"
|
34 |
+
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
|
35 |
+
- "$@train#trainer.run()"
|
36 |
+
- "$dist.destroy_process_group()"
|
37 |
+
train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(),
|
38 |
+
shuffle=True, even_divisible=True,)[dist.get_rank()]"
|
39 |
+
train#dataset:
|
40 |
+
_target_: CacheDataset
|
41 |
+
data: "@train_data_partition"
|
42 |
+
transform: "@train#preprocessing"
|
43 |
+
cache_rate: 1
|
44 |
+
num_workers: 4
|
45 |
+
val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(),
|
46 |
+
shuffle=False, even_divisible=False,)[dist.get_rank()]"
|
47 |
+
validate#dataset:
|
48 |
+
_target_: CacheDataset
|
49 |
+
data: "@val_data_partition"
|
50 |
+
transform: "@validate#preprocessing"
|
51 |
+
cache_rate: 1
|
52 |
+
num_workers: 4
|
configs/search.yaml
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
imports:
|
3 |
+
- "$from scipy import ndimage"
|
4 |
+
arch_ckpt_path: models
|
5 |
+
amp: true
|
6 |
+
data_file_base_dir: /workspace/data/msd/Task07_Pancreas
|
7 |
+
data_list_file_path: configs/dataset_0.json
|
8 |
+
determ: true
|
9 |
+
input_channels: 1
|
10 |
+
learning_rate: 0.025
|
11 |
+
learning_rate_arch: 0.001
|
12 |
+
learning_rate_milestones:
|
13 |
+
- 0.4
|
14 |
+
- 0.8
|
15 |
+
num_images_per_batch: 1
|
16 |
+
num_epochs: 1430
|
17 |
+
num_epochs_per_validation: 100
|
18 |
+
num_epochs_warmup: 715
|
19 |
+
num_patches_per_image: 1
|
20 |
+
num_sw_batch_size: 6
|
21 |
+
output_classes: 3
|
22 |
+
overlap_ratio: 0.625
|
23 |
+
patch_size:
|
24 |
+
- 96
|
25 |
+
- 96
|
26 |
+
- 96
|
27 |
+
patch_size_valid:
|
28 |
+
- 96
|
29 |
+
- 96
|
30 |
+
- 96
|
31 |
+
ram_cost_factor: 0.8
|
32 |
+
image_key: image
|
33 |
+
label_key: label
|
34 |
+
transform_train:
|
35 |
+
_target_: Compose
|
36 |
+
transforms:
|
37 |
+
- _target_: LoadImaged
|
38 |
+
keys:
|
39 |
+
- "@image_key"
|
40 |
+
- "@label_key"
|
41 |
+
- _target_: EnsureChannelFirstd
|
42 |
+
keys:
|
43 |
+
- "@image_key"
|
44 |
+
- "@label_key"
|
45 |
+
- _target_: Orientationd
|
46 |
+
keys:
|
47 |
+
- "@image_key"
|
48 |
+
- "@label_key"
|
49 |
+
axcodes: RAS
|
50 |
+
- _target_: Spacingd
|
51 |
+
keys:
|
52 |
+
- "@image_key"
|
53 |
+
- "@label_key"
|
54 |
+
pixdim:
|
55 |
+
- 1
|
56 |
+
- 1
|
57 |
+
- 1
|
58 |
+
mode:
|
59 |
+
- bilinear
|
60 |
+
- nearest
|
61 |
+
align_corners:
|
62 |
+
- true
|
63 |
+
- true
|
64 |
+
- _target_: CastToTyped
|
65 |
+
keys: "@image_key"
|
66 |
+
dtype: "$torch.float32"
|
67 |
+
- _target_: ScaleIntensityRanged
|
68 |
+
keys: "@image_key"
|
69 |
+
a_min: -87
|
70 |
+
a_max: 199
|
71 |
+
b_min: 0
|
72 |
+
b_max: 1
|
73 |
+
clip: true
|
74 |
+
- _target_: CastToTyped
|
75 |
+
keys:
|
76 |
+
- "@image_key"
|
77 |
+
- "@label_key"
|
78 |
+
dtype:
|
79 |
+
- "$np.float16"
|
80 |
+
- "$np.uint8"
|
81 |
+
- _target_: CopyItemsd
|
82 |
+
keys: "@label_key"
|
83 |
+
times: 1
|
84 |
+
names:
|
85 |
+
- label4crop
|
86 |
+
- _target_: Lambdad
|
87 |
+
keys: label4crop
|
88 |
+
func: "$lambda x, s=@output_classes: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(s)]), axis=0)"
|
89 |
+
overwrite: true
|
90 |
+
- _target_: EnsureTyped
|
91 |
+
keys:
|
92 |
+
- "@image_key"
|
93 |
+
- "@label_key"
|
94 |
+
- _target_: CastToTyped
|
95 |
+
keys: "@image_key"
|
96 |
+
dtype: "$torch.float32"
|
97 |
+
- _target_: SpatialPadd
|
98 |
+
keys:
|
99 |
+
- "@image_key"
|
100 |
+
- "@label_key"
|
101 |
+
- label4crop
|
102 |
+
spatial_size: "@patch_size"
|
103 |
+
mode:
|
104 |
+
- reflect
|
105 |
+
- constant
|
106 |
+
- constant
|
107 |
+
- _target_: RandCropByLabelClassesd
|
108 |
+
keys:
|
109 |
+
- "@image_key"
|
110 |
+
- "@label_key"
|
111 |
+
label_key: label4crop
|
112 |
+
num_classes: "@output_classes"
|
113 |
+
ratios: "$[1,] * @output_classes"
|
114 |
+
spatial_size: "@patch_size"
|
115 |
+
num_samples: "@num_patches_per_image"
|
116 |
+
- _target_: Lambdad
|
117 |
+
keys: label4crop
|
118 |
+
func: "$lambda x: 0"
|
119 |
+
- _target_: RandRotated
|
120 |
+
keys:
|
121 |
+
- "@image_key"
|
122 |
+
- "@label_key"
|
123 |
+
range_x: 0.3
|
124 |
+
range_y: 0.3
|
125 |
+
range_z: 0.3
|
126 |
+
mode:
|
127 |
+
- bilinear
|
128 |
+
- nearest
|
129 |
+
prob: 0.2
|
130 |
+
- _target_: RandZoomd
|
131 |
+
keys:
|
132 |
+
- "@image_key"
|
133 |
+
- "@label_key"
|
134 |
+
min_zoom: 0.8
|
135 |
+
max_zoom: 1.2
|
136 |
+
mode:
|
137 |
+
- trilinear
|
138 |
+
- nearest
|
139 |
+
align_corners:
|
140 |
+
- null
|
141 |
+
- null
|
142 |
+
prob: 0.16
|
143 |
+
- _target_: RandGaussianSmoothd
|
144 |
+
keys: "@image_key"
|
145 |
+
sigma_x:
|
146 |
+
- 0.5
|
147 |
+
- 1.15
|
148 |
+
sigma_y:
|
149 |
+
- 0.5
|
150 |
+
- 1.15
|
151 |
+
sigma_z:
|
152 |
+
- 0.5
|
153 |
+
- 1.15
|
154 |
+
prob: 0.15
|
155 |
+
- _target_: RandScaleIntensityd
|
156 |
+
keys: "@image_key"
|
157 |
+
factors: 0.3
|
158 |
+
prob: 0.5
|
159 |
+
- _target_: RandShiftIntensityd
|
160 |
+
keys: "@image_key"
|
161 |
+
offsets: 0.1
|
162 |
+
prob: 0.5
|
163 |
+
- _target_: RandGaussianNoised
|
164 |
+
keys: "@image_key"
|
165 |
+
std: 0.01
|
166 |
+
prob: 0.15
|
167 |
+
- _target_: RandFlipd
|
168 |
+
keys:
|
169 |
+
- "@image_key"
|
170 |
+
- "@label_key"
|
171 |
+
spatial_axis: 0
|
172 |
+
prob: 0.5
|
173 |
+
- _target_: RandFlipd
|
174 |
+
keys:
|
175 |
+
- "@image_key"
|
176 |
+
- "@label_key"
|
177 |
+
spatial_axis: 1
|
178 |
+
prob: 0.5
|
179 |
+
- _target_: RandFlipd
|
180 |
+
keys:
|
181 |
+
- "@image_key"
|
182 |
+
- "@label_key"
|
183 |
+
spatial_axis: 2
|
184 |
+
prob: 0.5
|
185 |
+
- _target_: CastToTyped
|
186 |
+
keys:
|
187 |
+
- "@image_key"
|
188 |
+
- "@label_key"
|
189 |
+
dtype:
|
190 |
+
- "$torch.float32"
|
191 |
+
- "$torch.uint8"
|
192 |
+
- _target_: ToTensord
|
193 |
+
keys:
|
194 |
+
- "@image_key"
|
195 |
+
- "@label_key"
|
196 |
+
transform_validation:
|
197 |
+
_target_: Compose
|
198 |
+
transforms:
|
199 |
+
- _target_: LoadImaged
|
200 |
+
keys:
|
201 |
+
- "@image_key"
|
202 |
+
- "@label_key"
|
203 |
+
- _target_: EnsureChannelFirstd
|
204 |
+
keys:
|
205 |
+
- "@image_key"
|
206 |
+
- "@label_key"
|
207 |
+
- _target_: Orientationd
|
208 |
+
keys:
|
209 |
+
- "@image_key"
|
210 |
+
- "@label_key"
|
211 |
+
axcodes: RAS
|
212 |
+
- _target_: Spacingd
|
213 |
+
keys:
|
214 |
+
- "@image_key"
|
215 |
+
- "@label_key"
|
216 |
+
pixdim:
|
217 |
+
- 1
|
218 |
+
- 1
|
219 |
+
- 1
|
220 |
+
mode:
|
221 |
+
- bilinear
|
222 |
+
- nearest
|
223 |
+
align_corners:
|
224 |
+
- true
|
225 |
+
- true
|
226 |
+
- _target_: CastToTyped
|
227 |
+
keys: "@image_key"
|
228 |
+
dtype: "$torch.float32"
|
229 |
+
- _target_: ScaleIntensityRanged
|
230 |
+
keys: "@image_key"
|
231 |
+
a_min: -87
|
232 |
+
a_max: 199
|
233 |
+
b_min: 0
|
234 |
+
b_max: 1
|
235 |
+
clip: true
|
236 |
+
- _target_: CastToTyped
|
237 |
+
keys:
|
238 |
+
- "@image_key"
|
239 |
+
- "@label_key"
|
240 |
+
dtype:
|
241 |
+
- "$np.float16"
|
242 |
+
- "$np.uint8"
|
243 |
+
- _target_: CastToTyped
|
244 |
+
keys:
|
245 |
+
- "@image_key"
|
246 |
+
- "@label_key"
|
247 |
+
dtype:
|
248 |
+
- "$torch.float32"
|
249 |
+
- "$torch.uint8"
|
250 |
+
- _target_: ToTensord
|
251 |
+
keys:
|
252 |
+
- "@image_key"
|
253 |
+
- "@label_key"
|
254 |
+
loss:
|
255 |
+
_target_: DiceCELoss
|
256 |
+
include_background: false
|
257 |
+
to_onehot_y: true
|
258 |
+
softmax: true
|
259 |
+
squared_pred: true
|
260 |
+
batch: true
|
261 |
+
smooth_nr: 0.00001
|
262 |
+
smooth_dr: 0.00001
|
263 |
+
dints_space:
|
264 |
+
_target_: monai.networks.nets.TopologySearch
|
265 |
+
channel_mul: 0.5
|
266 |
+
num_blocks: 12
|
267 |
+
num_depths: 4
|
268 |
+
use_downsample: true
|
269 |
+
device: "$torch.device('cuda')"
|
270 |
+
network:
|
271 |
+
_target_: monai.networks.nets.DiNTS
|
272 |
+
dints_space: "@dints_space"
|
273 |
+
in_channels: "@input_channels"
|
274 |
+
num_classes: "@output_classes"
|
275 |
+
use_downsample: true
|
configs/train.yaml
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
imports:
|
3 |
+
- "$import glob"
|
4 |
+
- "$import json"
|
5 |
+
- "$import os"
|
6 |
+
- "$import ignite"
|
7 |
+
- "$from scipy import ndimage"
|
8 |
+
input_channels: 1
|
9 |
+
output_classes: 3
|
10 |
+
arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
|
11 |
+
arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
|
12 |
+
bundle_root: "/workspace/MONAI/model-zoo/models/pancreas_ct_dints_segmentation"
|
13 |
+
ckpt_dir: "$@bundle_root + '/models'"
|
14 |
+
output_dir: "$@bundle_root + '/eval'"
|
15 |
+
dataset_dir: "/workspace/data/msd/Task07_Pancreas"
|
16 |
+
data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
|
17 |
+
train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training',
|
18 |
+
base_dir=@dataset_dir)"
|
19 |
+
val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation',
|
20 |
+
base_dir=@dataset_dir)"
|
21 |
+
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
|
22 |
+
dints_space:
|
23 |
+
_target_: monai.networks.nets.TopologyInstance
|
24 |
+
channel_mul: 1
|
25 |
+
num_blocks: 12
|
26 |
+
num_depths: 4
|
27 |
+
use_downsample: true
|
28 |
+
arch_code:
|
29 |
+
- "$@arch_ckpt['arch_code_a']"
|
30 |
+
- "$@arch_ckpt['arch_code_c']"
|
31 |
+
device: "$torch.device('cuda')"
|
32 |
+
network_def:
|
33 |
+
_target_: monai.networks.nets.DiNTS
|
34 |
+
dints_space: "@dints_space"
|
35 |
+
in_channels: "@input_channels"
|
36 |
+
num_classes: "@output_classes"
|
37 |
+
use_downsample: true
|
38 |
+
node_a: "$@arch_ckpt['node_a']"
|
39 |
+
network: "$@network_def.to(@device)"
|
40 |
+
loss:
|
41 |
+
_target_: DiceCELoss
|
42 |
+
include_background: false
|
43 |
+
to_onehot_y: true
|
44 |
+
softmax: true
|
45 |
+
squared_pred: true
|
46 |
+
batch: true
|
47 |
+
smooth_nr: 1.0e-05
|
48 |
+
smooth_dr: 1.0e-05
|
49 |
+
optimizer:
|
50 |
+
_target_: torch.optim.SGD
|
51 |
+
params: "$@network.parameters()"
|
52 |
+
momentum: 0.9
|
53 |
+
weight_decay: 4.0e-05
|
54 |
+
lr: 0.025
|
55 |
+
lr_scheduler:
|
56 |
+
_target_: torch.optim.lr_scheduler.StepLR
|
57 |
+
optimizer: "@optimizer"
|
58 |
+
step_size: 80
|
59 |
+
gamma: 0.5
|
60 |
+
image_key: image
|
61 |
+
label_key: label
|
62 |
+
train:
|
63 |
+
deterministic_transforms:
|
64 |
+
- _target_: LoadImaged
|
65 |
+
keys:
|
66 |
+
- "@image_key"
|
67 |
+
- "@label_key"
|
68 |
+
- _target_: EnsureChannelFirstd
|
69 |
+
keys:
|
70 |
+
- "@image_key"
|
71 |
+
- "@label_key"
|
72 |
+
- _target_: Orientationd
|
73 |
+
keys:
|
74 |
+
- "@image_key"
|
75 |
+
- "@label_key"
|
76 |
+
axcodes: RAS
|
77 |
+
- _target_: Spacingd
|
78 |
+
keys:
|
79 |
+
- "@image_key"
|
80 |
+
- "@label_key"
|
81 |
+
pixdim:
|
82 |
+
- 1
|
83 |
+
- 1
|
84 |
+
- 1
|
85 |
+
mode:
|
86 |
+
- bilinear
|
87 |
+
- nearest
|
88 |
+
align_corners:
|
89 |
+
- true
|
90 |
+
- true
|
91 |
+
- _target_: CastToTyped
|
92 |
+
keys: "@image_key"
|
93 |
+
dtype: "$torch.float32"
|
94 |
+
- _target_: ScaleIntensityRanged
|
95 |
+
keys: "@image_key"
|
96 |
+
a_min: -87
|
97 |
+
a_max: 199
|
98 |
+
b_min: 0
|
99 |
+
b_max: 1
|
100 |
+
clip: true
|
101 |
+
- _target_: CastToTyped
|
102 |
+
keys:
|
103 |
+
- "@image_key"
|
104 |
+
- "@label_key"
|
105 |
+
dtype:
|
106 |
+
- "$np.float16"
|
107 |
+
- "$np.uint8"
|
108 |
+
- _target_: CopyItemsd
|
109 |
+
keys: "@label_key"
|
110 |
+
times: 1
|
111 |
+
names:
|
112 |
+
- label4crop
|
113 |
+
- _target_: Lambdad
|
114 |
+
keys: label4crop
|
115 |
+
func: "$lambda x, s=@output_classes: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype),
|
116 |
+
iterations=48).astype(x.dtype) for _k in range(s)]), axis=0)"
|
117 |
+
overwrite: true
|
118 |
+
- _target_: EnsureTyped
|
119 |
+
keys:
|
120 |
+
- "@image_key"
|
121 |
+
- "@label_key"
|
122 |
+
- _target_: CastToTyped
|
123 |
+
keys: "@image_key"
|
124 |
+
dtype: "$torch.float32"
|
125 |
+
- _target_: SpatialPadd
|
126 |
+
keys:
|
127 |
+
- "@image_key"
|
128 |
+
- "@label_key"
|
129 |
+
- label4crop
|
130 |
+
spatial_size:
|
131 |
+
- 96
|
132 |
+
- 96
|
133 |
+
- 96
|
134 |
+
mode:
|
135 |
+
- reflect
|
136 |
+
- constant
|
137 |
+
- constant
|
138 |
+
random_transforms:
|
139 |
+
- _target_: RandCropByLabelClassesd
|
140 |
+
keys:
|
141 |
+
- "@image_key"
|
142 |
+
- "@label_key"
|
143 |
+
label_key: label4crop
|
144 |
+
num_classes: "@output_classes"
|
145 |
+
ratios: "$[1,] * @output_classes"
|
146 |
+
spatial_size:
|
147 |
+
- 96
|
148 |
+
- 96
|
149 |
+
- 96
|
150 |
+
num_samples: 1
|
151 |
+
- _target_: Lambdad
|
152 |
+
keys: label4crop
|
153 |
+
func: "$lambda x: 0"
|
154 |
+
- _target_: RandRotated
|
155 |
+
keys:
|
156 |
+
- "@image_key"
|
157 |
+
- "@label_key"
|
158 |
+
range_x: 0.3
|
159 |
+
range_y: 0.3
|
160 |
+
range_z: 0.3
|
161 |
+
mode:
|
162 |
+
- bilinear
|
163 |
+
- nearest
|
164 |
+
prob: 0.2
|
165 |
+
- _target_: RandZoomd
|
166 |
+
keys:
|
167 |
+
- "@image_key"
|
168 |
+
- "@label_key"
|
169 |
+
min_zoom: 0.8
|
170 |
+
max_zoom: 1.2
|
171 |
+
mode:
|
172 |
+
- trilinear
|
173 |
+
- nearest
|
174 |
+
align_corners:
|
175 |
+
- true
|
176 |
+
-
|
177 |
+
prob: 0.16
|
178 |
+
- _target_: RandGaussianSmoothd
|
179 |
+
keys: "@image_key"
|
180 |
+
sigma_x:
|
181 |
+
- 0.5
|
182 |
+
- 1.15
|
183 |
+
sigma_y:
|
184 |
+
- 0.5
|
185 |
+
- 1.15
|
186 |
+
sigma_z:
|
187 |
+
- 0.5
|
188 |
+
- 1.15
|
189 |
+
prob: 0.15
|
190 |
+
- _target_: RandScaleIntensityd
|
191 |
+
keys: "@image_key"
|
192 |
+
factors: 0.3
|
193 |
+
prob: 0.5
|
194 |
+
- _target_: RandShiftIntensityd
|
195 |
+
keys: "@image_key"
|
196 |
+
offsets: 0.1
|
197 |
+
prob: 0.5
|
198 |
+
- _target_: RandGaussianNoised
|
199 |
+
keys: "@image_key"
|
200 |
+
std: 0.01
|
201 |
+
prob: 0.15
|
202 |
+
- _target_: RandFlipd
|
203 |
+
keys:
|
204 |
+
- "@image_key"
|
205 |
+
- "@label_key"
|
206 |
+
spatial_axis: 0
|
207 |
+
prob: 0.5
|
208 |
+
- _target_: RandFlipd
|
209 |
+
keys:
|
210 |
+
- "@image_key"
|
211 |
+
- "@label_key"
|
212 |
+
spatial_axis: 1
|
213 |
+
prob: 0.5
|
214 |
+
- _target_: RandFlipd
|
215 |
+
keys:
|
216 |
+
- "@image_key"
|
217 |
+
- "@label_key"
|
218 |
+
spatial_axis: 2
|
219 |
+
prob: 0.5
|
220 |
+
- _target_: CastToTyped
|
221 |
+
keys:
|
222 |
+
- "@image_key"
|
223 |
+
- "@label_key"
|
224 |
+
dtype:
|
225 |
+
- "$torch.float32"
|
226 |
+
- "$torch.uint8"
|
227 |
+
- _target_: ToTensord
|
228 |
+
keys:
|
229 |
+
- "@image_key"
|
230 |
+
- "@label_key"
|
231 |
+
preprocessing:
|
232 |
+
_target_: Compose
|
233 |
+
transforms: "$@train#deterministic_transforms + @train#random_transforms"
|
234 |
+
dataset:
|
235 |
+
_target_: CacheDataset
|
236 |
+
data: "@train_datalist"
|
237 |
+
transform: "@train#preprocessing"
|
238 |
+
cache_rate: 0.125
|
239 |
+
num_workers: 4
|
240 |
+
dataloader:
|
241 |
+
_target_: DataLoader
|
242 |
+
dataset: "@train#dataset"
|
243 |
+
batch_size: 2
|
244 |
+
shuffle: true
|
245 |
+
num_workers: 4
|
246 |
+
inferer:
|
247 |
+
_target_: SimpleInferer
|
248 |
+
postprocessing:
|
249 |
+
_target_: Compose
|
250 |
+
transforms:
|
251 |
+
- _target_: Activationsd
|
252 |
+
keys: pred
|
253 |
+
softmax: true
|
254 |
+
- _target_: AsDiscreted
|
255 |
+
keys:
|
256 |
+
- pred
|
257 |
+
- label
|
258 |
+
argmax:
|
259 |
+
- true
|
260 |
+
- false
|
261 |
+
to_onehot: "@output_classes"
|
262 |
+
handlers:
|
263 |
+
- _target_: LrScheduleHandler
|
264 |
+
lr_scheduler: "@lr_scheduler"
|
265 |
+
print_lr: true
|
266 |
+
- _target_: ValidationHandler
|
267 |
+
validator: "@validate#evaluator"
|
268 |
+
epoch_level: true
|
269 |
+
interval: 10
|
270 |
+
- _target_: StatsHandler
|
271 |
+
tag_name: train_loss
|
272 |
+
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
|
273 |
+
- _target_: TensorBoardStatsHandler
|
274 |
+
log_dir: "@output_dir"
|
275 |
+
tag_name: train_loss
|
276 |
+
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
|
277 |
+
key_metric:
|
278 |
+
train_accuracy:
|
279 |
+
_target_: ignite.metrics.Accuracy
|
280 |
+
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
|
281 |
+
trainer:
|
282 |
+
_target_: SupervisedTrainer
|
283 |
+
max_epochs: 400
|
284 |
+
device: "@device"
|
285 |
+
train_data_loader: "@train#dataloader"
|
286 |
+
network: "@network"
|
287 |
+
loss_function: "@loss"
|
288 |
+
optimizer: "@optimizer"
|
289 |
+
inferer: "@train#inferer"
|
290 |
+
postprocessing: "@train#postprocessing"
|
291 |
+
key_train_metric: "@train#key_metric"
|
292 |
+
train_handlers: "@train#handlers"
|
293 |
+
amp: true
|
294 |
+
validate:
|
295 |
+
preprocessing:
|
296 |
+
_target_: Compose
|
297 |
+
transforms: "%train#deterministic_transforms"
|
298 |
+
dataset:
|
299 |
+
_target_: CacheDataset
|
300 |
+
data: "@val_datalist"
|
301 |
+
transform: "@validate#preprocessing"
|
302 |
+
cache_rate: 0.125
|
303 |
+
dataloader:
|
304 |
+
_target_: DataLoader
|
305 |
+
dataset: "@validate#dataset"
|
306 |
+
batch_size: 1
|
307 |
+
shuffle: false
|
308 |
+
num_workers: 4
|
309 |
+
inferer:
|
310 |
+
_target_: SlidingWindowInferer
|
311 |
+
roi_size:
|
312 |
+
- 96
|
313 |
+
- 96
|
314 |
+
- 96
|
315 |
+
sw_batch_size: 6
|
316 |
+
overlap: 0.625
|
317 |
+
postprocessing: "%train#postprocessing"
|
318 |
+
handlers:
|
319 |
+
- _target_: StatsHandler
|
320 |
+
iteration_log: false
|
321 |
+
- _target_: TensorBoardStatsHandler
|
322 |
+
log_dir: "@output_dir"
|
323 |
+
iteration_log: false
|
324 |
+
- _target_: CheckpointSaver
|
325 |
+
save_dir: "@ckpt_dir"
|
326 |
+
save_dict:
|
327 |
+
model: "@network"
|
328 |
+
save_key_metric: true
|
329 |
+
key_metric_filename: model.pt
|
330 |
+
key_metric:
|
331 |
+
val_mean_dice:
|
332 |
+
_target_: MeanDice
|
333 |
+
include_background: false
|
334 |
+
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
|
335 |
+
additional_metrics:
|
336 |
+
val_accuracy:
|
337 |
+
_target_: ignite.metrics.Accuracy
|
338 |
+
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
|
339 |
+
evaluator:
|
340 |
+
_target_: SupervisedEvaluator
|
341 |
+
device: "@device"
|
342 |
+
val_data_loader: "@validate#dataloader"
|
343 |
+
network: "@network"
|
344 |
+
inferer: "@validate#inferer"
|
345 |
+
postprocessing: "@validate#postprocessing"
|
346 |
+
key_val_metric: "@validate#key_metric"
|
347 |
+
additional_metrics: "@validate#additional_metrics"
|
348 |
+
val_handlers: "@validate#handlers"
|
349 |
+
amp: true
|
350 |
+
training:
|
351 |
+
- "$monai.utils.set_determinism(seed=123)"
|
352 |
+
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
|
353 |
+
- "$@train#trainer.run()"
|
docs/README.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
A neural architecture search algorithm for volumetric (3D) segmentation of the pancreas and pancreatic tumor from CT image.
|
3 |
+
|
4 |
+
# Model Overview
|
5 |
+
This model is trained using the state-of-the-art algorithm [1] of the "Medical Segmentation Decathlon Challenge 2018" with 196 training images, 56 validation images, and 28 testing images.
|
6 |
+
|
7 |
+
## Data
|
8 |
+
The training dataset is Task07_Pancreas.tar from http://medicaldecathlon.com/. And the data list/split can be created with the script `scripts/prepare_datalist.py`.
|
9 |
+
|
10 |
+
## Training configuration
|
11 |
+
The training was performed with at least 16GB-memory GPUs.
|
12 |
+
|
13 |
+
Actual Model Input: 96 x 96 x 96
|
14 |
+
|
15 |
+
## Input and output formats
|
16 |
+
Input: 1 channel CT image
|
17 |
+
|
18 |
+
Output: 3 channels: Label 2: pancreatic tumor; Label 1: pancreas; Label 0: everything else
|
19 |
+
|
20 |
+
## Scores
|
21 |
+
This model achieves the following Dice score on the validation data (our own split from the training dataset):
|
22 |
+
|
23 |
+
Mean Dice = 0.72
|
24 |
+
|
25 |
+
## commands example
|
26 |
+
Create data split (.json file):
|
27 |
+
|
28 |
+
```
|
29 |
+
python scripts/prepare_datalist.py --path /path-to-Task07_Pancreas/ --output configs/dataset_0.json
|
30 |
+
```
|
31 |
+
|
32 |
+
Execute model searching:
|
33 |
+
|
34 |
+
```
|
35 |
+
python -m scripts.search run --config_file configs/search.yaml
|
36 |
+
```
|
37 |
+
|
38 |
+
Execute multi-GPU model searching (recommended):
|
39 |
+
|
40 |
+
```
|
41 |
+
torchrun --nnodes=1 --nproc_per_node=8 -m scripts.search run --config_file configs/search.yaml
|
42 |
+
```
|
43 |
+
|
44 |
+
Execute training:
|
45 |
+
|
46 |
+
```
|
47 |
+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.yaml --logging_file configs/logging.conf
|
48 |
+
```
|
49 |
+
|
50 |
+
Override the `train` config to execute multi-GPU training:
|
51 |
+
|
52 |
+
```
|
53 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']" --logging_file configs/logging.conf
|
54 |
+
```
|
55 |
+
|
56 |
+
Override the `train` config to execute evaluation with the trained model:
|
57 |
+
|
58 |
+
```
|
59 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.yaml','configs/evaluate.yaml']" --logging_file configs/logging.conf
|
60 |
+
```
|
61 |
+
|
62 |
+
Execute inference:
|
63 |
+
|
64 |
+
```
|
65 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.yaml --logging_file configs/logging.conf
|
66 |
+
```
|
67 |
+
|
68 |
+
# Disclaimer
|
69 |
+
This is an example, not to be used for diagnostic purposes.
|
70 |
+
|
71 |
+
# References
|
72 |
+
[1] He, Y., Yang, D., Roth, H., Zhao, C. and Xu, D., 2021. Dints: Differentiable neural network topology search for 3d medical image segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 5841-5850).
|
docs/license.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Third Party Licenses
|
2 |
+
-----------------------------------------------------------------------
|
3 |
+
|
4 |
+
/*********************************************************************/
|
5 |
+
i. Medical Segmentation Decathlon
|
6 |
+
http://medicaldecathlon.com/
|
models/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:975201057eb16225a1abfd047cb9b293f6d481dc604468d512710f3543f29066
|
3 |
+
size 616210421
|
models/model.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:526d2bdb4d88f6f55f2d88eb0c79deeeb90b3ced182269e33cdc5da6e46ea5fb
|
3 |
+
size 616338455
|
models/search_code_18590.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01e361e9843e2f4e5ff1599da0abac77013ea38cab8fdd6c9286bb6572c9a32d
|
3 |
+
size 4335
|
scripts/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
scripts/prepare_datalist.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
import monai
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
|
9 |
+
|
10 |
+
def produce_sample_dict(line: str):
|
11 |
+
return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
|
12 |
+
|
13 |
+
|
14 |
+
def produce_datalist(dataset_dir: str):
|
15 |
+
"""
|
16 |
+
This function is used to split the dataset.
|
17 |
+
It will produce 200 samples for training, and the other samples are divided equally
|
18 |
+
into val and test sets.
|
19 |
+
"""
|
20 |
+
|
21 |
+
samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
|
22 |
+
samples = [_item.replace(os.path.join(dataset_dir, "labelsTr"), "labelsTr") for _item in samples]
|
23 |
+
datalist = []
|
24 |
+
for line in samples:
|
25 |
+
datalist.append(produce_sample_dict(line))
|
26 |
+
train_list, other_list = train_test_split(datalist, train_size=196)
|
27 |
+
val_list, test_list = train_test_split(other_list, train_size=0.66)
|
28 |
+
|
29 |
+
return {"training": train_list, "validation": val_list, "testing": test_list}
|
30 |
+
|
31 |
+
|
32 |
+
def main(args):
|
33 |
+
"""
|
34 |
+
split the dataset and output the data list into a json file.
|
35 |
+
"""
|
36 |
+
data_file_base_dir = args.path
|
37 |
+
output_json = args.output
|
38 |
+
# produce deterministic data splits
|
39 |
+
monai.utils.set_determinism(seed=123)
|
40 |
+
datalist = produce_datalist(dataset_dir=data_file_base_dir)
|
41 |
+
with open(output_json, "w") as f:
|
42 |
+
json.dump(datalist, f, ensure_ascii=True, indent=4)
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
|
47 |
+
parser = argparse.ArgumentParser(description="")
|
48 |
+
parser.add_argument(
|
49 |
+
"--path",
|
50 |
+
type=str,
|
51 |
+
default="/workspace/data/msd/Task07_Pancreas",
|
52 |
+
help="root path of MSD Task07_Pancreas dataset.",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
|
56 |
+
)
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
main(args)
|
scripts/search.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import json
|
13 |
+
import logging
|
14 |
+
import os
|
15 |
+
import random
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
from datetime import datetime
|
19 |
+
from typing import Sequence, Union
|
20 |
+
|
21 |
+
import monai
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torch.distributed as dist
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import yaml
|
27 |
+
from monai import transforms
|
28 |
+
from monai.bundle import ConfigParser
|
29 |
+
from monai.data import ThreadDataLoader, partition_dataset
|
30 |
+
from monai.inferers import sliding_window_inference
|
31 |
+
from monai.metrics import compute_meandice
|
32 |
+
from monai.utils import set_determinism
|
33 |
+
from torch.nn.parallel import DistributedDataParallel
|
34 |
+
from torch.utils.tensorboard import SummaryWriter
|
35 |
+
|
36 |
+
|
37 |
+
def run(config_file: Union[str, Sequence[str]]):
|
38 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
39 |
+
|
40 |
+
parser = ConfigParser()
|
41 |
+
parser.read_config(config_file)
|
42 |
+
|
43 |
+
arch_ckpt_path = parser["arch_ckpt_path"]
|
44 |
+
amp = parser["amp"]
|
45 |
+
data_file_base_dir = parser["data_file_base_dir"]
|
46 |
+
data_list_file_path = parser["data_list_file_path"]
|
47 |
+
determ = parser["determ"]
|
48 |
+
learning_rate = parser["learning_rate"]
|
49 |
+
learning_rate_arch = parser["learning_rate_arch"]
|
50 |
+
learning_rate_milestones = np.array(parser["learning_rate_milestones"])
|
51 |
+
num_images_per_batch = parser["num_images_per_batch"]
|
52 |
+
num_epochs = parser["num_epochs"] # around 20k iterations
|
53 |
+
num_epochs_per_validation = parser["num_epochs_per_validation"]
|
54 |
+
num_epochs_warmup = parser["num_epochs_warmup"]
|
55 |
+
num_sw_batch_size = parser["num_sw_batch_size"]
|
56 |
+
output_classes = parser["output_classes"]
|
57 |
+
overlap_ratio = parser["overlap_ratio"]
|
58 |
+
patch_size_valid = parser["patch_size_valid"]
|
59 |
+
ram_cost_factor = parser["ram_cost_factor"]
|
60 |
+
print("[info] GPU RAM cost factor:", ram_cost_factor)
|
61 |
+
|
62 |
+
train_transforms = parser.get_parsed_content("transform_train")
|
63 |
+
val_transforms = parser.get_parsed_content("transform_validation")
|
64 |
+
|
65 |
+
# deterministic training
|
66 |
+
if determ:
|
67 |
+
set_determinism(seed=0)
|
68 |
+
|
69 |
+
print("[info] number of GPUs:", torch.cuda.device_count())
|
70 |
+
if torch.cuda.device_count() > 1:
|
71 |
+
# initialize the distributed training process, every GPU runs in a process
|
72 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
73 |
+
world_size = dist.get_world_size()
|
74 |
+
else:
|
75 |
+
world_size = 1
|
76 |
+
print("[info] world_size:", world_size)
|
77 |
+
|
78 |
+
with open(data_list_file_path, "r") as f:
|
79 |
+
json_data = json.load(f)
|
80 |
+
|
81 |
+
list_train = json_data["training"]
|
82 |
+
list_valid = json_data["validation"]
|
83 |
+
|
84 |
+
# training data
|
85 |
+
files = []
|
86 |
+
for _i in range(len(list_train)):
|
87 |
+
str_img = os.path.join(data_file_base_dir, list_train[_i]["image"])
|
88 |
+
str_seg = os.path.join(data_file_base_dir, list_train[_i]["label"])
|
89 |
+
|
90 |
+
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
|
91 |
+
continue
|
92 |
+
|
93 |
+
files.append({"image": str_img, "label": str_seg})
|
94 |
+
train_files = files
|
95 |
+
|
96 |
+
random.shuffle(train_files)
|
97 |
+
|
98 |
+
train_files_w = train_files[: len(train_files) // 2]
|
99 |
+
if torch.cuda.device_count() > 1:
|
100 |
+
train_files_w = partition_dataset(
|
101 |
+
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
|
102 |
+
)[dist.get_rank()]
|
103 |
+
print("train_files_w:", len(train_files_w))
|
104 |
+
|
105 |
+
train_files_a = train_files[len(train_files) // 2 :]
|
106 |
+
if torch.cuda.device_count() > 1:
|
107 |
+
train_files_a = partition_dataset(
|
108 |
+
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
|
109 |
+
)[dist.get_rank()]
|
110 |
+
print("train_files_a:", len(train_files_a))
|
111 |
+
|
112 |
+
# validation data
|
113 |
+
files = []
|
114 |
+
for _i in range(len(list_valid)):
|
115 |
+
str_img = os.path.join(data_file_base_dir, list_valid[_i]["image"])
|
116 |
+
str_seg = os.path.join(data_file_base_dir, list_valid[_i]["label"])
|
117 |
+
|
118 |
+
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
|
119 |
+
continue
|
120 |
+
|
121 |
+
files.append({"image": str_img, "label": str_seg})
|
122 |
+
val_files = files
|
123 |
+
|
124 |
+
if torch.cuda.device_count() > 1:
|
125 |
+
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
|
126 |
+
dist.get_rank()
|
127 |
+
]
|
128 |
+
print("val_files:", len(val_files))
|
129 |
+
|
130 |
+
# network architecture
|
131 |
+
if torch.cuda.device_count() > 1:
|
132 |
+
device = torch.device(f"cuda:{dist.get_rank()}")
|
133 |
+
else:
|
134 |
+
device = torch.device("cuda:0")
|
135 |
+
torch.cuda.set_device(device)
|
136 |
+
|
137 |
+
if torch.cuda.device_count() > 1:
|
138 |
+
train_ds_a = monai.data.CacheDataset(
|
139 |
+
data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8
|
140 |
+
)
|
141 |
+
train_ds_w = monai.data.CacheDataset(
|
142 |
+
data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8
|
143 |
+
)
|
144 |
+
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
|
145 |
+
else:
|
146 |
+
train_ds_a = monai.data.CacheDataset(
|
147 |
+
data=train_files_a, transform=train_transforms, cache_rate=0.125, num_workers=8
|
148 |
+
)
|
149 |
+
train_ds_w = monai.data.CacheDataset(
|
150 |
+
data=train_files_w, transform=train_transforms, cache_rate=0.125, num_workers=8
|
151 |
+
)
|
152 |
+
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.125, num_workers=2)
|
153 |
+
|
154 |
+
train_loader_a = ThreadDataLoader(train_ds_a, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
|
155 |
+
train_loader_w = ThreadDataLoader(train_ds_w, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
|
156 |
+
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
|
157 |
+
|
158 |
+
model = parser.get_parsed_content("network")
|
159 |
+
dints_space = parser.get_parsed_content("dints_space")
|
160 |
+
|
161 |
+
model = model.to(device)
|
162 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
163 |
+
|
164 |
+
post_pred = transforms.Compose(
|
165 |
+
[transforms.EnsureType(), transforms.AsDiscrete(argmax=True, to_onehot=output_classes)]
|
166 |
+
)
|
167 |
+
post_label = transforms.Compose([transforms.EnsureType(), transforms.AsDiscrete(to_onehot=output_classes)])
|
168 |
+
|
169 |
+
# loss function
|
170 |
+
loss_func = parser.get_parsed_content("loss")
|
171 |
+
|
172 |
+
# optimizer
|
173 |
+
optimizer = torch.optim.SGD(
|
174 |
+
model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004
|
175 |
+
)
|
176 |
+
arch_optimizer_a = torch.optim.Adam(
|
177 |
+
[dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
|
178 |
+
)
|
179 |
+
arch_optimizer_c = torch.optim.Adam(
|
180 |
+
[dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
|
181 |
+
)
|
182 |
+
|
183 |
+
if torch.cuda.device_count() > 1:
|
184 |
+
model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
|
185 |
+
|
186 |
+
# amp
|
187 |
+
if amp:
|
188 |
+
from torch.cuda.amp import GradScaler, autocast
|
189 |
+
|
190 |
+
scaler = GradScaler()
|
191 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
192 |
+
print("[info] amp enabled")
|
193 |
+
|
194 |
+
# start a typical PyTorch training
|
195 |
+
val_interval = num_epochs_per_validation
|
196 |
+
best_metric = -1
|
197 |
+
best_metric_epoch = -1
|
198 |
+
idx_iter = 0
|
199 |
+
|
200 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
201 |
+
writer = SummaryWriter(log_dir=os.path.join(arch_ckpt_path, "Events"))
|
202 |
+
|
203 |
+
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
|
204 |
+
f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")
|
205 |
+
|
206 |
+
dataloader_a_iterator = iter(train_loader_a)
|
207 |
+
|
208 |
+
start_time = time.time()
|
209 |
+
for epoch in range(num_epochs):
|
210 |
+
decay = 0.5 ** np.sum(
|
211 |
+
[(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones]
|
212 |
+
)
|
213 |
+
lr = learning_rate * decay * world_size
|
214 |
+
for param_group in optimizer.param_groups:
|
215 |
+
param_group["lr"] = lr
|
216 |
+
|
217 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
218 |
+
print("-" * 10)
|
219 |
+
print(f"epoch {epoch + 1}/{num_epochs}")
|
220 |
+
print("learning rate is set to {}".format(lr))
|
221 |
+
|
222 |
+
model.train()
|
223 |
+
epoch_loss = 0
|
224 |
+
loss_torch = torch.zeros(2, dtype=torch.float, device=device)
|
225 |
+
epoch_loss_arch = 0
|
226 |
+
loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
|
227 |
+
step = 0
|
228 |
+
|
229 |
+
for batch_data in train_loader_w:
|
230 |
+
step += 1
|
231 |
+
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
|
232 |
+
if world_size == 1:
|
233 |
+
for _ in model.weight_parameters():
|
234 |
+
_.requires_grad = True
|
235 |
+
else:
|
236 |
+
for _ in model.module.weight_parameters():
|
237 |
+
_.requires_grad = True
|
238 |
+
dints_space.log_alpha_a.requires_grad = False
|
239 |
+
dints_space.log_alpha_c.requires_grad = False
|
240 |
+
|
241 |
+
optimizer.zero_grad()
|
242 |
+
|
243 |
+
if amp:
|
244 |
+
with autocast():
|
245 |
+
outputs = model(inputs)
|
246 |
+
if output_classes == 2:
|
247 |
+
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
|
248 |
+
else:
|
249 |
+
loss = loss_func(outputs, labels)
|
250 |
+
|
251 |
+
scaler.scale(loss).backward()
|
252 |
+
scaler.step(optimizer)
|
253 |
+
scaler.update()
|
254 |
+
else:
|
255 |
+
outputs = model(inputs)
|
256 |
+
if output_classes == 2:
|
257 |
+
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
|
258 |
+
else:
|
259 |
+
loss = loss_func(outputs, labels)
|
260 |
+
loss.backward()
|
261 |
+
optimizer.step()
|
262 |
+
|
263 |
+
epoch_loss += loss.item()
|
264 |
+
loss_torch[0] += loss.item()
|
265 |
+
loss_torch[1] += 1.0
|
266 |
+
epoch_len = len(train_loader_w)
|
267 |
+
idx_iter += 1
|
268 |
+
|
269 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
270 |
+
print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
|
271 |
+
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
|
272 |
+
|
273 |
+
if epoch < num_epochs_warmup:
|
274 |
+
continue
|
275 |
+
|
276 |
+
try:
|
277 |
+
sample_a = next(dataloader_a_iterator)
|
278 |
+
except StopIteration:
|
279 |
+
dataloader_a_iterator = iter(train_loader_a)
|
280 |
+
sample_a = next(dataloader_a_iterator)
|
281 |
+
inputs_search, labels_search = (sample_a["image"].to(device), sample_a["label"].to(device))
|
282 |
+
if world_size == 1:
|
283 |
+
for _ in model.weight_parameters():
|
284 |
+
_.requires_grad = False
|
285 |
+
else:
|
286 |
+
for _ in model.module.weight_parameters():
|
287 |
+
_.requires_grad = False
|
288 |
+
dints_space.log_alpha_a.requires_grad = True
|
289 |
+
dints_space.log_alpha_c.requires_grad = True
|
290 |
+
|
291 |
+
# linear increase topology and RAM loss
|
292 |
+
entropy_alpha_c = torch.tensor(0.0).to(device)
|
293 |
+
entropy_alpha_a = torch.tensor(0.0).to(device)
|
294 |
+
ram_cost_full = torch.tensor(0.0).to(device)
|
295 |
+
ram_cost_usage = torch.tensor(0.0).to(device)
|
296 |
+
ram_cost_loss = torch.tensor(0.0).to(device)
|
297 |
+
topology_loss = torch.tensor(0.0).to(device)
|
298 |
+
|
299 |
+
probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
|
300 |
+
entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
|
301 |
+
entropy_alpha_c = -(
|
302 |
+
F.softmax(dints_space.log_alpha_c, dim=-1) * F.log_softmax(dints_space.log_alpha_c, dim=-1)
|
303 |
+
).mean()
|
304 |
+
topology_loss = dints_space.get_topology_entropy(probs_a)
|
305 |
+
|
306 |
+
ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True)
|
307 |
+
ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
|
308 |
+
ram_cost_loss = torch.abs(ram_cost_factor - ram_cost_usage / ram_cost_full)
|
309 |
+
|
310 |
+
arch_optimizer_a.zero_grad()
|
311 |
+
arch_optimizer_c.zero_grad()
|
312 |
+
|
313 |
+
combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup)
|
314 |
+
|
315 |
+
if amp:
|
316 |
+
with autocast():
|
317 |
+
outputs_search = model(inputs_search)
|
318 |
+
if output_classes == 2:
|
319 |
+
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
|
320 |
+
else:
|
321 |
+
loss = loss_func(outputs_search, labels_search)
|
322 |
+
|
323 |
+
loss += combination_weights * (
|
324 |
+
(entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
|
325 |
+
)
|
326 |
+
|
327 |
+
scaler.scale(loss).backward()
|
328 |
+
scaler.step(arch_optimizer_a)
|
329 |
+
scaler.step(arch_optimizer_c)
|
330 |
+
scaler.update()
|
331 |
+
else:
|
332 |
+
outputs_search = model(inputs_search)
|
333 |
+
if output_classes == 2:
|
334 |
+
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
|
335 |
+
else:
|
336 |
+
loss = loss_func(outputs_search, labels_search)
|
337 |
+
|
338 |
+
loss += 1.0 * (
|
339 |
+
combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
|
340 |
+
)
|
341 |
+
|
342 |
+
loss.backward()
|
343 |
+
arch_optimizer_a.step()
|
344 |
+
arch_optimizer_c.step()
|
345 |
+
|
346 |
+
epoch_loss_arch += loss.item()
|
347 |
+
loss_torch_arch[0] += loss.item()
|
348 |
+
loss_torch_arch[1] += 1.0
|
349 |
+
|
350 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
351 |
+
print(
|
352 |
+
"[{0}] ".format(str(datetime.now())[:19])
|
353 |
+
+ f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}"
|
354 |
+
)
|
355 |
+
writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step)
|
356 |
+
|
357 |
+
# synchronizes all processes and reduce results
|
358 |
+
if torch.cuda.device_count() > 1:
|
359 |
+
dist.barrier()
|
360 |
+
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
|
361 |
+
|
362 |
+
loss_torch = loss_torch.tolist()
|
363 |
+
loss_torch_arch = loss_torch_arch.tolist()
|
364 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
365 |
+
loss_torch_epoch = loss_torch[0] / loss_torch[1]
|
366 |
+
print(
|
367 |
+
f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, "
|
368 |
+
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
|
369 |
+
)
|
370 |
+
|
371 |
+
if epoch >= num_epochs_warmup:
|
372 |
+
loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
|
373 |
+
print(
|
374 |
+
f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, "
|
375 |
+
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
|
376 |
+
)
|
377 |
+
|
378 |
+
if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs:
|
379 |
+
torch.cuda.empty_cache()
|
380 |
+
model.eval()
|
381 |
+
with torch.no_grad():
|
382 |
+
metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device)
|
383 |
+
metric_sum = 0.0
|
384 |
+
metric_count = 0
|
385 |
+
metric_mat = []
|
386 |
+
val_images = None
|
387 |
+
val_labels = None
|
388 |
+
val_outputs = None
|
389 |
+
|
390 |
+
_index = 0
|
391 |
+
for val_data in val_loader:
|
392 |
+
val_images = val_data["image"].to(device)
|
393 |
+
val_labels = val_data["label"].to(device)
|
394 |
+
|
395 |
+
roi_size = patch_size_valid
|
396 |
+
sw_batch_size = num_sw_batch_size
|
397 |
+
|
398 |
+
if amp:
|
399 |
+
with torch.cuda.amp.autocast():
|
400 |
+
pred = sliding_window_inference(
|
401 |
+
val_images,
|
402 |
+
roi_size,
|
403 |
+
sw_batch_size,
|
404 |
+
lambda x: model(x),
|
405 |
+
mode="gaussian",
|
406 |
+
overlap=overlap_ratio,
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
pred = sliding_window_inference(
|
410 |
+
val_images,
|
411 |
+
roi_size,
|
412 |
+
sw_batch_size,
|
413 |
+
lambda x: model(x),
|
414 |
+
mode="gaussian",
|
415 |
+
overlap=overlap_ratio,
|
416 |
+
)
|
417 |
+
val_outputs = pred
|
418 |
+
|
419 |
+
val_outputs = post_pred(val_outputs[0, ...])
|
420 |
+
val_outputs = val_outputs[None, ...]
|
421 |
+
val_labels = post_label(val_labels[0, ...])
|
422 |
+
val_labels = val_labels[None, ...]
|
423 |
+
|
424 |
+
value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False)
|
425 |
+
|
426 |
+
print(_index + 1, "/", len(val_loader), value)
|
427 |
+
|
428 |
+
metric_count += len(value)
|
429 |
+
metric_sum += value.sum().item()
|
430 |
+
metric_vals = value.cpu().numpy()
|
431 |
+
if len(metric_mat) == 0:
|
432 |
+
metric_mat = metric_vals
|
433 |
+
else:
|
434 |
+
metric_mat = np.concatenate((metric_mat, metric_vals), axis=0)
|
435 |
+
|
436 |
+
for _c in range(output_classes - 1):
|
437 |
+
val0 = torch.nan_to_num(value[0, _c], nan=0.0)
|
438 |
+
val1 = 1.0 - torch.isnan(value[0, 0]).float()
|
439 |
+
metric[2 * _c] += val0 * val1
|
440 |
+
metric[2 * _c + 1] += val1
|
441 |
+
|
442 |
+
_index += 1
|
443 |
+
|
444 |
+
# synchronizes all processes and reduce results
|
445 |
+
if torch.cuda.device_count() > 1:
|
446 |
+
dist.barrier()
|
447 |
+
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
|
448 |
+
|
449 |
+
metric = metric.tolist()
|
450 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
451 |
+
for _c in range(output_classes - 1):
|
452 |
+
print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1])
|
453 |
+
avg_metric = 0
|
454 |
+
for _c in range(output_classes - 1):
|
455 |
+
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
|
456 |
+
avg_metric = avg_metric / float(output_classes - 1)
|
457 |
+
print("avg_metric", avg_metric)
|
458 |
+
|
459 |
+
if avg_metric > best_metric:
|
460 |
+
best_metric = avg_metric
|
461 |
+
best_metric_epoch = epoch + 1
|
462 |
+
best_metric_iterations = idx_iter
|
463 |
+
|
464 |
+
(node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d) = dints_space.decode()
|
465 |
+
torch.save(
|
466 |
+
{
|
467 |
+
"node_a": node_a_d,
|
468 |
+
"arch_code_a": arch_code_a_d,
|
469 |
+
"arch_code_a_max": arch_code_a_max_d,
|
470 |
+
"arch_code_c": arch_code_c_d,
|
471 |
+
"iter_num": idx_iter,
|
472 |
+
"epochs": epoch + 1,
|
473 |
+
"best_dsc": best_metric,
|
474 |
+
"best_path": best_metric_iterations,
|
475 |
+
},
|
476 |
+
os.path.join(arch_ckpt_path, "search_code_" + str(idx_iter) + ".pt"),
|
477 |
+
)
|
478 |
+
print("saved new best metric model")
|
479 |
+
|
480 |
+
dict_file = {}
|
481 |
+
dict_file["best_avg_dice_score"] = float(best_metric)
|
482 |
+
dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch)
|
483 |
+
dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
|
484 |
+
with open(os.path.join(arch_ckpt_path, "progress.yaml"), "w") as out_file:
|
485 |
+
_ = yaml.dump(dict_file, stream=out_file)
|
486 |
+
|
487 |
+
print(
|
488 |
+
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
|
489 |
+
epoch + 1, avg_metric, best_metric, best_metric_epoch
|
490 |
+
)
|
491 |
+
)
|
492 |
+
|
493 |
+
current_time = time.time()
|
494 |
+
elapsed_time = (current_time - start_time) / 60.0
|
495 |
+
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
|
496 |
+
f.write(
|
497 |
+
"{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format(
|
498 |
+
epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter
|
499 |
+
)
|
500 |
+
)
|
501 |
+
|
502 |
+
if torch.cuda.device_count() > 1:
|
503 |
+
dist.barrier()
|
504 |
+
|
505 |
+
torch.cuda.empty_cache()
|
506 |
+
|
507 |
+
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
|
508 |
+
|
509 |
+
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
|
510 |
+
writer.close()
|
511 |
+
|
512 |
+
if torch.cuda.device_count() > 1:
|
513 |
+
dist.destroy_process_group()
|
514 |
+
|
515 |
+
|
516 |
+
if __name__ == "__main__":
|
517 |
+
from monai.utils import optional_import
|
518 |
+
|
519 |
+
fire, _ = optional_import("fire")
|
520 |
+
fire.Fire()
|