monai
medical
katielink's picture
re-train model with updated dints implementation
49f392e
raw
history blame
2.77 kB
---
imports:
- "$import glob"
- "$import os"
input_channels: 1
output_classes: 3
arch_ckpt_path: "$@bundle_root + '/models/search_code_18590.pt'"
arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))"
bundle_root: "."
output_dir: "$@bundle_root + '/eval'"
dataset_dir: "/workspace/data/msd/Task07_Pancreas"
data_list_file_path: "$@bundle_root + '/configs/dataset_0.json'"
datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='testing',
base_dir=@dataset_dir)"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
dints_space:
_target_: monai.networks.nets.TopologyInstance
channel_mul: 1
num_blocks: 12
num_depths: 4
use_downsample: true
arch_code:
- "$@arch_ckpt['arch_code_a']"
- "$@arch_ckpt['arch_code_c']"
device: "$torch.device('cuda')"
network_def:
_target_: monai.networks.nets.DiNTS
dints_space: "@dints_space"
in_channels: "@input_channels"
num_classes: "@output_classes"
use_downsample: true
node_a: "$torch.from_numpy(@arch_ckpt['node_a'])"
network: "$@network_def.to(@device)"
preprocessing:
_target_: Compose
transforms:
- _target_: LoadImaged
keys: image
- _target_: EnsureChannelFirstd
keys: image
- _target_: Orientationd
keys: image
axcodes: RAS
- _target_: Spacingd
keys: image
pixdim:
- 1
- 1
- 1
mode: bilinear
- _target_: ScaleIntensityRanged
keys: image
a_min: -87
a_max: 199
b_min: 0
b_max: 1
clip: true
- _target_: EnsureTyped
keys: image
dataset:
_target_: Dataset
data: "@datalist"
transform: "@preprocessing"
dataloader:
_target_: DataLoader
dataset: "@dataset"
batch_size: 1
shuffle: false
num_workers: 4
inferer:
_target_: SlidingWindowInferer
roi_size:
- 96
- 96
- 96
sw_batch_size: 4
overlap: 0.625
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: pred
softmax: true
- _target_: Invertd
keys: pred
transform: "@preprocessing"
orig_keys: image
meta_key_postfix: meta_dict
nearest_interp: false
to_tensor: true
- _target_: AsDiscreted
keys: pred
argmax: true
- _target_: SaveImaged
keys: pred
meta_keys: pred_meta_dict
output_dir: "@output_dir"
handlers:
- _target_: CheckpointLoader
load_path: "$@bundle_root + '/models/model.pt'"
load_dict:
model: "@network"
- _target_: StatsHandler
iteration_log: false
evaluator:
_target_: SupervisedEvaluator
device: "@device"
val_data_loader: "@dataloader"
network: "@network"
inferer: "@inferer"
postprocessing: "@postprocessing"
val_handlers: "@handlers"
amp: true
initialize:
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "$@evaluator.run()"