monai
medical
katielink's picture
re-train model with updated dints implementation
49f392e
raw
history blame
8.24 kB
---
imports:
- "$import glob"
- "$import json"
- "$import os"
- "$import ignite"
- "$from scipy import ndimage"
input_channels: 1
output_classes: 3
arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
bundle_root: "."
ckpt_dir: "$@bundle_root + '/models'"
output_dir: "$@bundle_root + '/eval'"
dataset_dir: "/workspace/data/msd/Task07_Pancreas"
data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training',
base_dir=@dataset_dir)"
val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation',
base_dir=@dataset_dir)"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
dints_space:
_target_: monai.networks.nets.TopologyInstance
channel_mul: 1
num_blocks: 12
num_depths: 4
use_downsample: true
arch_code:
- "$@arch_ckpt['arch_code_a']"
- "$@arch_ckpt['arch_code_c']"
device: "$torch.device('cuda')"
network_def:
_target_: monai.networks.nets.DiNTS
dints_space: "@dints_space"
in_channels: "@input_channels"
num_classes: "@output_classes"
use_downsample: true
node_a: "$@arch_ckpt['node_a']"
network: "$@network_def.to(@device)"
loss:
_target_: DiceCELoss
include_background: false
to_onehot_y: true
softmax: true
squared_pred: true
batch: true
smooth_nr: 1.0e-05
smooth_dr: 1.0e-05
optimizer:
_target_: torch.optim.SGD
params: "$@network.parameters()"
momentum: 0.9
weight_decay: 4.0e-05
lr: 0.025
lr_scheduler:
_target_: torch.optim.lr_scheduler.StepLR
optimizer: "@optimizer"
step_size: 80
gamma: 0.5
image_key: image
label_key: label
val_interval: 10
train:
deterministic_transforms:
- _target_: LoadImaged
keys:
- "@image_key"
- "@label_key"
- _target_: EnsureChannelFirstd
keys:
- "@image_key"
- "@label_key"
- _target_: Orientationd
keys:
- "@image_key"
- "@label_key"
axcodes: RAS
- _target_: Spacingd
keys:
- "@image_key"
- "@label_key"
pixdim:
- 1
- 1
- 1
mode:
- bilinear
- nearest
align_corners:
- true
- true
- _target_: CastToTyped
keys: "@image_key"
dtype: "$torch.float32"
- _target_: ScaleIntensityRanged
keys: "@image_key"
a_min: -87
a_max: 199
b_min: 0
b_max: 1
clip: true
- _target_: CastToTyped
keys:
- "@image_key"
- "@label_key"
dtype:
- "$np.float16"
- "$np.uint8"
- _target_: CopyItemsd
keys: "@label_key"
times: 1
names:
- label4crop
- _target_: Lambdad
keys: label4crop
func: "$lambda x, s=@output_classes: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype),
iterations=48).astype(float) for _k in range(s)]), axis=0)"
overwrite: true
- _target_: EnsureTyped
keys:
- "@image_key"
- "@label_key"
- _target_: CastToTyped
keys: "@image_key"
dtype: "$torch.float32"
- _target_: SpatialPadd
keys:
- "@image_key"
- "@label_key"
- label4crop
spatial_size:
- 96
- 96
- 96
mode:
- reflect
- constant
- constant
random_transforms:
- _target_: RandCropByLabelClassesd
keys:
- "@image_key"
- "@label_key"
label_key: label4crop
num_classes: "@output_classes"
ratios: "$[1,] * @output_classes"
spatial_size:
- 96
- 96
- 96
num_samples: 1
- _target_: Lambdad
keys: label4crop
func: "$lambda x: 0"
- _target_: RandRotated
keys:
- "@image_key"
- "@label_key"
range_x: 0.3
range_y: 0.3
range_z: 0.3
mode:
- bilinear
- nearest
prob: 0.2
- _target_: RandZoomd
keys:
- "@image_key"
- "@label_key"
min_zoom: 0.8
max_zoom: 1.2
mode:
- trilinear
- nearest
align_corners:
- true
-
prob: 0.16
- _target_: RandGaussianSmoothd
keys: "@image_key"
sigma_x:
- 0.5
- 1.15
sigma_y:
- 0.5
- 1.15
sigma_z:
- 0.5
- 1.15
prob: 0.15
- _target_: RandScaleIntensityd
keys: "@image_key"
factors: 0.3
prob: 0.5
- _target_: RandShiftIntensityd
keys: "@image_key"
offsets: 0.1
prob: 0.5
- _target_: RandGaussianNoised
keys: "@image_key"
std: 0.01
prob: 0.15
- _target_: RandFlipd
keys:
- "@image_key"
- "@label_key"
spatial_axis: 0
prob: 0.5
- _target_: RandFlipd
keys:
- "@image_key"
- "@label_key"
spatial_axis: 1
prob: 0.5
- _target_: RandFlipd
keys:
- "@image_key"
- "@label_key"
spatial_axis: 2
prob: 0.5
- _target_: CastToTyped
keys:
- "@image_key"
- "@label_key"
dtype:
- "$torch.float32"
- "$torch.uint8"
- _target_: ToTensord
keys:
- "@image_key"
- "@label_key"
preprocessing:
_target_: Compose
transforms: "$@train#deterministic_transforms + @train#random_transforms"
dataset:
_target_: CacheDataset
data: "@train_datalist"
transform: "@train#preprocessing"
cache_rate: 0.125
num_workers: 4
dataloader:
_target_: DataLoader
dataset: "@train#dataset"
batch_size: 2
shuffle: true
num_workers: 4
inferer:
_target_: SimpleInferer
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: pred
softmax: true
- _target_: AsDiscreted
keys:
- pred
- label
argmax:
- true
- false
to_onehot: "@output_classes"
handlers:
- _target_: LrScheduleHandler
lr_scheduler: "@lr_scheduler"
print_lr: true
- _target_: ValidationHandler
validator: "@validate#evaluator"
epoch_level: true
interval: "@val_interval"
- _target_: StatsHandler
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
- _target_: TensorBoardStatsHandler
log_dir: "@output_dir"
tag_name: train_loss
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
key_metric:
train_accuracy:
_target_: ignite.metrics.Accuracy
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
trainer:
_target_: SupervisedTrainer
max_epochs: 400
device: "@device"
train_data_loader: "@train#dataloader"
network: "@network"
loss_function: "@loss"
optimizer: "@optimizer"
inferer: "@train#inferer"
postprocessing: "@train#postprocessing"
key_train_metric: "@train#key_metric"
train_handlers: "@train#handlers"
amp: true
validate:
preprocessing:
_target_: Compose
transforms: "%train#deterministic_transforms"
dataset:
_target_: CacheDataset
data: "@val_datalist"
transform: "@validate#preprocessing"
cache_rate: 0.125
dataloader:
_target_: DataLoader
dataset: "@validate#dataset"
batch_size: 1
shuffle: false
num_workers: 4
inferer:
_target_: SlidingWindowInferer
roi_size:
- 96
- 96
- 96
sw_batch_size: 6
overlap: 0.625
postprocessing: "%train#postprocessing"
handlers:
- _target_: StatsHandler
iteration_log: false
- _target_: TensorBoardStatsHandler
log_dir: "@output_dir"
iteration_log: false
- _target_: CheckpointSaver
save_dir: "@ckpt_dir"
save_dict:
model: "@network"
save_key_metric: true
key_metric_filename: model.pt
key_metric:
val_mean_dice:
_target_: MeanDice
include_background: false
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
additional_metrics:
val_accuracy:
_target_: ignite.metrics.Accuracy
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
evaluator:
_target_: SupervisedEvaluator
device: "@device"
val_data_loader: "@validate#dataloader"
network: "@network"
inferer: "@validate#inferer"
postprocessing: "@validate#postprocessing"
key_val_metric: "@validate#key_metric"
additional_metrics: "@validate#additional_metrics"
val_handlers: "@validate#handlers"
amp: true
initialize:
- "$monai.utils.set_determinism(seed=123)"
run:
- "$@train#trainer.run()"