monai
medical
katielink commited on
Commit
2a6977a
1 Parent(s): 53f8b32

Initial version

Browse files
README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: unknown
7
+ ---
8
+
9
+ # 2D Cardiac Valve Landmark Regressor
10
+
11
+ This network identifies 10 different landmarks in 2D+t MR images of the heart (2 chamber, 3 chamber, and 4 chamber) representing the insertion locations of valve leaflets into the myocardial wall. These coordinates are used in part of the construction of 3D FEM cardiac models suitable for physics simulation of heart functions.
12
+
13
+ Input images are individual 2D slices from the time series, and the output from the network is a `(2, 10)` set of 2D points in `HW` image coordinate space. The 10 coordinates correspond to the attachment point for these valves:
14
+
15
+ 1. Mitral anterior in 2CH
16
+ 2. Mitral posterior in 2CH
17
+ 3. Mitral septal in 3CH
18
+ 4. Mitral free wall in 3CH
19
+ 5. Mitral septal in 4CH
20
+ 6. Mitral free wall in 4CH
21
+ 7. Aortic septal
22
+ 8. Aortic free wall
23
+ 9. Tricuspid septal
24
+ 10. Tricuspid free wall
25
+
26
+ Landmarks which do not appear in a particular image are predicted to be `(0, 0)` or close to this location. The mitral valve is expected to appear in all three views. Landmarks are not provided for the pulmonary valve.
27
+
28
+ Example plot of landmarks on a single frame, see [view_results.ipynb](./view_results.ipynb) for visualising network output:
29
+
30
+ ![Landmark Example Image](./prediction_example.png)
31
+
32
+ ## Training
33
+
34
+ The training script `train.json` is provided to train the network using a dataset of image pairs containing the MR image and a landmark image. This is done to reuse image-based transforms which do not currently operate on geometry. A number of other transforms are provided in `valve_landmarks.py` to implement Fourier-space dropout, image shifting which preserve landmarks, and smooth-field deformation applied to images and landmarks.
35
+
36
+ The dataset used for training unfortunately cannot be made public, however the training script can be used with any NPZ file containing the training image stack in key `trainImgs` and landmark image stack in `trainLMImgs`, plus `testImgs` and `testLMImgs` containing validation data. The landmark images are defined as 0 for every non-landmark pixel, with landmark pixels contaning the following values for each landmark type:
37
+
38
+ * 10: Mitral anterior in 2CH
39
+ * 15: Mitral posterior in 2CH
40
+ * 20: Mitral septal in 3CH
41
+ * 25: Mitral free wall in 3CH
42
+ * 30: Mitral septal in 4CH
43
+ * 35: Mitral free wall in 4CH
44
+ * 100: Aortic septal
45
+ * 150: Aortic free wall
46
+ * 200: Tricuspid septal
47
+ * 250: Tricuspid free wall
48
+
49
+ The following command will train with the default NPZ filename `./valvelandmarks.npz`:
50
+
51
+ ```sh
52
+ PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json \
53
+ --config_file configs/train.json --bundle_root . --dataset_file /path/to/data --output_dir /path/to/outputs
54
+ ```
55
+
56
+ ## Inference
57
+
58
+ The included `inference.json` script will run inference on a directory containing Nifti files whose images have shape `(256, 256, 1, N)` for `N` timesteps. For each image the output in the `output_dir` directory will be a npy file containing a result array of shape `(N, 2, 10)` storing the 10 coordinates for each `N` timesteps. Invoking this script can be done as follows, assuming the current directory is the bundle directory:
59
+
60
+ ```sh
61
+ PYTHONPATH=./scripts python -m monai.bundle run evaluating --meta_file configs/metadata.json \
62
+ --config_file configs/inference.json --bundle_root . --dataset_dir /path/to/data --output_dir /path/to/outputs
63
+ ```
64
+
65
+ It is important to set the `PYTHONPATH` variable since code in the provided scripts directory is necessary for inference. The provided test Nifti file can be placed in a directory which is then used as the `dataset_dir` value. This image was derived from [the AMRG Cardiac Atlas dataset](http://www.cardiacatlas.org/studies/amrg-cardiac-atlas) (AMRG Cardiac Atlas, Auckland MRI Research Group, Auckland, New Zealand). The results from this inference can be visualised by changing path values in [view_results.ipynb](./view_results.ipynb).
66
+
67
+
68
+ ### Reference
69
+
70
+ The work for this model and its application is described in:
71
+
72
+ `Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. in E Puyol Anton, M Pop, M Sermesant, V Campello, A Lalande, K Lekadir, A Suinesiaputra, O Camara & A Young (eds), Statistical Atlases and Computational Models of the Heart. MandMs and EMIDEC Challenges - 11th International Workshop, STACOM 2020, Held in Conjunction with MICCAI 2020, Revised Selected Papers. Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), vol. 12592 LNCS, Springer Science and Business Media Deutschland GmbH, pp. 146-155, 11th International Workshop on Statistical Atlases and Computational Models of the Heart, STACOM 2020 held in Conjunction with MICCAI 2020, Lima, Peru, 4/10/2020. https://doi.org/10.1007/978-3-030-68107-4_15`
configs/inference.json ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import os",
4
+ "$import glob",
5
+ "$import scripts"
6
+ ],
7
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
8
+ "ckpt_path": "$@bundle_root + '/models/model.pt'",
9
+ "dataset_dir": "/workspace/data",
10
+ "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii*')))",
11
+ "output_dir": "./output",
12
+ "network_def": {
13
+ "_target_": "scripts.valve_landmarks.PointRegressor",
14
+ "in_shape": [
15
+ 1,
16
+ 256,
17
+ 256
18
+ ],
19
+ "out_shape": [
20
+ 2,
21
+ 10
22
+ ],
23
+ "channels": [
24
+ 8,
25
+ 16,
26
+ 32,
27
+ 64,
28
+ 128
29
+ ],
30
+ "strides": [
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2
36
+ ]
37
+ },
38
+ "network": "$@network_def.to(@device)",
39
+ "preprocessing": {
40
+ "_target_": "Compose",
41
+ "transforms": [
42
+ {
43
+ "_target_": "LoadImage",
44
+ "image_only": "true"
45
+ },
46
+ {
47
+ "_target_": "EnsureType",
48
+ "device": "@device"
49
+ },
50
+ {
51
+ "_target_": "Transpose",
52
+ "indices": [
53
+ 2,
54
+ 0,
55
+ 1,
56
+ 3
57
+ ]
58
+ },
59
+ {
60
+ "_target_": "ScaleIntensity"
61
+ }
62
+ ]
63
+ },
64
+ "dataset": {
65
+ "_target_": "Dataset",
66
+ "data": "@datalist",
67
+ "transform": "@preprocessing"
68
+ },
69
+ "dataloader": {
70
+ "_target_": "DataLoader",
71
+ "dataset": "@dataset",
72
+ "batch_size": 1,
73
+ "shuffle": false,
74
+ "num_workers": 0
75
+ },
76
+ "inferer": {
77
+ "_target_": "scripts.valve_landmarks.LandmarkInferer",
78
+ "spatial_dim": 2,
79
+ "stack_dim": 1
80
+ },
81
+ "postprocessing": {
82
+ "_target_": "Compose",
83
+ "transforms": [
84
+ {
85
+ "_target_": "SqueezeDimd",
86
+ "keys": "pred",
87
+ "dim": 0
88
+ },
89
+ {
90
+ "_target_": "scripts.valve_landmarks.NpySaverd",
91
+ "keys": "pred",
92
+ "output_dir": "@output_dir",
93
+ "data_root_dir": "@dataset_dir"
94
+ }
95
+ ]
96
+ },
97
+ "handlers": [
98
+ {
99
+ "_target_": "CheckpointLoader",
100
+ "_disabled_": "$not os.path.exists(@ckpt_path)",
101
+ "load_path": "@ckpt_path",
102
+ "load_dict": {
103
+ "net": "@network"
104
+ }
105
+ }
106
+ ],
107
+ "evaluator": {
108
+ "_target_": "SupervisedEvaluator",
109
+ "device": "@device",
110
+ "val_data_loader": "@dataloader",
111
+ "network": "@network",
112
+ "inferer": "@inferer",
113
+ "postprocessing": "@postprocessing",
114
+ "val_handlers": "@handlers",
115
+ "decollate": false,
116
+ "prepare_batch": "$lambda batch, dev,nb: (batch.to(dev),())"
117
+ },
118
+ "evaluating": [
119
+ "$@evaluator.run()"
120
+ ]
121
+ }
configs/metadata.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220729.json",
3
+ "version": "0.1.0",
4
+ "changelog": {
5
+ "0.1.0": "Initial version"
6
+ },
7
+ "monai_version": "1.0.0rc1",
8
+ "pytorch_version": "1.10.2",
9
+ "numpy_version": "1.21.2",
10
+ "optional_packages_version": {},
11
+ "task": "Given long axis MR images of the heart, identify valve insertion points through the full cardiac cycle",
12
+ "description": "This network is used to find where valves attach to heart to help construct 3D FEM models for computation. The output is an array of 10 2D coordinates.",
13
+ "authors": "Eric Kerfoot",
14
+ "copyright": "Copyright (c) Eric Kerfoot",
15
+ "references": [
16
+ "Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. https://doi.org/10.1007/978-3-030-68107-4_15"
17
+ ],
18
+ "intended_use": "This is suitable for research purposes only",
19
+ "image_classes": "Single channel data, intensity scaled to [0, 1]",
20
+ "data_source": "Non-public dataset comprised of hand-annotated full cycle long axis MR images",
21
+ "coordinate_values": {
22
+ "0": 10,
23
+ "1": 15,
24
+ "2": 20,
25
+ "3": 25,
26
+ "4": 30,
27
+ "5": 35,
28
+ "6": 100,
29
+ "7": 150,
30
+ "8": 200,
31
+ "9": 250
32
+ },
33
+ "coordinate_meanings": {
34
+ "0": "mitral anterior 2CH",
35
+ "1": "mitral posterior 2CH",
36
+ "2": "mitral septal 3CH",
37
+ "3": "mitral free wall 3CH",
38
+ "4": "mitral septal 4CH",
39
+ "5": "mitral free wall 4CH",
40
+ "6": "aortic septal",
41
+ "7": "aortic free wall",
42
+ "8": "tricuspid septal",
43
+ "9": "tricuspid free wall"
44
+ },
45
+ "network_data_format": {
46
+ "inputs": {
47
+ "image": {
48
+ "type": "image",
49
+ "format": "magnitude",
50
+ "modality": "MR",
51
+ "num_channels": 1,
52
+ "spatial_shape": [
53
+ 256,
54
+ 256
55
+ ],
56
+ "dtype": "float32",
57
+ "value_range": [],
58
+ "is_patch_data": false,
59
+ "channel_def": {
60
+ "0": "image"
61
+ }
62
+ }
63
+ },
64
+ "outputs": {
65
+ "pred": {
66
+ "type": "tuples",
67
+ "format": "points",
68
+ "num_channels": 2,
69
+ "spatial_shape": [
70
+ 2,
71
+ 10
72
+ ],
73
+ "dtype": "float32",
74
+ "value_range": [],
75
+ "is_patch_data": false,
76
+ "channel_def": {
77
+ "0": "Y Dimension",
78
+ "1": "X Dimension"
79
+ }
80
+ }
81
+ }
82
+ }
83
+ }
configs/train.json ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import datetime",
4
+ "$import numpy as np",
5
+ "$import torch",
6
+ "$import ignite",
7
+ "$import scripts"
8
+ ],
9
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
10
+ "ckpt_path": "$@bundle_root + '/models/model.pt'",
11
+ "dataset_file": "./valvelandmarks.npz",
12
+ "output_dir": "$datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')",
13
+ "network_def": {
14
+ "_target_": "scripts.valve_landmarks.PointRegressor",
15
+ "in_shape": [
16
+ 1,
17
+ 256,
18
+ 256
19
+ ],
20
+ "out_shape": [
21
+ 2,
22
+ 10
23
+ ],
24
+ "channels": [
25
+ 8,
26
+ 16,
27
+ 32,
28
+ 64,
29
+ 128
30
+ ],
31
+ "strides": [
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2
37
+ ]
38
+ },
39
+ "network": "$@network_def.to(@device)",
40
+ "im_shape": [
41
+ 1,
42
+ 256,
43
+ 256
44
+ ],
45
+ "both_keys": [
46
+ "image",
47
+ "label"
48
+ ],
49
+ "rand_prob": 0.5,
50
+ "train_transforms": {
51
+ "_target_": "Compose",
52
+ "transforms": [
53
+ {
54
+ "_target_": "EnsureTyped",
55
+ "keys": "@both_keys",
56
+ "data_type": "numpy",
57
+ "dtype": "$(np.float32, np.int32)"
58
+ },
59
+ {
60
+ "_target_": "EnsureTyped",
61
+ "keys": "@both_keys"
62
+ },
63
+ {
64
+ "_target_": "ScaleIntensityd",
65
+ "keys": "image"
66
+ },
67
+ {
68
+ "_target_": "EnsureChannelFirstd",
69
+ "keys": "@both_keys"
70
+ },
71
+ {
72
+ "_target_": "RandAxisFlipd",
73
+ "keys": "@both_keys",
74
+ "prob": "@rand_prob"
75
+ },
76
+ {
77
+ "_target_": "RandRotate90d",
78
+ "keys": "@both_keys",
79
+ "prob": "@rand_prob"
80
+ },
81
+ {
82
+ "_target_": "RandSmoothFieldAdjustIntensityd",
83
+ "keys": "image",
84
+ "prob": "@rand_prob",
85
+ "spatial_size": "@im_shape",
86
+ "rand_size": [
87
+ 5,
88
+ 5
89
+ ],
90
+ "gamma": [
91
+ 0.1,
92
+ 1
93
+ ],
94
+ "mode": "$monai.utils.InterpolateMode.BICUBIC",
95
+ "align_corners": true
96
+ },
97
+ {
98
+ "_target_": "RandGaussianNoised",
99
+ "keys": "image",
100
+ "prob": "@rand_prob",
101
+ "std": 0.05
102
+ },
103
+ {
104
+ "_target_": "scripts.valve_landmarks.RandFourierDropoutd",
105
+ "keys": "image",
106
+ "prob": "@rand_prob"
107
+ },
108
+ {
109
+ "_target_": "scripts.valve_landmarks.RandImageLMDeformd",
110
+ "prob": "@rand_prob",
111
+ "spatial_size": [
112
+ 256,
113
+ 256
114
+ ],
115
+ "rand_size": [
116
+ 7,
117
+ 7
118
+ ],
119
+ "pad": 2,
120
+ "field_mode": "$monai.utils.InterpolateMode.BICUBIC",
121
+ "align_corners": true,
122
+ "def_range": 0.05
123
+ },
124
+ {
125
+ "_target_": "scripts.valve_landmarks.RandLMShiftd",
126
+ "keys": "@both_keys",
127
+ "prob": "@rand_prob",
128
+ "spatial_size": [
129
+ 256,
130
+ 256
131
+ ],
132
+ "max_shift": 8
133
+ },
134
+ {
135
+ "_target_": "Lambdad",
136
+ "keys": "label",
137
+ "func": "$scripts.valve_landmarks.convert_lm_image_t"
138
+ }
139
+ ]
140
+ },
141
+ "eval_transforms": {
142
+ "_target_": "Compose",
143
+ "transforms": [
144
+ {
145
+ "_target_": "EnsureTyped",
146
+ "keys": "@both_keys",
147
+ "data_type": "numpy",
148
+ "dtype": "$(np.float32, np.int32)"
149
+ },
150
+ {
151
+ "_target_": "EnsureTyped",
152
+ "keys": "@both_keys"
153
+ },
154
+ {
155
+ "_target_": "ScaleIntensityd",
156
+ "keys": "image"
157
+ },
158
+ {
159
+ "_target_": "EnsureChannelFirstd",
160
+ "keys": "@both_keys"
161
+ },
162
+ {
163
+ "_target_": "Lambdad",
164
+ "keys": "label",
165
+ "func": "$scripts.valve_landmarks.convert_lm_image_t"
166
+ }
167
+ ]
168
+ },
169
+ "train_dataset": {
170
+ "_target_": "NPZDictItemDataset",
171
+ "npzfile": "$@dataset_file",
172
+ "keys": {
173
+ "trainImgs": "image",
174
+ "trainLMImgs": "label"
175
+ },
176
+ "transform": "@train_transforms"
177
+ },
178
+ "eval_dataset": {
179
+ "_target_": "NPZDictItemDataset",
180
+ "npzfile": "$@dataset_file",
181
+ "keys": {
182
+ "testImgs": "image",
183
+ "testLMImgs": "label"
184
+ },
185
+ "transform": "@eval_transforms"
186
+ },
187
+ "num_iters": 400,
188
+ "batch_size": 600,
189
+ "num_epochs": 100,
190
+ "num_substeps": 3,
191
+ "sampler": {
192
+ "_target_": "torch.utils.data.WeightedRandomSampler",
193
+ "weights": "$torch.ones(len(@train_dataset))",
194
+ "replacement": true,
195
+ "num_samples": "$@num_iters*@batch_size"
196
+ },
197
+ "train_dataloader": {
198
+ "_target_": "ThreadDataLoader",
199
+ "dataset": "@train_dataset",
200
+ "batch_size": "@batch_size",
201
+ "repeats": "@num_substeps",
202
+ "num_workers": 8,
203
+ "sampler": "@sampler"
204
+ },
205
+ "eval_dataloader": {
206
+ "_target_": "DataLoader",
207
+ "dataset": "@eval_dataset",
208
+ "batch_size": "@batch_size",
209
+ "num_workers": 8
210
+ },
211
+ "lossfn": {
212
+ "_target_": "torch.nn.L1Loss"
213
+ },
214
+ "optimizer": {
215
+ "_target_": "torch.optim.Adam",
216
+ "params": "$@network.parameters()",
217
+ "lr": 0.0001
218
+ },
219
+ "evaluator": {
220
+ "_target_": "SupervisedEvaluator",
221
+ "device": "@device",
222
+ "val_data_loader": "@eval_dataloader",
223
+ "network": "@network",
224
+ "key_val_metric": {
225
+ "val_mean_dist": {
226
+ "_target_": "ignite.metrics.MeanPairwiseDistance",
227
+ "output_transform": "$scripts.valve_landmarks._output_lm_trans"
228
+ }
229
+ },
230
+ "metric_cmp_fn": "$lambda current, prev: prev < 0 or current < prev",
231
+ "val_handlers": [
232
+ {
233
+ "_target_": "StatsHandler",
234
+ "output_transform": "$lambda x: None"
235
+ }
236
+ ]
237
+ },
238
+ "handlers": [
239
+ {
240
+ "_target_": "ValidationHandler",
241
+ "validator": "@evaluator",
242
+ "epoch_level": true,
243
+ "interval": 1
244
+ },
245
+ {
246
+ "_target_": "CheckpointSaver",
247
+ "save_dir": "@output_dir",
248
+ "save_dict": {
249
+ "net": "@network"
250
+ },
251
+ "save_interval": 1,
252
+ "save_final": true,
253
+ "epoch_level": true
254
+ }
255
+ ],
256
+ "trainer": {
257
+ "_target_": "SupervisedTrainer",
258
+ "max_epochs": "@num_epochs",
259
+ "device": "@device",
260
+ "train_data_loader": "@train_dataloader",
261
+ "network": "@network",
262
+ "loss_function": "@lossfn",
263
+ "optimizer": "@optimizer",
264
+ "key_train_metric": null,
265
+ "train_handlers": "@handlers"
266
+ },
267
+ "training": [
268
+ "$@trainer.run()"
269
+ ]
270
+ }
docs/AMRGAtlas_0031.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddac587d7ced3c7f892113e379256f055895162a5077f741c1bdc79e1a70b558
3
+ size 2190289
docs/AMRGAtlas_0031_key-pred.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48c44c9b85bb080e33a97b5421fde7037e039e235dbfbf08eff22a8fd2901b89
3
+ size 2688
docs/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 2D Cardiac Valve Landmark Regressor
3
+
4
+ This network identifies 10 different landmarks in 2D+t MR images of the heart (2 chamber, 3 chamber, and 4 chamber) representing the insertion locations of valve leaflets into the myocardial wall. These coordinates are used in part of the construction of 3D FEM cardiac models suitable for physics simulation of heart functions.
5
+
6
+ Input images are individual 2D slices from the time series, and the output from the network is a `(2, 10)` set of 2D points in `HW` image coordinate space. The 10 coordinates correspond to the attachment point for these valves:
7
+
8
+ 1. Mitral anterior in 2CH
9
+ 2. Mitral posterior in 2CH
10
+ 3. Mitral septal in 3CH
11
+ 4. Mitral free wall in 3CH
12
+ 5. Mitral septal in 4CH
13
+ 6. Mitral free wall in 4CH
14
+ 7. Aortic septal
15
+ 8. Aortic free wall
16
+ 9. Tricuspid septal
17
+ 10. Tricuspid free wall
18
+
19
+ Landmarks which do not appear in a particular image are predicted to be `(0, 0)` or close to this location. The mitral valve is expected to appear in all three views. Landmarks are not provided for the pulmonary valve.
20
+
21
+ Example plot of landmarks on a single frame, see [view_results.ipynb](./view_results.ipynb) for visualising network output:
22
+
23
+ ![Landmark Example Image](./prediction_example.png)
24
+
25
+ ## Training
26
+
27
+ The training script `train.json` is provided to train the network using a dataset of image pairs containing the MR image and a landmark image. This is done to reuse image-based transforms which do not currently operate on geometry. A number of other transforms are provided in `valve_landmarks.py` to implement Fourier-space dropout, image shifting which preserve landmarks, and smooth-field deformation applied to images and landmarks.
28
+
29
+ The dataset used for training unfortunately cannot be made public, however the training script can be used with any NPZ file containing the training image stack in key `trainImgs` and landmark image stack in `trainLMImgs`, plus `testImgs` and `testLMImgs` containing validation data. The landmark images are defined as 0 for every non-landmark pixel, with landmark pixels contaning the following values for each landmark type:
30
+
31
+ * 10: Mitral anterior in 2CH
32
+ * 15: Mitral posterior in 2CH
33
+ * 20: Mitral septal in 3CH
34
+ * 25: Mitral free wall in 3CH
35
+ * 30: Mitral septal in 4CH
36
+ * 35: Mitral free wall in 4CH
37
+ * 100: Aortic septal
38
+ * 150: Aortic free wall
39
+ * 200: Tricuspid septal
40
+ * 250: Tricuspid free wall
41
+
42
+ The following command will train with the default NPZ filename `./valvelandmarks.npz`:
43
+
44
+ ```sh
45
+ PYTHONPATH=./scripts python -m monai.bundle run training --meta_file configs/metadata.json \
46
+ --config_file configs/train.json --bundle_root . --dataset_file /path/to/data --output_dir /path/to/outputs
47
+ ```
48
+
49
+ ## Inference
50
+
51
+ The included `inference.json` script will run inference on a directory containing Nifti files whose images have shape `(256, 256, 1, N)` for `N` timesteps. For each image the output in the `output_dir` directory will be a npy file containing a result array of shape `(N, 2, 10)` storing the 10 coordinates for each `N` timesteps. Invoking this script can be done as follows, assuming the current directory is the bundle directory:
52
+
53
+ ```sh
54
+ PYTHONPATH=./scripts python -m monai.bundle run evaluating --meta_file configs/metadata.json \
55
+ --config_file configs/inference.json --bundle_root . --dataset_dir /path/to/data --output_dir /path/to/outputs
56
+ ```
57
+
58
+ It is important to set the `PYTHONPATH` variable since code in the provided scripts directory is necessary for inference. The provided test Nifti file can be placed in a directory which is then used as the `dataset_dir` value. This image was derived from [the AMRG Cardiac Atlas dataset](http://www.cardiacatlas.org/studies/amrg-cardiac-atlas) (AMRG Cardiac Atlas, Auckland MRI Research Group, Auckland, New Zealand). The results from this inference can be visualised by changing path values in [view_results.ipynb](./view_results.ipynb).
59
+
60
+
61
+ ### Reference
62
+
63
+ The work for this model and its application is described in:
64
+
65
+ `Kerfoot, E, King, CE, Ismail, T, Nordsletten, D & Miller, R 2021, Estimation of Cardiac Valve Annuli Motion with Deep Learning. in E Puyol Anton, M Pop, M Sermesant, V Campello, A Lalande, K Lekadir, A Suinesiaputra, O Camara & A Young (eds), Statistical Atlases and Computational Models of the Heart. MandMs and EMIDEC Challenges - 11th International Workshop, STACOM 2020, Held in Conjunction with MICCAI 2020, Revised Selected Papers. Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics), vol. 12592 LNCS, Springer Science and Business Media Deutschland GmbH, pp. 146-155, 11th International Workshop on Statistical Atlases and Computational Models of the Heart, STACOM 2020 held in Conjunction with MICCAI 2020, Lima, Peru, 4/10/2020. https://doi.org/10.1007/978-3-030-68107-4_15`
docs/license.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Eric Kerfoot
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
docs/prediction_example.png ADDED
docs/view_results.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0187ee65b150c0c693b16e6a65d0b3f659bed163423a1e4a5a6b3bb4fbeb7bf
3
+ size 13349733
scripts/__init__.py ADDED
File without changes
scripts/valve_landmarks.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Eric Kerfoot under MIT license, see license.txt
2
+
3
+ import os
4
+ from typing import Any, Callable, Sequence
5
+
6
+ import monai
7
+ import monai.transforms as mt
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from monai.data.meta_obj import get_track_meta
12
+ from monai.networks.blocks import ConvDenseBlock, Convolution
13
+ from monai.networks.layers import Flatten, Reshape
14
+ from monai.networks.nets import Regressor
15
+ from monai.networks.utils import meshgrid_ij
16
+ from monai.utils import CommonKeys
17
+ from monai.utils import ImageMetaKey as Key
18
+ from monai.utils import convert_to_numpy, convert_to_tensor
19
+
20
+ # relates the label in training images to index of landmark point
21
+ LM_INDICES = {
22
+ 10: 0, # mitral anterior 2CH
23
+ 15: 1, # mitral posterior 2CH
24
+ 20: 2, # mitral septal 3CH
25
+ 25: 3, # mitral free wall 3CH
26
+ 30: 4, # mitral septal 4CH
27
+ 35: 5, # mitral free wall 4CH
28
+ 100: 6, # aortic septal
29
+ 150: 7, # aortic free wall
30
+ 200: 8, # tricuspid septal
31
+ 250: 9, # tricuspid free wall
32
+ }
33
+
34
+ output_trans = monai.handlers.from_engine(["pred", "label"])
35
+
36
+
37
+ def _output_lm_trans(data):
38
+ pred, label = output_trans(data)
39
+ return [p.permute(1, 0) for p in pred], [l.permute(1, 0) for l in label]
40
+
41
+
42
+ def convert_lm_image_t(lm_image):
43
+ """Convert a landmark image into a (2,N) tensor of landmark coordinates."""
44
+ lmarray = torch.zeros((2, len(LM_INDICES)), dtype=torch.float32).to(lm_image.device)
45
+
46
+ for _, y, x in np.argwhere(lm_image.cpu().numpy() != 0):
47
+ im_id = int(lm_image[0, y, x])
48
+ lm_index = LM_INDICES[im_id]
49
+
50
+ lmarray[0, lm_index] = y
51
+ lmarray[1, lm_index] = x
52
+
53
+ return lmarray
54
+
55
+
56
+ class ParallelCat(nn.Module):
57
+ """
58
+ Apply the same input to each of the given modules and concatenate their results together.
59
+
60
+ Args:
61
+ catmodules: sequence of nn.Module objects to apply inputs to
62
+ cat_dim: dimension to concatenate along when combining outputs
63
+ """
64
+
65
+ def __init__(self, catmodules: Sequence[nn.Module], cat_dim: int = 1):
66
+ super().__init__()
67
+ self.cat_dim = cat_dim
68
+
69
+ for i, s in enumerate(catmodules):
70
+ self.add_module(f"catmodule_{i}", s)
71
+
72
+ def forward(self, x):
73
+ tensors = [s(x) for s in self.children()]
74
+ return torch.cat(tensors, self.cat_dim)
75
+
76
+
77
+ class PointRegressor(Regressor):
78
+ """Regressor defined as a sequence of dense blocks followed by convolution/linear layers for each landmark."""
79
+
80
+ def _get_layer(self, in_channels, out_channels, strides, is_last):
81
+ dout = out_channels - in_channels
82
+ dilations = [1, 2, 4]
83
+ dchannels = [dout // 3, dout // 3, dout // 3 + dout % 3]
84
+
85
+ db = ConvDenseBlock(
86
+ spatial_dims=self.dimensions,
87
+ in_channels=in_channels,
88
+ channels=dchannels,
89
+ dilations=dilations,
90
+ kernel_size=self.kernel_size,
91
+ num_res_units=self.num_res_units,
92
+ act=self.act,
93
+ norm=self.norm,
94
+ dropout=self.dropout,
95
+ bias=self.bias,
96
+ )
97
+
98
+ conv = Convolution(
99
+ spatial_dims=self.dimensions,
100
+ in_channels=out_channels,
101
+ out_channels=out_channels,
102
+ strides=strides,
103
+ kernel_size=self.kernel_size,
104
+ act=self.act,
105
+ norm=self.norm,
106
+ dropout=self.dropout,
107
+ bias=self.bias,
108
+ conv_only=is_last,
109
+ )
110
+
111
+ return nn.Sequential(db, conv)
112
+
113
+ def _get_final_layer(self, in_shape):
114
+ point_paths = []
115
+
116
+ for _ in range(self.out_shape[1]):
117
+ conv = Convolution(
118
+ spatial_dims=self.dimensions,
119
+ in_channels=in_shape[0],
120
+ out_channels=in_shape[0] * 2,
121
+ strides=2,
122
+ kernel_size=self.kernel_size,
123
+ act=self.act,
124
+ norm=self.norm,
125
+ dropout=self.dropout,
126
+ conv_only=True,
127
+ )
128
+ linear = nn.Linear(int(np.product(in_shape)) // 2, self.out_shape[0])
129
+ point_paths.append(nn.Sequential(conv, Flatten(), linear))
130
+
131
+ return torch.nn.Sequential(ParallelCat(point_paths), Reshape(*self.out_shape))
132
+
133
+
134
+ class LandmarkInferer(monai.inferers.Inferer):
135
+ """Applies inference on 2D slices from 3D volumes."""
136
+
137
+ def __init__(self, spatial_dim=0, stack_dim=-1):
138
+ self.spatial_dim = spatial_dim
139
+ self.stack_dim = stack_dim
140
+
141
+ def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any):
142
+ if inputs.ndim != 5:
143
+ raise ValueError(f"Input volume to inferer must have shape BCDHW, input shape is {inputs.shape}")
144
+
145
+ results = []
146
+ input_slices = [slice(None) for _ in range(inputs.ndim)]
147
+
148
+ for idx in range(inputs.shape[self.spatial_dim + 2]):
149
+ input_slices[self.spatial_dim + 2] = idx
150
+ input_2d = inputs[input_slices] # BCDHW -> BCHW by iterating over one of DHW
151
+
152
+ result = network(input_2d, *args, **kwargs)
153
+ results.append(result)
154
+
155
+ result = torch.stack(results, self.stack_dim)
156
+ return result
157
+
158
+
159
+ class NpySaverd(mt.MapTransform):
160
+ """Saves tensors/arrays to Numpy npy files."""
161
+
162
+ def __init__(self, keys, output_dir, data_root_dir):
163
+ super().__init__(keys)
164
+ self.output_dir = output_dir
165
+ self.data_root_dir = data_root_dir
166
+ self.folder_layout = monai.data.FolderLayout(
167
+ self.output_dir, extension=".npy", data_root_dir=self.data_root_dir
168
+ )
169
+
170
+ def __call__(self, d):
171
+ if not os.path.exists(self.output_dir):
172
+ os.makedirs(self.output_dir, exist_ok=True)
173
+
174
+ for key in self.key_iterator(d):
175
+ orig_filename = d[key].meta[Key.FILENAME_OR_OBJ]
176
+ if isinstance(orig_filename, (list, tuple)):
177
+ orig_filename = orig_filename[0]
178
+
179
+ out_filename = self.folder_layout.filename(orig_filename, key=key)
180
+
181
+ np.save(out_filename, convert_to_numpy(d[key]))
182
+
183
+ return d
184
+
185
+
186
+ class FourierDropout(mt.Transform, mt.Fourier):
187
+ """
188
+ Apply dropout in Fourier space to corrupt images. This works by zeroing out pixels with greater probability the
189
+ farther from the centre they are. All pixels closer than `min_dist` to the center are preserved, all beyond
190
+ `max_dist` become 0. Distances from the centre to an edge in a given dimension are defined as 1.0.
191
+
192
+ Args:
193
+ min_dist: minimum distance to apply dropout, must be >0, smaller values will cause greater corruption
194
+ max_dist: maximal distance to apply dropout, must be greater than `min_dist`, all pixels beyond become 0
195
+ """
196
+
197
+ def __init__(self, min_dist: float = 0.1, max_dist: float = 0.9):
198
+ super().__init__()
199
+ self.min_dist = min_dist
200
+ self.max_dist = max_dist
201
+ self.prob_field = None
202
+ self.field_shape = None
203
+
204
+ def _get_prob_field(self, shape):
205
+ shape = tuple(shape)
206
+ if shape != self.field_shape:
207
+ self.field_shape = shape
208
+ spaces = [torch.linspace(-1, 1, s) for s in shape[1:]]
209
+ grids = meshgrid_ij(*spaces)
210
+ # middle is 0, mid edges 1, corners sqrt(2)
211
+ self.prob_field = torch.stack(grids).pow_(2).sum(axis=0).sqrt_()
212
+
213
+ return self.prob_field
214
+
215
+ def __call__(self, im):
216
+ probfield = self._get_prob_field(im.shape).to(im.device)
217
+
218
+ # rand range from min_dist to max_dist
219
+ dropout = torch.rand_like(im).mul_(self.max_dist - self.min_dist).add_(self.min_dist)
220
+ # keep pixel if dropout value is greater than distance from center, so less likely farther from center
221
+ dropout = dropout.ge_(probfield)
222
+
223
+ result = self.shift_fourier(im, im.ndim - 1)
224
+ result.mul_(dropout)
225
+ result = self.inv_shift_fourier(result, im.ndim - 1)
226
+
227
+ return convert_to_tensor(result, track_meta=get_track_meta())
228
+
229
+
230
+ class RandFourierDropout(mt.RandomizableTransform):
231
+ def __init__(self, min_dist=0.1, max_dist=0.9, prob=0.1):
232
+ mt.RandomizableTransform.__init__(self, prob)
233
+ self.dropper = FourierDropout(min_dist, max_dist)
234
+
235
+ def __call__(self, im, randomize: bool = True):
236
+ if randomize:
237
+ self.randomize(None)
238
+
239
+ if self._do_transform:
240
+ im = self.dropper(im)
241
+ else:
242
+ im = convert_to_tensor(im, track_meta=get_track_meta())
243
+
244
+ return im
245
+
246
+
247
+ class RandFourierDropoutd(mt.RandomizableTransform, mt.MapTransform):
248
+ def __init__(self, keys, min_dist=0.1, max_dist=0.9, prob=0.1):
249
+ mt.RandomizableTransform.__init__(self, prob)
250
+ mt.MapTransform.__init__(self, keys)
251
+ self.dropper = FourierDropout(min_dist, max_dist)
252
+
253
+ def __call__(self, data, randomize: bool = True):
254
+ d = dict(data)
255
+
256
+ if randomize:
257
+ self.randomize(None)
258
+
259
+ for key in self.key_iterator(d):
260
+ if self._do_transform:
261
+ d[key] = self.dropper(d[key])
262
+ else:
263
+ d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
264
+
265
+ return d
266
+
267
+
268
+ class RandImageLMDeformd(mt.RandSmoothDeform):
269
+ """Apply smooth random deformation to the image and landmark locations."""
270
+
271
+ def __call__(self, d):
272
+ d = dict(d)
273
+ old_label = d[CommonKeys.LABEL]
274
+ new_label = torch.zeros_like(old_label)
275
+
276
+ d[CommonKeys.IMAGE] = super().__call__(d[CommonKeys.IMAGE])
277
+
278
+ if self._do_transform:
279
+ field = self.sfield()
280
+ labels = np.argwhere(d[CommonKeys.LABEL][0].cpu().numpy() > 0)
281
+
282
+ # moving the landmarks this way prevents losing some to
283
+ # interpolation errors if deformation were applied the landmark image
284
+ for y, x in labels:
285
+ dy = int(field[0, y, x] * new_label.shape[1] / 2)
286
+ dx = int(field[1, y, x] * new_label.shape[2] / 2)
287
+
288
+ new_label[:, y - dy, x - dx] = old_label[:, y, x]
289
+
290
+ d[CommonKeys.LABEL] = new_label
291
+
292
+ return d
293
+
294
+
295
+ class RandLMShiftd(mt.RandomizableTransform, mt.MapTransform):
296
+ """Randomly shift the image and landmark image in either direction in integer amounts."""
297
+
298
+ def __init__(self, keys, spatial_size, max_shift=0, prob=0.1):
299
+ mt.RandomizableTransform.__init__(self, prob=prob)
300
+ mt.MapTransform.__init__(self, keys=keys)
301
+
302
+ self.spatial_size = tuple(spatial_size)
303
+ self.max_shift = max_shift
304
+ self.padder = mt.BorderPad(self.max_shift)
305
+ self.unpadder = mt.CenterSpatialCrop(self.spatial_size)
306
+ self.shift = (0,) * len(self.spatial_size)
307
+ self.roll_dims = list(range(1, len(self.spatial_size) + 1))
308
+
309
+ def randomize(self, data):
310
+ super().randomize(None)
311
+ if self._do_transform:
312
+ rs = torch.randint(-self.max_shift, self.max_shift, (len(self.spatial_size),), dtype=torch.int32)
313
+ self.shift = tuple(rs.tolist())
314
+
315
+ def __call__(self, d, randomize: bool = True):
316
+ d = dict(d)
317
+
318
+ if randomize:
319
+ self.randomize(None)
320
+
321
+ if self._do_transform:
322
+ for key in self.key_iterator(d):
323
+ imp = self.padder(d[key])
324
+ ims = torch.roll(imp, self.shift, self.roll_dims) # prevents interpolation of landmark image
325
+ d[key] = self.unpadder(ims)
326
+
327
+ return d