Initial version
Browse files- README.md +72 -0
- configs/inference.json +121 -0
- configs/metadata.json +83 -0
- configs/train.json +270 -0
- docs/AMRGAtlas_0031.nii.gz +3 -0
- docs/AMRGAtlas_0031_key-pred.npy +3 -0
- docs/README.md +65 -0
- docs/license.txt +21 -0
- docs/prediction_example.png +0 -0
- docs/view_results.ipynb +0 -0
- models/model.pt +3 -0
- scripts/__init__.py +0 -0
- scripts/valve_landmarks.py +327 -0
README.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- monai
|
4 |
+
- medical
|
5 |
+
library_name: monai
|
6 |
+
license: unknown
|
7 |
+
---
|
8 |
+
|
9 |
+
# 2D Cardiac Valve Landmark Regressor
|
10 |
+
|
11 |
+
This network identifies 10 different landmarks in 2D+t MR images of the heart (2 chamber, 3 chamber, and 4 chamber) representing the insertion locations of valve leaflets into the myocardial wall. These coordinates are used in part of the construction of 3D FEM cardiac models suitable for physics simulation of heart functions.
|
12 |
+
|
13 |
+
Input images are individual 2D slices from the time series, and the output from the network is a `(2, 10)` set of 2D points in `HW` image coordinate space. The 10 coordinates correspond to the attachment point for these valves:
|
14 |
+
|
15 |
+
1. Mitral anterior in 2CH
|
16 |
+
2. Mitral posterior in 2CH
|
17 |
+
3. Mitral septal in 3CH
|
18 |
+
4. Mitral free wall in 3CH
|
19 |
+
5. Mitral septal in 4CH
|
20 |
+
6. Mitral free wall in 4CH
|
21 |
+
7. Aortic septal
|
22 |
+
8. Aortic free wall
|
23 |
+
9. Tricuspid septal
|
24 |
+
10. Tricuspid free wall
|
25 |
+
|
26 |
+
Landmarks which do not appear in a particular image are predicted to be `(0, 0)` or close to this location. The mitral valve is expected to appear in all three views. Landmarks are not provided for the pulmonary valve.
|
27 |
+
|
28 |
+
Example plot of landmarks on a single frame, see [view_results.ipynb](./view_results.ipynb) for visualising network output:
|
29 |
+
|
30 |
+
![Landmark Example Image](./prediction_example.png)
|
31 |
+
|
32 |
+
## Training
|
33 |
+
|
34 |
+
The training script `train.json` is provided to train the network using a dataset of image pairs containing the MR image and a landmark image. This is done to reuse image-based transforms which do not currently operate on geometry. A number of other transforms are provided in `valve_landmarks.py` to implement Fourier-space dropout, image shifting which preserve landmarks, and smooth-field deformation applied to images and landmarks.
|
35 |
+
|
36 |
+
The dataset used for training unfortunately cannot be made public, however the training script can be used with any NPZ file containing the training image stack in key `trainImgs` and landmark image stack in `trainLMImgs`, plus `testImgs` and `testLMImgs` containing validation data. The landmark images are defined as 0 for every non-landmark pixel, with landmark pixels contaning the following values for each landmark type:
|
37 |
+
|
38 |
+
* 10: Mitral anterior in 2CH
|
39 |
+
* 15: Mitral posterior in 2CH
|
40 |
+
* 20: Mitral septal in 3CH
|
41 |
+
* 25: Mitral free wall in 3CH
|
42 |
+
* 30: Mitral septal in 4CH
|
43 |
+
* 35: Mitral free wall in 4CH
|
44 |
+
* 100: Aortic septal
|
45 |
+
* 150: Aortic free wall
|
46 |
+
* 200: Tricuspid septal
|
47 |
+
* 250: Tricuspid free wall
|
48 |
+
|
49 |
+
The following command will train with the default NPZ filename `./valvelandmarks.npz`:
|
50 |
+
|
51 |
+
```sh
|
52 |
+
PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json \
|
53 |
+
--config_file configs/train.json --bundle_root . --dataset_file /path/to/data --output_dir /path/to/outputs
|
54 |
+
```
|
55 |
+
|
56 |
+
## Inference
|
57 |
+
|
58 |
+
The included `inference.json` script will run inference on a directory containing Nifti files whose images have shape `(256, 256, 1, N)` for `N` timesteps. For each image the output in the `output_dir` directory will be a npy file containing a result array of shape `(N, 2, 10)` storing the 10 coordinates for each `N` timesteps. Invoking this script can be done as follows, assuming the current directory is the bundle directory:
|
59 |
+
|
60 |
+
```sh
|
61 |
+
PYTHONPATH=./scripts python -m monai.bundle run evaluating --meta_file configs/metadata.json \
|
62 |
+
--config_file configs/inference.json --bundle_root . --dataset_dir /path/to/data --output_dir /path/to/outputs
|
63 |
+
```
|
64 |
+
|
65 |
+
It is important to set the `PYTHONPATH` variable since code in the provided scripts directory is necessary for inference. The provided test Nifti file can be placed in a directory which is then used as the `dataset_dir` value. This image was derived from [the AMRG Cardiac Atlas dataset](http://www.cardiacatlas.org/studies/amrg-cardiac-atlas) (AMRG Cardiac Atlas, Auckland MRI Research Group, Auckland, New Zealand). The results from this inference can be visualised by changing path values in [view_results.ipynb](./view_results.ipynb).
|
66 |
+
|
67 |
+
|
68 |
+
### Reference
|
69 |
+
|
70 |
+
The work for this model and its application is described in:
|
71 |
+
|
72 |
+
`Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. in E Puyol Anton, M Pop, M Sermesant, V Campello, A Lalande, K Lekadir, A Suinesiaputra, O Camara & A Young (eds), Statistical Atlases and Computational Models of the Heart. MandMs and EMIDEC Challenges - 11th International Workshop, STACOM 2020, Held in Conjunction with MICCAI 2020, Revised Selected Papers. Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), vol. 12592 LNCS, Springer Science and Business Media Deutschland GmbH, pp. 146-155, 11th International Workshop on Statistical Atlases and Computational Models of the Heart, STACOM 2020 held in Conjunction with MICCAI 2020, Lima, Peru, 4/10/2020. https://doi.org/10.1007/978-3-030-68107-4_15`
|
configs/inference.json
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import os",
|
4 |
+
"$import glob",
|
5 |
+
"$import scripts"
|
6 |
+
],
|
7 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
8 |
+
"ckpt_path": "$@bundle_root + '/models/model.pt'",
|
9 |
+
"dataset_dir": "/workspace/data",
|
10 |
+
"datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii*')))",
|
11 |
+
"output_dir": "./output",
|
12 |
+
"network_def": {
|
13 |
+
"_target_": "scripts.valve_landmarks.PointRegressor",
|
14 |
+
"in_shape": [
|
15 |
+
1,
|
16 |
+
256,
|
17 |
+
256
|
18 |
+
],
|
19 |
+
"out_shape": [
|
20 |
+
2,
|
21 |
+
10
|
22 |
+
],
|
23 |
+
"channels": [
|
24 |
+
8,
|
25 |
+
16,
|
26 |
+
32,
|
27 |
+
64,
|
28 |
+
128
|
29 |
+
],
|
30 |
+
"strides": [
|
31 |
+
2,
|
32 |
+
2,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2
|
36 |
+
]
|
37 |
+
},
|
38 |
+
"network": "$@network_def.to(@device)",
|
39 |
+
"preprocessing": {
|
40 |
+
"_target_": "Compose",
|
41 |
+
"transforms": [
|
42 |
+
{
|
43 |
+
"_target_": "LoadImage",
|
44 |
+
"image_only": "true"
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"_target_": "EnsureType",
|
48 |
+
"device": "@device"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"_target_": "Transpose",
|
52 |
+
"indices": [
|
53 |
+
2,
|
54 |
+
0,
|
55 |
+
1,
|
56 |
+
3
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"_target_": "ScaleIntensity"
|
61 |
+
}
|
62 |
+
]
|
63 |
+
},
|
64 |
+
"dataset": {
|
65 |
+
"_target_": "Dataset",
|
66 |
+
"data": "@datalist",
|
67 |
+
"transform": "@preprocessing"
|
68 |
+
},
|
69 |
+
"dataloader": {
|
70 |
+
"_target_": "DataLoader",
|
71 |
+
"dataset": "@dataset",
|
72 |
+
"batch_size": 1,
|
73 |
+
"shuffle": false,
|
74 |
+
"num_workers": 0
|
75 |
+
},
|
76 |
+
"inferer": {
|
77 |
+
"_target_": "scripts.valve_landmarks.LandmarkInferer",
|
78 |
+
"spatial_dim": 2,
|
79 |
+
"stack_dim": 1
|
80 |
+
},
|
81 |
+
"postprocessing": {
|
82 |
+
"_target_": "Compose",
|
83 |
+
"transforms": [
|
84 |
+
{
|
85 |
+
"_target_": "SqueezeDimd",
|
86 |
+
"keys": "pred",
|
87 |
+
"dim": 0
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"_target_": "scripts.valve_landmarks.NpySaverd",
|
91 |
+
"keys": "pred",
|
92 |
+
"output_dir": "@output_dir",
|
93 |
+
"data_root_dir": "@dataset_dir"
|
94 |
+
}
|
95 |
+
]
|
96 |
+
},
|
97 |
+
"handlers": [
|
98 |
+
{
|
99 |
+
"_target_": "CheckpointLoader",
|
100 |
+
"_disabled_": "$not os.path.exists(@ckpt_path)",
|
101 |
+
"load_path": "@ckpt_path",
|
102 |
+
"load_dict": {
|
103 |
+
"net": "@network"
|
104 |
+
}
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"evaluator": {
|
108 |
+
"_target_": "SupervisedEvaluator",
|
109 |
+
"device": "@device",
|
110 |
+
"val_data_loader": "@dataloader",
|
111 |
+
"network": "@network",
|
112 |
+
"inferer": "@inferer",
|
113 |
+
"postprocessing": "@postprocessing",
|
114 |
+
"val_handlers": "@handlers",
|
115 |
+
"decollate": false,
|
116 |
+
"prepare_batch": "$lambda batch, dev,nb: (batch.to(dev),())"
|
117 |
+
},
|
118 |
+
"evaluating": [
|
119 |
+
"$@evaluator.run()"
|
120 |
+
]
|
121 |
+
}
|
configs/metadata.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220729.json",
|
3 |
+
"version": "0.1.0",
|
4 |
+
"changelog": {
|
5 |
+
"0.1.0": "Initial version"
|
6 |
+
},
|
7 |
+
"monai_version": "1.0.0rc1",
|
8 |
+
"pytorch_version": "1.10.2",
|
9 |
+
"numpy_version": "1.21.2",
|
10 |
+
"optional_packages_version": {},
|
11 |
+
"task": "Given long axis MR images of the heart, identify valve insertion points through the full cardiac cycle",
|
12 |
+
"description": "This network is used to find where valves attach to heart to help construct 3D FEM models for computation. The output is an array of 10 2D coordinates.",
|
13 |
+
"authors": "Eric Kerfoot",
|
14 |
+
"copyright": "Copyright (c) Eric Kerfoot",
|
15 |
+
"references": [
|
16 |
+
"Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. https://doi.org/10.1007/978-3-030-68107-4_15"
|
17 |
+
],
|
18 |
+
"intended_use": "This is suitable for research purposes only",
|
19 |
+
"image_classes": "Single channel data, intensity scaled to [0, 1]",
|
20 |
+
"data_source": "Non-public dataset comprised of hand-annotated full cycle long axis MR images",
|
21 |
+
"coordinate_values": {
|
22 |
+
"0": 10,
|
23 |
+
"1": 15,
|
24 |
+
"2": 20,
|
25 |
+
"3": 25,
|
26 |
+
"4": 30,
|
27 |
+
"5": 35,
|
28 |
+
"6": 100,
|
29 |
+
"7": 150,
|
30 |
+
"8": 200,
|
31 |
+
"9": 250
|
32 |
+
},
|
33 |
+
"coordinate_meanings": {
|
34 |
+
"0": "mitral anterior 2CH",
|
35 |
+
"1": "mitral posterior 2CH",
|
36 |
+
"2": "mitral septal 3CH",
|
37 |
+
"3": "mitral free wall 3CH",
|
38 |
+
"4": "mitral septal 4CH",
|
39 |
+
"5": "mitral free wall 4CH",
|
40 |
+
"6": "aortic septal",
|
41 |
+
"7": "aortic free wall",
|
42 |
+
"8": "tricuspid septal",
|
43 |
+
"9": "tricuspid free wall"
|
44 |
+
},
|
45 |
+
"network_data_format": {
|
46 |
+
"inputs": {
|
47 |
+
"image": {
|
48 |
+
"type": "image",
|
49 |
+
"format": "magnitude",
|
50 |
+
"modality": "MR",
|
51 |
+
"num_channels": 1,
|
52 |
+
"spatial_shape": [
|
53 |
+
256,
|
54 |
+
256
|
55 |
+
],
|
56 |
+
"dtype": "float32",
|
57 |
+
"value_range": [],
|
58 |
+
"is_patch_data": false,
|
59 |
+
"channel_def": {
|
60 |
+
"0": "image"
|
61 |
+
}
|
62 |
+
}
|
63 |
+
},
|
64 |
+
"outputs": {
|
65 |
+
"pred": {
|
66 |
+
"type": "tuples",
|
67 |
+
"format": "points",
|
68 |
+
"num_channels": 2,
|
69 |
+
"spatial_shape": [
|
70 |
+
2,
|
71 |
+
10
|
72 |
+
],
|
73 |
+
"dtype": "float32",
|
74 |
+
"value_range": [],
|
75 |
+
"is_patch_data": false,
|
76 |
+
"channel_def": {
|
77 |
+
"0": "Y Dimension",
|
78 |
+
"1": "X Dimension"
|
79 |
+
}
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
configs/train.json
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import datetime",
|
4 |
+
"$import numpy as np",
|
5 |
+
"$import torch",
|
6 |
+
"$import ignite",
|
7 |
+
"$import scripts"
|
8 |
+
],
|
9 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
10 |
+
"ckpt_path": "$@bundle_root + '/models/model.pt'",
|
11 |
+
"dataset_file": "./valvelandmarks.npz",
|
12 |
+
"output_dir": "$datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')",
|
13 |
+
"network_def": {
|
14 |
+
"_target_": "scripts.valve_landmarks.PointRegressor",
|
15 |
+
"in_shape": [
|
16 |
+
1,
|
17 |
+
256,
|
18 |
+
256
|
19 |
+
],
|
20 |
+
"out_shape": [
|
21 |
+
2,
|
22 |
+
10
|
23 |
+
],
|
24 |
+
"channels": [
|
25 |
+
8,
|
26 |
+
16,
|
27 |
+
32,
|
28 |
+
64,
|
29 |
+
128
|
30 |
+
],
|
31 |
+
"strides": [
|
32 |
+
2,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2,
|
36 |
+
2
|
37 |
+
]
|
38 |
+
},
|
39 |
+
"network": "$@network_def.to(@device)",
|
40 |
+
"im_shape": [
|
41 |
+
1,
|
42 |
+
256,
|
43 |
+
256
|
44 |
+
],
|
45 |
+
"both_keys": [
|
46 |
+
"image",
|
47 |
+
"label"
|
48 |
+
],
|
49 |
+
"rand_prob": 0.5,
|
50 |
+
"train_transforms": {
|
51 |
+
"_target_": "Compose",
|
52 |
+
"transforms": [
|
53 |
+
{
|
54 |
+
"_target_": "EnsureTyped",
|
55 |
+
"keys": "@both_keys",
|
56 |
+
"data_type": "numpy",
|
57 |
+
"dtype": "$(np.float32, np.int32)"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"_target_": "EnsureTyped",
|
61 |
+
"keys": "@both_keys"
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"_target_": "ScaleIntensityd",
|
65 |
+
"keys": "image"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"_target_": "EnsureChannelFirstd",
|
69 |
+
"keys": "@both_keys"
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"_target_": "RandAxisFlipd",
|
73 |
+
"keys": "@both_keys",
|
74 |
+
"prob": "@rand_prob"
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"_target_": "RandRotate90d",
|
78 |
+
"keys": "@both_keys",
|
79 |
+
"prob": "@rand_prob"
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"_target_": "RandSmoothFieldAdjustIntensityd",
|
83 |
+
"keys": "image",
|
84 |
+
"prob": "@rand_prob",
|
85 |
+
"spatial_size": "@im_shape",
|
86 |
+
"rand_size": [
|
87 |
+
5,
|
88 |
+
5
|
89 |
+
],
|
90 |
+
"gamma": [
|
91 |
+
0.1,
|
92 |
+
1
|
93 |
+
],
|
94 |
+
"mode": "$monai.utils.InterpolateMode.BICUBIC",
|
95 |
+
"align_corners": true
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"_target_": "RandGaussianNoised",
|
99 |
+
"keys": "image",
|
100 |
+
"prob": "@rand_prob",
|
101 |
+
"std": 0.05
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"_target_": "scripts.valve_landmarks.RandFourierDropoutd",
|
105 |
+
"keys": "image",
|
106 |
+
"prob": "@rand_prob"
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"_target_": "scripts.valve_landmarks.RandImageLMDeformd",
|
110 |
+
"prob": "@rand_prob",
|
111 |
+
"spatial_size": [
|
112 |
+
256,
|
113 |
+
256
|
114 |
+
],
|
115 |
+
"rand_size": [
|
116 |
+
7,
|
117 |
+
7
|
118 |
+
],
|
119 |
+
"pad": 2,
|
120 |
+
"field_mode": "$monai.utils.InterpolateMode.BICUBIC",
|
121 |
+
"align_corners": true,
|
122 |
+
"def_range": 0.05
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"_target_": "scripts.valve_landmarks.RandLMShiftd",
|
126 |
+
"keys": "@both_keys",
|
127 |
+
"prob": "@rand_prob",
|
128 |
+
"spatial_size": [
|
129 |
+
256,
|
130 |
+
256
|
131 |
+
],
|
132 |
+
"max_shift": 8
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"_target_": "Lambdad",
|
136 |
+
"keys": "label",
|
137 |
+
"func": "$scripts.valve_landmarks.convert_lm_image_t"
|
138 |
+
}
|
139 |
+
]
|
140 |
+
},
|
141 |
+
"eval_transforms": {
|
142 |
+
"_target_": "Compose",
|
143 |
+
"transforms": [
|
144 |
+
{
|
145 |
+
"_target_": "EnsureTyped",
|
146 |
+
"keys": "@both_keys",
|
147 |
+
"data_type": "numpy",
|
148 |
+
"dtype": "$(np.float32, np.int32)"
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"_target_": "EnsureTyped",
|
152 |
+
"keys": "@both_keys"
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"_target_": "ScaleIntensityd",
|
156 |
+
"keys": "image"
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"_target_": "EnsureChannelFirstd",
|
160 |
+
"keys": "@both_keys"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"_target_": "Lambdad",
|
164 |
+
"keys": "label",
|
165 |
+
"func": "$scripts.valve_landmarks.convert_lm_image_t"
|
166 |
+
}
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"train_dataset": {
|
170 |
+
"_target_": "NPZDictItemDataset",
|
171 |
+
"npzfile": "$@dataset_file",
|
172 |
+
"keys": {
|
173 |
+
"trainImgs": "image",
|
174 |
+
"trainLMImgs": "label"
|
175 |
+
},
|
176 |
+
"transform": "@train_transforms"
|
177 |
+
},
|
178 |
+
"eval_dataset": {
|
179 |
+
"_target_": "NPZDictItemDataset",
|
180 |
+
"npzfile": "$@dataset_file",
|
181 |
+
"keys": {
|
182 |
+
"testImgs": "image",
|
183 |
+
"testLMImgs": "label"
|
184 |
+
},
|
185 |
+
"transform": "@eval_transforms"
|
186 |
+
},
|
187 |
+
"num_iters": 400,
|
188 |
+
"batch_size": 600,
|
189 |
+
"num_epochs": 100,
|
190 |
+
"num_substeps": 3,
|
191 |
+
"sampler": {
|
192 |
+
"_target_": "torch.utils.data.WeightedRandomSampler",
|
193 |
+
"weights": "$torch.ones(len(@train_dataset))",
|
194 |
+
"replacement": true,
|
195 |
+
"num_samples": "$@num_iters*@batch_size"
|
196 |
+
},
|
197 |
+
"train_dataloader": {
|
198 |
+
"_target_": "ThreadDataLoader",
|
199 |
+
"dataset": "@train_dataset",
|
200 |
+
"batch_size": "@batch_size",
|
201 |
+
"repeats": "@num_substeps",
|
202 |
+
"num_workers": 8,
|
203 |
+
"sampler": "@sampler"
|
204 |
+
},
|
205 |
+
"eval_dataloader": {
|
206 |
+
"_target_": "DataLoader",
|
207 |
+
"dataset": "@eval_dataset",
|
208 |
+
"batch_size": "@batch_size",
|
209 |
+
"num_workers": 8
|
210 |
+
},
|
211 |
+
"lossfn": {
|
212 |
+
"_target_": "torch.nn.L1Loss"
|
213 |
+
},
|
214 |
+
"optimizer": {
|
215 |
+
"_target_": "torch.optim.Adam",
|
216 |
+
"params": "$@network.parameters()",
|
217 |
+
"lr": 0.0001
|
218 |
+
},
|
219 |
+
"evaluator": {
|
220 |
+
"_target_": "SupervisedEvaluator",
|
221 |
+
"device": "@device",
|
222 |
+
"val_data_loader": "@eval_dataloader",
|
223 |
+
"network": "@network",
|
224 |
+
"key_val_metric": {
|
225 |
+
"val_mean_dist": {
|
226 |
+
"_target_": "ignite.metrics.MeanPairwiseDistance",
|
227 |
+
"output_transform": "$scripts.valve_landmarks._output_lm_trans"
|
228 |
+
}
|
229 |
+
},
|
230 |
+
"metric_cmp_fn": "$lambda current, prev: prev < 0 or current < prev",
|
231 |
+
"val_handlers": [
|
232 |
+
{
|
233 |
+
"_target_": "StatsHandler",
|
234 |
+
"output_transform": "$lambda x: None"
|
235 |
+
}
|
236 |
+
]
|
237 |
+
},
|
238 |
+
"handlers": [
|
239 |
+
{
|
240 |
+
"_target_": "ValidationHandler",
|
241 |
+
"validator": "@evaluator",
|
242 |
+
"epoch_level": true,
|
243 |
+
"interval": 1
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"_target_": "CheckpointSaver",
|
247 |
+
"save_dir": "@output_dir",
|
248 |
+
"save_dict": {
|
249 |
+
"net": "@network"
|
250 |
+
},
|
251 |
+
"save_interval": 1,
|
252 |
+
"save_final": true,
|
253 |
+
"epoch_level": true
|
254 |
+
}
|
255 |
+
],
|
256 |
+
"trainer": {
|
257 |
+
"_target_": "SupervisedTrainer",
|
258 |
+
"max_epochs": "@num_epochs",
|
259 |
+
"device": "@device",
|
260 |
+
"train_data_loader": "@train_dataloader",
|
261 |
+
"network": "@network",
|
262 |
+
"loss_function": "@lossfn",
|
263 |
+
"optimizer": "@optimizer",
|
264 |
+
"key_train_metric": null,
|
265 |
+
"train_handlers": "@handlers"
|
266 |
+
},
|
267 |
+
"training": [
|
268 |
+
"$@trainer.run()"
|
269 |
+
]
|
270 |
+
}
|
docs/AMRGAtlas_0031.nii.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddac587d7ced3c7f892113e379256f055895162a5077f741c1bdc79e1a70b558
|
3 |
+
size 2190289
|
docs/AMRGAtlas_0031_key-pred.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48c44c9b85bb080e33a97b5421fde7037e039e235dbfbf08eff22a8fd2901b89
|
3 |
+
size 2688
|
docs/README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 2D Cardiac Valve Landmark Regressor
|
3 |
+
|
4 |
+
This network identifies 10 different landmarks in 2D+t MR images of the heart (2 chamber, 3 chamber, and 4 chamber) representing the insertion locations of valve leaflets into the myocardial wall. These coordinates are used in part of the construction of 3D FEM cardiac models suitable for physics simulation of heart functions.
|
5 |
+
|
6 |
+
Input images are individual 2D slices from the time series, and the output from the network is a `(2, 10)` set of 2D points in `HW` image coordinate space. The 10 coordinates correspond to the attachment point for these valves:
|
7 |
+
|
8 |
+
1. Mitral anterior in 2CH
|
9 |
+
2. Mitral posterior in 2CH
|
10 |
+
3. Mitral septal in 3CH
|
11 |
+
4. Mitral free wall in 3CH
|
12 |
+
5. Mitral septal in 4CH
|
13 |
+
6. Mitral free wall in 4CH
|
14 |
+
7. Aortic septal
|
15 |
+
8. Aortic free wall
|
16 |
+
9. Tricuspid septal
|
17 |
+
10. Tricuspid free wall
|
18 |
+
|
19 |
+
Landmarks which do not appear in a particular image are predicted to be `(0, 0)` or close to this location. The mitral valve is expected to appear in all three views. Landmarks are not provided for the pulmonary valve.
|
20 |
+
|
21 |
+
Example plot of landmarks on a single frame, see [view_results.ipynb](./view_results.ipynb) for visualising network output:
|
22 |
+
|
23 |
+
![Landmark Example Image](./prediction_example.png)
|
24 |
+
|
25 |
+
## Training
|
26 |
+
|
27 |
+
The training script `train.json` is provided to train the network using a dataset of image pairs containing the MR image and a landmark image. This is done to reuse image-based transforms which do not currently operate on geometry. A number of other transforms are provided in `valve_landmarks.py` to implement Fourier-space dropout, image shifting which preserve landmarks, and smooth-field deformation applied to images and landmarks.
|
28 |
+
|
29 |
+
The dataset used for training unfortunately cannot be made public, however the training script can be used with any NPZ file containing the training image stack in key `trainImgs` and landmark image stack in `trainLMImgs`, plus `testImgs` and `testLMImgs` containing validation data. The landmark images are defined as 0 for every non-landmark pixel, with landmark pixels contaning the following values for each landmark type:
|
30 |
+
|
31 |
+
* 10: Mitral anterior in 2CH
|
32 |
+
* 15: Mitral posterior in 2CH
|
33 |
+
* 20: Mitral septal in 3CH
|
34 |
+
* 25: Mitral free wall in 3CH
|
35 |
+
* 30: Mitral septal in 4CH
|
36 |
+
* 35: Mitral free wall in 4CH
|
37 |
+
* 100: Aortic septal
|
38 |
+
* 150: Aortic free wall
|
39 |
+
* 200: Tricuspid septal
|
40 |
+
* 250: Tricuspid free wall
|
41 |
+
|
42 |
+
The following command will train with the default NPZ filename `./valvelandmarks.npz`:
|
43 |
+
|
44 |
+
```sh
|
45 |
+
PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json \
|
46 |
+
--config_file configs/train.json --bundle_root . --dataset_file /path/to/data --output_dir /path/to/outputs
|
47 |
+
```
|
48 |
+
|
49 |
+
## Inference
|
50 |
+
|
51 |
+
The included `inference.json` script will run inference on a directory containing Nifti files whose images have shape `(256, 256, 1, N)` for `N` timesteps. For each image the output in the `output_dir` directory will be a npy file containing a result array of shape `(N, 2, 10)` storing the 10 coordinates for each `N` timesteps. Invoking this script can be done as follows, assuming the current directory is the bundle directory:
|
52 |
+
|
53 |
+
```sh
|
54 |
+
PYTHONPATH=./scripts python -m monai.bundle run evaluating --meta_file configs/metadata.json \
|
55 |
+
--config_file configs/inference.json --bundle_root . --dataset_dir /path/to/data --output_dir /path/to/outputs
|
56 |
+
```
|
57 |
+
|
58 |
+
It is important to set the `PYTHONPATH` variable since code in the provided scripts directory is necessary for inference. The provided test Nifti file can be placed in a directory which is then used as the `dataset_dir` value. This image was derived from [the AMRG Cardiac Atlas dataset](http://www.cardiacatlas.org/studies/amrg-cardiac-atlas) (AMRG Cardiac Atlas, Auckland MRI Research Group, Auckland, New Zealand). The results from this inference can be visualised by changing path values in [view_results.ipynb](./view_results.ipynb).
|
59 |
+
|
60 |
+
|
61 |
+
### Reference
|
62 |
+
|
63 |
+
The work for this model and its application is described in:
|
64 |
+
|
65 |
+
`Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. in E Puyol Anton, M Pop, M Sermesant, V Campello, A Lalande, K Lekadir, A Suinesiaputra, O Camara & A Young (eds), Statistical Atlases and Computational Models of the Heart. MandMs and EMIDEC Challenges - 11th International Workshop, STACOM 2020, Held in Conjunction with MICCAI 2020, Revised Selected Papers. Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), vol. 12592 LNCS, Springer Science and Business Media Deutschland GmbH, pp. 146-155, 11th International Workshop on Statistical Atlases and Computational Models of the Heart, STACOM 2020 held in Conjunction with MICCAI 2020, Lima, Peru, 4/10/2020. https://doi.org/10.1007/978-3-030-68107-4_15`
|
docs/license.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Eric Kerfoot
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
docs/prediction_example.png
ADDED
docs/view_results.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0187ee65b150c0c693b16e6a65d0b3f659bed163423a1e4a5a6b3bb4fbeb7bf
|
3 |
+
size 13349733
|
scripts/__init__.py
ADDED
File without changes
|
scripts/valve_landmarks.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Eric Kerfoot under MIT license, see license.txt
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import Any, Callable, Sequence
|
5 |
+
|
6 |
+
import monai
|
7 |
+
import monai.transforms as mt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from monai.data.meta_obj import get_track_meta
|
12 |
+
from monai.networks.blocks import ConvDenseBlock, Convolution
|
13 |
+
from monai.networks.layers import Flatten, Reshape
|
14 |
+
from monai.networks.nets import Regressor
|
15 |
+
from monai.networks.utils import meshgrid_ij
|
16 |
+
from monai.utils import CommonKeys
|
17 |
+
from monai.utils import ImageMetaKey as Key
|
18 |
+
from monai.utils import convert_to_numpy, convert_to_tensor
|
19 |
+
|
20 |
+
# relates the label in training images to index of landmark point
|
21 |
+
LM_INDICES = {
|
22 |
+
10: 0, # mitral anterior 2CH
|
23 |
+
15: 1, # mitral posterior 2CH
|
24 |
+
20: 2, # mitral septal 3CH
|
25 |
+
25: 3, # mitral free wall 3CH
|
26 |
+
30: 4, # mitral septal 4CH
|
27 |
+
35: 5, # mitral free wall 4CH
|
28 |
+
100: 6, # aortic septal
|
29 |
+
150: 7, # aortic free wall
|
30 |
+
200: 8, # tricuspid septal
|
31 |
+
250: 9, # tricuspid free wall
|
32 |
+
}
|
33 |
+
|
34 |
+
output_trans = monai.handlers.from_engine(["pred", "label"])
|
35 |
+
|
36 |
+
|
37 |
+
def _output_lm_trans(data):
|
38 |
+
pred, label = output_trans(data)
|
39 |
+
return [p.permute(1, 0) for p in pred], [l.permute(1, 0) for l in label]
|
40 |
+
|
41 |
+
|
42 |
+
def convert_lm_image_t(lm_image):
|
43 |
+
"""Convert a landmark image into a (2,N) tensor of landmark coordinates."""
|
44 |
+
lmarray = torch.zeros((2, len(LM_INDICES)), dtype=torch.float32).to(lm_image.device)
|
45 |
+
|
46 |
+
for _, y, x in np.argwhere(lm_image.cpu().numpy() != 0):
|
47 |
+
im_id = int(lm_image[0, y, x])
|
48 |
+
lm_index = LM_INDICES[im_id]
|
49 |
+
|
50 |
+
lmarray[0, lm_index] = y
|
51 |
+
lmarray[1, lm_index] = x
|
52 |
+
|
53 |
+
return lmarray
|
54 |
+
|
55 |
+
|
56 |
+
class ParallelCat(nn.Module):
|
57 |
+
"""
|
58 |
+
Apply the same input to each of the given modules and concatenate their results together.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
catmodules: sequence of nn.Module objects to apply inputs to
|
62 |
+
cat_dim: dimension to concatenate along when combining outputs
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, catmodules: Sequence[nn.Module], cat_dim: int = 1):
|
66 |
+
super().__init__()
|
67 |
+
self.cat_dim = cat_dim
|
68 |
+
|
69 |
+
for i, s in enumerate(catmodules):
|
70 |
+
self.add_module(f"catmodule_{i}", s)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
tensors = [s(x) for s in self.children()]
|
74 |
+
return torch.cat(tensors, self.cat_dim)
|
75 |
+
|
76 |
+
|
77 |
+
class PointRegressor(Regressor):
|
78 |
+
"""Regressor defined as a sequence of dense blocks followed by convolution/linear layers for each landmark."""
|
79 |
+
|
80 |
+
def _get_layer(self, in_channels, out_channels, strides, is_last):
|
81 |
+
dout = out_channels - in_channels
|
82 |
+
dilations = [1, 2, 4]
|
83 |
+
dchannels = [dout // 3, dout // 3, dout // 3 + dout % 3]
|
84 |
+
|
85 |
+
db = ConvDenseBlock(
|
86 |
+
spatial_dims=self.dimensions,
|
87 |
+
in_channels=in_channels,
|
88 |
+
channels=dchannels,
|
89 |
+
dilations=dilations,
|
90 |
+
kernel_size=self.kernel_size,
|
91 |
+
num_res_units=self.num_res_units,
|
92 |
+
act=self.act,
|
93 |
+
norm=self.norm,
|
94 |
+
dropout=self.dropout,
|
95 |
+
bias=self.bias,
|
96 |
+
)
|
97 |
+
|
98 |
+
conv = Convolution(
|
99 |
+
spatial_dims=self.dimensions,
|
100 |
+
in_channels=out_channels,
|
101 |
+
out_channels=out_channels,
|
102 |
+
strides=strides,
|
103 |
+
kernel_size=self.kernel_size,
|
104 |
+
act=self.act,
|
105 |
+
norm=self.norm,
|
106 |
+
dropout=self.dropout,
|
107 |
+
bias=self.bias,
|
108 |
+
conv_only=is_last,
|
109 |
+
)
|
110 |
+
|
111 |
+
return nn.Sequential(db, conv)
|
112 |
+
|
113 |
+
def _get_final_layer(self, in_shape):
|
114 |
+
point_paths = []
|
115 |
+
|
116 |
+
for _ in range(self.out_shape[1]):
|
117 |
+
conv = Convolution(
|
118 |
+
spatial_dims=self.dimensions,
|
119 |
+
in_channels=in_shape[0],
|
120 |
+
out_channels=in_shape[0] * 2,
|
121 |
+
strides=2,
|
122 |
+
kernel_size=self.kernel_size,
|
123 |
+
act=self.act,
|
124 |
+
norm=self.norm,
|
125 |
+
dropout=self.dropout,
|
126 |
+
conv_only=True,
|
127 |
+
)
|
128 |
+
linear = nn.Linear(int(np.product(in_shape)) // 2, self.out_shape[0])
|
129 |
+
point_paths.append(nn.Sequential(conv, Flatten(), linear))
|
130 |
+
|
131 |
+
return torch.nn.Sequential(ParallelCat(point_paths), Reshape(*self.out_shape))
|
132 |
+
|
133 |
+
|
134 |
+
class LandmarkInferer(monai.inferers.Inferer):
|
135 |
+
"""Applies inference on 2D slices from 3D volumes."""
|
136 |
+
|
137 |
+
def __init__(self, spatial_dim=0, stack_dim=-1):
|
138 |
+
self.spatial_dim = spatial_dim
|
139 |
+
self.stack_dim = stack_dim
|
140 |
+
|
141 |
+
def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any):
|
142 |
+
if inputs.ndim != 5:
|
143 |
+
raise ValueError(f"Input volume to inferer must have shape BCDHW, input shape is {inputs.shape}")
|
144 |
+
|
145 |
+
results = []
|
146 |
+
input_slices = [slice(None) for _ in range(inputs.ndim)]
|
147 |
+
|
148 |
+
for idx in range(inputs.shape[self.spatial_dim + 2]):
|
149 |
+
input_slices[self.spatial_dim + 2] = idx
|
150 |
+
input_2d = inputs[input_slices] # BCDHW -> BCHW by iterating over one of DHW
|
151 |
+
|
152 |
+
result = network(input_2d, *args, **kwargs)
|
153 |
+
results.append(result)
|
154 |
+
|
155 |
+
result = torch.stack(results, self.stack_dim)
|
156 |
+
return result
|
157 |
+
|
158 |
+
|
159 |
+
class NpySaverd(mt.MapTransform):
|
160 |
+
"""Saves tensors/arrays to Numpy npy files."""
|
161 |
+
|
162 |
+
def __init__(self, keys, output_dir, data_root_dir):
|
163 |
+
super().__init__(keys)
|
164 |
+
self.output_dir = output_dir
|
165 |
+
self.data_root_dir = data_root_dir
|
166 |
+
self.folder_layout = monai.data.FolderLayout(
|
167 |
+
self.output_dir, extension=".npy", data_root_dir=self.data_root_dir
|
168 |
+
)
|
169 |
+
|
170 |
+
def __call__(self, d):
|
171 |
+
if not os.path.exists(self.output_dir):
|
172 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
173 |
+
|
174 |
+
for key in self.key_iterator(d):
|
175 |
+
orig_filename = d[key].meta[Key.FILENAME_OR_OBJ]
|
176 |
+
if isinstance(orig_filename, (list, tuple)):
|
177 |
+
orig_filename = orig_filename[0]
|
178 |
+
|
179 |
+
out_filename = self.folder_layout.filename(orig_filename, key=key)
|
180 |
+
|
181 |
+
np.save(out_filename, convert_to_numpy(d[key]))
|
182 |
+
|
183 |
+
return d
|
184 |
+
|
185 |
+
|
186 |
+
class FourierDropout(mt.Transform, mt.Fourier):
|
187 |
+
"""
|
188 |
+
Apply dropout in Fourier space to corrupt images. This works by zeroing out pixels with greater probability the
|
189 |
+
farther from the centre they are. All pixels closer than `min_dist` to the center are preserved, all beyond
|
190 |
+
`max_dist` become 0. Distances from the centre to an edge in a given dimension are defined as 1.0.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
min_dist: minimum distance to apply dropout, must be >0, smaller values will cause greater corruption
|
194 |
+
max_dist: maximal distance to apply dropout, must be greater than `min_dist`, all pixels beyond become 0
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, min_dist: float = 0.1, max_dist: float = 0.9):
|
198 |
+
super().__init__()
|
199 |
+
self.min_dist = min_dist
|
200 |
+
self.max_dist = max_dist
|
201 |
+
self.prob_field = None
|
202 |
+
self.field_shape = None
|
203 |
+
|
204 |
+
def _get_prob_field(self, shape):
|
205 |
+
shape = tuple(shape)
|
206 |
+
if shape != self.field_shape:
|
207 |
+
self.field_shape = shape
|
208 |
+
spaces = [torch.linspace(-1, 1, s) for s in shape[1:]]
|
209 |
+
grids = meshgrid_ij(*spaces)
|
210 |
+
# middle is 0, mid edges 1, corners sqrt(2)
|
211 |
+
self.prob_field = torch.stack(grids).pow_(2).sum(axis=0).sqrt_()
|
212 |
+
|
213 |
+
return self.prob_field
|
214 |
+
|
215 |
+
def __call__(self, im):
|
216 |
+
probfield = self._get_prob_field(im.shape).to(im.device)
|
217 |
+
|
218 |
+
# rand range from min_dist to max_dist
|
219 |
+
dropout = torch.rand_like(im).mul_(self.max_dist - self.min_dist).add_(self.min_dist)
|
220 |
+
# keep pixel if dropout value is greater than distance from center, so less likely farther from center
|
221 |
+
dropout = dropout.ge_(probfield)
|
222 |
+
|
223 |
+
result = self.shift_fourier(im, im.ndim - 1)
|
224 |
+
result.mul_(dropout)
|
225 |
+
result = self.inv_shift_fourier(result, im.ndim - 1)
|
226 |
+
|
227 |
+
return convert_to_tensor(result, track_meta=get_track_meta())
|
228 |
+
|
229 |
+
|
230 |
+
class RandFourierDropout(mt.RandomizableTransform):
|
231 |
+
def __init__(self, min_dist=0.1, max_dist=0.9, prob=0.1):
|
232 |
+
mt.RandomizableTransform.__init__(self, prob)
|
233 |
+
self.dropper = FourierDropout(min_dist, max_dist)
|
234 |
+
|
235 |
+
def __call__(self, im, randomize: bool = True):
|
236 |
+
if randomize:
|
237 |
+
self.randomize(None)
|
238 |
+
|
239 |
+
if self._do_transform:
|
240 |
+
im = self.dropper(im)
|
241 |
+
else:
|
242 |
+
im = convert_to_tensor(im, track_meta=get_track_meta())
|
243 |
+
|
244 |
+
return im
|
245 |
+
|
246 |
+
|
247 |
+
class RandFourierDropoutd(mt.RandomizableTransform, mt.MapTransform):
|
248 |
+
def __init__(self, keys, min_dist=0.1, max_dist=0.9, prob=0.1):
|
249 |
+
mt.RandomizableTransform.__init__(self, prob)
|
250 |
+
mt.MapTransform.__init__(self, keys)
|
251 |
+
self.dropper = FourierDropout(min_dist, max_dist)
|
252 |
+
|
253 |
+
def __call__(self, data, randomize: bool = True):
|
254 |
+
d = dict(data)
|
255 |
+
|
256 |
+
if randomize:
|
257 |
+
self.randomize(None)
|
258 |
+
|
259 |
+
for key in self.key_iterator(d):
|
260 |
+
if self._do_transform:
|
261 |
+
d[key] = self.dropper(d[key])
|
262 |
+
else:
|
263 |
+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
|
264 |
+
|
265 |
+
return d
|
266 |
+
|
267 |
+
|
268 |
+
class RandImageLMDeformd(mt.RandSmoothDeform):
|
269 |
+
"""Apply smooth random deformation to the image and landmark locations."""
|
270 |
+
|
271 |
+
def __call__(self, d):
|
272 |
+
d = dict(d)
|
273 |
+
old_label = d[CommonKeys.LABEL]
|
274 |
+
new_label = torch.zeros_like(old_label)
|
275 |
+
|
276 |
+
d[CommonKeys.IMAGE] = super().__call__(d[CommonKeys.IMAGE])
|
277 |
+
|
278 |
+
if self._do_transform:
|
279 |
+
field = self.sfield()
|
280 |
+
labels = np.argwhere(d[CommonKeys.LABEL][0].cpu().numpy() > 0)
|
281 |
+
|
282 |
+
# moving the landmarks this way prevents losing some to
|
283 |
+
# interpolation errors if deformation were applied the landmark image
|
284 |
+
for y, x in labels:
|
285 |
+
dy = int(field[0, y, x] * new_label.shape[1] / 2)
|
286 |
+
dx = int(field[1, y, x] * new_label.shape[2] / 2)
|
287 |
+
|
288 |
+
new_label[:, y - dy, x - dx] = old_label[:, y, x]
|
289 |
+
|
290 |
+
d[CommonKeys.LABEL] = new_label
|
291 |
+
|
292 |
+
return d
|
293 |
+
|
294 |
+
|
295 |
+
class RandLMShiftd(mt.RandomizableTransform, mt.MapTransform):
|
296 |
+
"""Randomly shift the image and landmark image in either direction in integer amounts."""
|
297 |
+
|
298 |
+
def __init__(self, keys, spatial_size, max_shift=0, prob=0.1):
|
299 |
+
mt.RandomizableTransform.__init__(self, prob=prob)
|
300 |
+
mt.MapTransform.__init__(self, keys=keys)
|
301 |
+
|
302 |
+
self.spatial_size = tuple(spatial_size)
|
303 |
+
self.max_shift = max_shift
|
304 |
+
self.padder = mt.BorderPad(self.max_shift)
|
305 |
+
self.unpadder = mt.CenterSpatialCrop(self.spatial_size)
|
306 |
+
self.shift = (0,) * len(self.spatial_size)
|
307 |
+
self.roll_dims = list(range(1, len(self.spatial_size) + 1))
|
308 |
+
|
309 |
+
def randomize(self, data):
|
310 |
+
super().randomize(None)
|
311 |
+
if self._do_transform:
|
312 |
+
rs = torch.randint(-self.max_shift, self.max_shift, (len(self.spatial_size),), dtype=torch.int32)
|
313 |
+
self.shift = tuple(rs.tolist())
|
314 |
+
|
315 |
+
def __call__(self, d, randomize: bool = True):
|
316 |
+
d = dict(d)
|
317 |
+
|
318 |
+
if randomize:
|
319 |
+
self.randomize(None)
|
320 |
+
|
321 |
+
if self._do_transform:
|
322 |
+
for key in self.key_iterator(d):
|
323 |
+
imp = self.padder(d[key])
|
324 |
+
ims = torch.roll(imp, self.shift, self.roll_dims) # prevents interpolation of landmark image
|
325 |
+
d[key] = self.unpadder(ims)
|
326 |
+
|
327 |
+
return d
|