monai
medical
katielink commited on
Commit
7712592
1 Parent(s): e883afe

complete the model package

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/demo.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: unknown
7
+ ---
8
+ # Description
9
+ Detailed whole brain segmentation is an essential quantitative technique in medical image analysis, which provides a non-invasive way of measuring brain regions from a clinical acquired structural magnetic resonance imaging (MRI).
10
+ We provide the pre-trained model for inferencing whole brain segmentation with 133 structures.
11
+
12
+ A tutorial and release of model for whole brain segmentation using the 3D transformer-based segmentation model UNEST.
13
+
14
+ Authors:
15
+ Xin Yu (xin.yu@vanderbilt.edu) (Primary)
16
+
17
+ Yinchi Zhou (yinchi.zhou@vanderbilt.edu) | Yucheng Tang (yuchengt@nvidia.com)
18
+
19
+ <p align="center">
20
+ -------------------------------------------------------------------------------------
21
+ </p>
22
+
23
+ ![](./demo.png) <br>
24
+ <p align="center">
25
+ Fig.1 - The demonstration of T1w MRI images registered in MNI space and the whole brain segmentation labels with 133 classes</p>
26
+
27
+
28
+
29
+ # Model Overview
30
+ A pre-trained larger UNEST base model [1] for volumetric (3D) whole brain segmentation with T1w MR images.
31
+ To leverage information across embedded sequences, ”shifted window” transformers
32
+ are proposed for dense predictions and modeling multi-scale features. However, these
33
+ attempts that aim to complicate the self-attention range often yield high computation
34
+ complexity and data inefficiency. Inspired by the aggregation function in the nested
35
+ ViT, we propose a new design of a 3D U-shape medical segmentation model with
36
+ Nested Transformers (UNesT) hierarchically with the 3D block aggregation function,
37
+ that learn locality behaviors for small structures or small dataset. This design retains
38
+ the original global self-attention mechanism and achieves information communication
39
+ across patches by stacking transformer encoders hierarchically.
40
+
41
+ ![](./unest.png) <br>
42
+ <p align="center">
43
+ Fig.2 - The network architecture of UNEST Base model
44
+ </p>
45
+
46
+
47
+ ## Data
48
+ The training data is from the Vanderbilt University and Vanderbilt University Medical Center with public released OASIS and CANDI datsets.
49
+ Training and testing data are MRI T1-weighted (T1w) 3D volumes coming from 3 different sites. There are a total of 133 classes in the whole brain segmentation task.
50
+ Among 50 T1w MRI scans from Open Access Series on Imaging Studies (OASIS) (Marcus et al., 2007) dataset, 45 scans are used for training and the other 5 for validation.
51
+ The testing cohort contains Colin27 T1w scan (Aubert-Broche et al., 2006) and 13 T1w MRI scans from the Child and Adolescent Neuro Development Initiative (CANDI)
52
+ (Kennedy et al., 2012). All data are registered to the MNI space using the MNI305 (Evans et al., 1993) template and preprocessed follow the method in (Huo et al., 2019). Input images are randomly cropped to the size of 96 × 96 × 96.
53
+
54
+ ### Important
55
+
56
+ ```diff
57
+ + All the brain MRI images for training are registered to Affine registration from the target image to the MNI305 template using NiftyReg.
58
+ + The data should be in the MNI305 space before inference.
59
+
60
+
61
+ Registration to MNI Space: Sample suggestion. E.g., use ANTS or other tools for registering T1 MRI image to MNI305 Space.
62
+
63
+
64
+ ```
65
+ pip install antspyx
66
+ ```
67
+ Sample ANTS registration
68
+ ```
69
+
70
+ import ants
71
+ import sys
72
+ import os
73
+
74
+ fixed_image = ants.image_read('<fixed_image_path>')
75
+ moving_image = ants.image_read('<moving_image_path>')
76
+ transform = ants.registration(fixed_image,moving_image,'Affine')
77
+
78
+ reg3t = ants.apply_transforms(fixed_image,moving_image,transform['fwdtransforms'][0])
79
+ ants.image_write(reg3t,output_image_path)
80
+
81
+ ```
82
+ ## Training configuration
83
+ The training and inference was performed with at least one 24GB-memory GPU.
84
+
85
+ Actual Model Input: 96 x 96 x 96
86
+
87
+ ## Input and output formats
88
+ Input: 1 channel T1w MRI image in MNI305 Space.
89
+
90
+
91
+ ## commands example
92
+ Download trained checkpoint model to ./model/model.pt:
93
+
94
+
95
+ Add scripts component: To run the workflow with customized components, PYTHONPATH should be revised to include the path to the customized component:
96
+
97
+ ```
98
+ export PYTHONPATH=$PYTHONPATH: '<path to the bundle root dir>/scripts'
99
+
100
+ ```
101
+
102
+
103
+ Execute inference:
104
+
105
+ ```
106
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
107
+ ```
108
+
109
+
110
+ ## More examples output
111
+ ![](./wholebrain.png) <br>
112
+ <p align="center">
113
+ Fig.3 - The output prediction comparison with variant and ground truth
114
+ </p>
115
+
116
+
117
+ ## Complete ROI of the whole brain segmentation
118
+ 133 brain structures are segmented.
119
+
120
+ | #1 | #2 | #3 | #4 |
121
+ | :------------ | :---------- | :-------- | :-------- |
122
+ | 0: background | 1 : 3rd-Ventricle | 2 : 4th-Ventricle | 3 : Right-Accumbens-Area |
123
+ | 4 : Left-Accumbens-Area | 5 : Right-Amygdala | 6 : Left-Amygdala | 7 : Brain-Stem |
124
+ | 8 : Right-Caudate | 9 : Left-Caudate | 10 : Right-Cerebellum-Exterior | 11 : Left-Cerebellum-Exterior |
125
+ | 12 : Right-Cerebellum-White-Matter | 13 : Left-Cerebellum-White-Matter | 14 : Right-Cerebral-White-Matter | 15 : Left-Cerebral-White-Matter |
126
+ | 16 : Right-Hippocampus | 17 : Left-Hippocampus | 18 : Right-Inf-Lat-Vent | 19 : Left-Inf-Lat-Vent |
127
+ | 20 : Right-Lateral-Ventricle | 21 : Left-Lateral-Ventricle | 22 : Right-Pallidum | 23 : Left-Pallidum |
128
+ | 24 : Right-Putamen | 25 : Left-Putamen | 26 : Right-Thalamus-Proper | 27 : Left-Thalamus-Proper |
129
+ | 28 : Right-Ventral-DC | 29 : Left-Ventral-DC | 30 : Cerebellar-Vermal-Lobules-I-V | 31 : Cerebellar-Vermal-Lobules-VI-VII |
130
+ | 32 : Cerebellar-Vermal-Lobules-VIII-X | 33 : Left-Basal-Forebrain | 34 : Right-Basal-Forebrain | 35 : Right-ACgG--anterior-cingulate-gyrus |
131
+ | 36 : Left-ACgG--anterior-cingulate-gyrus | 37 : Right-AIns--anterior-insula | 38 : Left-AIns--anterior-insula | 39 : Right-AOrG--anterior-orbital-gyrus |
132
+ | 40 : Left-AOrG--anterior-orbital-gyrus | 41 : Right-AnG---angular-gyrus | 42 : Left-AnG---angular-gyrus | 43 : Right-Calc--calcarine-cortex |
133
+ | 44 : Left-Calc--calcarine-cortex | 45 : Right-CO----central-operculum | 46 : Left-CO----central-operculum | 47 : Right-Cun---cuneus |
134
+ | 48 : Left-Cun---cuneus | 49 : Right-Ent---entorhinal-area | 50 : Left-Ent---entorhinal-area | 51 : Right-FO----frontal-operculum |
135
+ | 52 : Left-FO----frontal-operculum | 53 : Right-FRP---frontal-pole | 54 : Left-FRP---frontal-pole | 55 : Right-FuG---fusiform-gyrus |
136
+ | 56 : Left-FuG---fusiform-gyrus | 57 : Right-GRe---gyrus-rectus | 58 : Left-GRe---gyrus-rectus | 59 : Right-IOG---inferior-occipital-gyrus ,
137
+ | 60 : Left-IOG---inferior-occipital-gyrus | 61 : Right-ITG---inferior-temporal-gyrus | 62 : Left-ITG---inferior-temporal-gyrus | 63 : Right-LiG---lingual-gyrus |
138
+ | 64 : Left-LiG---lingual-gyrus | 65 : Right-LOrG--lateral-orbital-gyrus | 66 : Left-LOrG--lateral-orbital-gyrus | 67 : Right-MCgG--middle-cingulate-gyrus |
139
+ | 68 : Left-MCgG--middle-cingulate-gyrus | 69 : Right-MFC---medial-frontal-cortex | 70 : Left-MFC---medial-frontal-cortex | 71 : Right-MFG---middle-frontal-gyrus |
140
+ | 72 : Left-MFG---middle-frontal-gyrus | 73 : Right-MOG---middle-occipital-gyrus | 74 : Left-MOG---middle-occipital-gyrus | 75 : Right-MOrG--medial-orbital-gyrus |
141
+ | 76 : Left-MOrG--medial-orbital-gyrus | 77 : Right-MPoG--postcentral-gyrus | 78 : Left-MPoG--postcentral-gyrus | 79 : Right-MPrG--precentral-gyrus |
142
+ | 80 : Left-MPrG--precentral-gyrus | 81 : Right-MSFG--superior-frontal-gyrus | 82 : Left-MSFG--superior-frontal-gyrus | 83 : Right-MTG---middle-temporal-gyrus |
143
+ | 84 : Left-MTG---middle-temporal-gyrus | 85 : Right-OCP---occipital-pole | 86 : Left-OCP---occipital-pole | 87 : Right-OFuG--occipital-fusiform-gyrus |
144
+ | 88 : Left-OFuG--occipital-fusiform-gyrus | 89 : Right-OpIFG-opercular-part-of-the-IFG | 90 : Left-OpIFG-opercular-part-of-the-IFG | 91 : Right-OrIFG-orbital-part-of-the-IFG |
145
+ | 92 : Left-OrIFG-orbital-part-of-the-IFG | 93 : Right-PCgG--posterior-cingulate-gyrus | 94 : Left-PCgG--posterior-cingulate-gyrus | 95 : Right-PCu---precuneus |
146
+ | 96 : Left-PCu---precuneus | 97 : Right-PHG---parahippocampal-gyrus | 98 : Left-PHG---parahippocampal-gyrus | 99 : Right-PIns--posterior-insula |
147
+ | 100 : Left-PIns--posterior-insula | 101 : Right-PO----parietal-operculum | 102 : Left-PO----parietal-operculum | 103 : Right-PoG---postcentral-gyrus |
148
+ | 104 : Left-PoG---postcentral-gyrus | 105 : Right-POrG--posterior-orbital-gyrus | 106 : Left-POrG--posterior-orbital-gyrus | 107 : Right-PP----planum-polare |
149
+ | 108 : Left-PP----planum-polare | 109 : Right-PrG---precentral-gyrus | 110 : Left-PrG---precentral-gyrus | 111 : Right-PT----planum-temporale |
150
+ | 112 : Left-PT----planum-temporale | 113 : Right-SCA---subcallosal-area | 114 : Left-SCA---subcallosal-area | 115 : Right-SFG---superior-frontal-gyrus |
151
+ | 116 : Left-SFG---superior-frontal-gyrus | 117 : Right-SMC---supplementary-motor-cortex | 118 : Left-SMC---supplementary-motor-cortex | 119 : Right-SMG---supramarginal-gyrus |
152
+ | 120 : Left-SMG---supramarginal-gyrus | 121 : Right-SOG---superior-occipital-gyrus | 122 : Left-SOG---superior-occipital-gyrus | 123 : Right-SPL---superior-parietal-lobule |
153
+ | 124 : Left-SPL---superior-parietal-lobule | 125 : Right-STG---superior-temporal-gyrus | 126 : Left-STG---superior-temporal-gyrus | 127 : Right-TMP---temporal-pole |
154
+ | 128 : Left-TMP---temporal-pole | 129 : Right-TrIFG-triangular-part-of-the-IFG | 130 : Left-TrIFG-triangular-part-of-the-IFG | 131 : Right-TTG---transverse-temporal-gyrus |
155
+ | 132 : Left-TTG---transverse-temporal-gyrus |
156
+
157
+
158
+ ## Bundle Integration in MONAI Lable
159
+ The inference pipleine can be easily used by the MONAI Label server and 3D Slicer for fast labeling T1w MRI images in MNI space.
160
+
161
+ ![](./3DSlicer_use.png) <br>
162
+
163
+ # Disclaimer
164
+ This is an example, not to be used for diagnostic purposes.
165
+
166
+ # References
167
+ [1] Yu, Xin, Yinchi Zhou, Yucheng Tang et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf
168
+
169
+ [2] Zizhao Zhang et al. Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding. AAAI Conference on Artificial Intelligence (AAAI) 2022
170
+
171
+ [3] Huo, Yuankai, et al. 3D whole brain segmentation using spatially localized atlas network tiles. NeuroImage 194 (2019): 105-119.
configs/inference.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": ".",
7
+ "output_dir": "$@bundle_root + '/eval'",
8
+ "dataset_dir": "$@bundle_root + '/dataset/images'",
9
+ "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii.gz')))",
10
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
11
+ "network_def": {
12
+ "_target_": "scripts.networks.unest_base_patch_4.UNesT",
13
+ "in_channels": 1,
14
+ "out_channels": 133,
15
+ "patch_size": 4,
16
+ "depths": [
17
+ 2,
18
+ 2,
19
+ 8
20
+ ],
21
+ "embed_dim": [
22
+ 128,
23
+ 256,
24
+ 512
25
+ ],
26
+ "num_heads": [
27
+ 4,
28
+ 8,
29
+ 16
30
+ ]
31
+ },
32
+ "network": "$@network_def.to(@device)",
33
+ "preprocessing": {
34
+ "_target_": "Compose",
35
+ "transforms": [
36
+ {
37
+ "_target_": "LoadImaged",
38
+ "keys": "image"
39
+ },
40
+ {
41
+ "_target_": "EnsureChannelFirstd",
42
+ "keys": "image"
43
+ },
44
+ {
45
+ "_target_": "NormalizeIntensityd",
46
+ "keys": "image",
47
+ "nonzero": "True",
48
+ "channel_wise": "True"
49
+ },
50
+ {
51
+ "_target_": "EnsureTyped",
52
+ "keys": "image"
53
+ }
54
+ ]
55
+ },
56
+ "dataset": {
57
+ "_target_": "Dataset",
58
+ "data": "$[{'image': i} for i in @datalist]",
59
+ "transform": "@preprocessing"
60
+ },
61
+ "dataloader": {
62
+ "_target_": "DataLoader",
63
+ "dataset": "@dataset",
64
+ "batch_size": 1,
65
+ "shuffle": false,
66
+ "num_workers": 4
67
+ },
68
+ "inferer": {
69
+ "_target_": "SlidingWindowInferer",
70
+ "roi_size": [
71
+ 96,
72
+ 96,
73
+ 96
74
+ ],
75
+ "sw_batch_size": 4,
76
+ "overlap": 0.7
77
+ },
78
+ "postprocessing": {
79
+ "_target_": "Compose",
80
+ "transforms": [
81
+ {
82
+ "_target_": "Activationsd",
83
+ "keys": "pred",
84
+ "softmax": true
85
+ },
86
+ {
87
+ "_target_": "Invertd",
88
+ "keys": "pred",
89
+ "transform": "@preprocessing",
90
+ "orig_keys": "image",
91
+ "meta_key_postfix": "meta_dict",
92
+ "nearest_interp": false,
93
+ "to_tensor": true
94
+ },
95
+ {
96
+ "_target_": "AsDiscreted",
97
+ "keys": "pred",
98
+ "argmax": true
99
+ },
100
+ {
101
+ "_target_": "SaveImaged",
102
+ "keys": "pred",
103
+ "meta_keys": "pred_meta_dict",
104
+ "output_dir": "@output_dir"
105
+ }
106
+ ]
107
+ },
108
+ "handlers": [
109
+ {
110
+ "_target_": "CheckpointLoader",
111
+ "load_path": "$@bundle_root + '/models/model.pt'",
112
+ "load_dict": {
113
+ "model": "@network"
114
+ },
115
+ "strict": "True"
116
+ },
117
+ {
118
+ "_target_": "StatsHandler",
119
+ "iteration_log": false
120
+ }
121
+ ],
122
+ "evaluator": {
123
+ "_target_": "SupervisedEvaluator",
124
+ "device": "@device",
125
+ "val_data_loader": "@dataloader",
126
+ "network": "@network",
127
+ "inferer": "@inferer",
128
+ "postprocessing": "@postprocessing",
129
+ "val_handlers": "@handlers",
130
+ "amp": false
131
+ },
132
+ "evaluating": [
133
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
134
+ "$@evaluator.run()"
135
+ ]
136
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.1.0",
4
+ "changelog": {
5
+ "0.1.0": "complete the model package"
6
+ },
7
+ "monai_version": "0.9.1",
8
+ "pytorch_version": "1.10.0",
9
+ "numpy_version": "1.21.2",
10
+ "optional_packages_version": {
11
+ "nibabel": "3.2.1",
12
+ "pytorch-ignite": "0.4.8",
13
+ "einops": "0.4.1",
14
+ "fire": "0.4.0",
15
+ "timm": "0.6.7"
16
+ },
17
+ "task": "Whole Brain Segmentation",
18
+ "description": "A 3D transformer-based model for whole brain segmentation from T1W MRI image",
19
+ "authors": "Vanderbilt University + MONAI team",
20
+ "copyright": "Copyright (c) MONAI Consortium",
21
+ "data_source": "",
22
+ "data_type": "nibabel",
23
+ "image_classes": "single channel data, intensity scaled to [0, 1]",
24
+ "label_classes": "133 Classes",
25
+ "pred_classes": "133 Classes",
26
+ "eval_metrics": {
27
+ "mean_dice": 0.71
28
+ },
29
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
30
+ "references": [
31
+ "Xin, et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf"
32
+ ],
33
+ "network_data_format": {
34
+ "inputs": {
35
+ "image": {
36
+ "type": "image",
37
+ "format": "hounsfield",
38
+ "modality": "MRI",
39
+ "num_channels": 1,
40
+ "spatial_shape": [
41
+ 96,
42
+ 96,
43
+ 96
44
+ ],
45
+ "dtype": "float32",
46
+ "value_range": [
47
+ 0,
48
+ 1
49
+ ],
50
+ "is_patch_data": true,
51
+ "channel_def": {
52
+ "0": "image"
53
+ }
54
+ }
55
+ },
56
+ "outputs": {
57
+ "pred": {
58
+ "type": "image",
59
+ "format": "segmentation",
60
+ "num_channels": 133,
61
+ "spatial_shape": [
62
+ 96,
63
+ 96,
64
+ 96
65
+ ],
66
+ "dtype": "float32",
67
+ "value_range": [
68
+ 0,
69
+ 1
70
+ ],
71
+ "is_patch_data": true,
72
+ "channel_def": {
73
+ "0": "background",
74
+ "1": "3rd-Ventricle",
75
+ "2": "4th-Ventricle",
76
+ "3": "Right-Accumbens-Area",
77
+ "4": "Left-Accumbens-Area",
78
+ "5": "Right-Amygdala",
79
+ "6": "Left-Amygdala",
80
+ "7": "Brain-Stem",
81
+ "8": "Right-Caudate",
82
+ "9": "Left-Caudate",
83
+ "10": "Right-Cerebellum-Exterior",
84
+ "11": "Left-Cerebellum-Exterior",
85
+ "12": "Right-Cerebellum-White-Matter",
86
+ "13": "Left-Cerebellum-White-Matter",
87
+ "14": "Right-Cerebral-White-Matter",
88
+ "15": "Left-Cerebral-White-Matter",
89
+ "16": "Right-Hippocampus",
90
+ "17": "Left-Hippocampus",
91
+ "18": "Right-Inf-Lat-Vent",
92
+ "19": "Left-Inf-Lat-Vent",
93
+ "20": "Right-Lateral-Ventricle",
94
+ "21": "Left-Lateral-Ventricle",
95
+ "22": "Right-Pallidum",
96
+ "23": "Left-Pallidum",
97
+ "24": "Right-Putamen",
98
+ "25": "Left-Putamen",
99
+ "26": "Right-Thalamus-Proper",
100
+ "27": "Left-Thalamus-Proper",
101
+ "28": "Right-Ventral-DC",
102
+ "29": "Left-Ventral-DC",
103
+ "30": "Cerebellar-Vermal-Lobules-I-V",
104
+ "31": "Cerebellar-Vermal-Lobules-VI-VII",
105
+ "32": "Cerebellar-Vermal-Lobules-VIII-X",
106
+ "33": "Left-Basal-Forebrain",
107
+ "34": "Right-Basal-Forebrain",
108
+ "35": "Right-ACgG--anterior-cingulate-gyrus",
109
+ "36": "Left-ACgG--anterior-cingulate-gyrus",
110
+ "37": "Right-AIns--anterior-insula",
111
+ "38": "Left-AIns--anterior-insula",
112
+ "39": "Right-AOrG--anterior-orbital-gyrus",
113
+ "40": "Left-AOrG--anterior-orbital-gyrus",
114
+ "41": "Right-AnG---angular-gyrus",
115
+ "42": "Left-AnG---angular-gyrus",
116
+ "43": "Right-Calc--calcarine-cortex",
117
+ "44": "Left-Calc--calcarine-cortex",
118
+ "45": "Right-CO----central-operculum",
119
+ "46": "Left-CO----central-operculum",
120
+ "47": "Right-Cun---cuneus",
121
+ "48": "Left-Cun---cuneus",
122
+ "49": "Right-Ent---entorhinal-area",
123
+ "50": "Left-Ent---entorhinal-area",
124
+ "51": "Right-FO----frontal-operculum",
125
+ "52": "Left-FO----frontal-operculum",
126
+ "53": "Right-FRP---frontal-pole",
127
+ "54": "Left-FRP---frontal-pole",
128
+ "55": "Right-FuG---fusiform-gyrus ",
129
+ "56": "Left-FuG---fusiform-gyrus",
130
+ "57": "Right-GRe---gyrus-rectus",
131
+ "58": "Left-GRe---gyrus-rectus",
132
+ "59": "Right-IOG---inferior-occipital-gyrus",
133
+ "60": "Left-IOG---inferior-occipital-gyrus",
134
+ "61": "Right-ITG---inferior-temporal-gyrus",
135
+ "62": "Left-ITG---inferior-temporal-gyrus",
136
+ "63": "Right-LiG---lingual-gyrus",
137
+ "64": "Left-LiG---lingual-gyrus",
138
+ "65": "Right-LOrG--lateral-orbital-gyrus",
139
+ "66": "Left-LOrG--lateral-orbital-gyrus",
140
+ "67": "Right-MCgG--middle-cingulate-gyrus",
141
+ "68": "Left-MCgG--middle-cingulate-gyrus",
142
+ "69": "Right-MFC---medial-frontal-cortex",
143
+ "70": "Left-MFC---medial-frontal-cortex",
144
+ "71": "Right-MFG---middle-frontal-gyrus",
145
+ "72": "Left-MFG---middle-frontal-gyrus",
146
+ "73": "Right-MOG---middle-occipital-gyrus",
147
+ "74": "Left-MOG---middle-occipital-gyrus",
148
+ "75": "Right-MOrG--medial-orbital-gyrus",
149
+ "76": "Left-MOrG--medial-orbital-gyrus",
150
+ "77": "Right-MPoG--postcentral-gyrus",
151
+ "78": "Left-MPoG--postcentral-gyrus",
152
+ "79": "Right-MPrG--precentral-gyrus",
153
+ "80": "Left-MPrG--precentral-gyrus",
154
+ "81": "Right-MSFG--superior-frontal-gyrus",
155
+ "82": "Left-MSFG--superior-frontal-gyrus",
156
+ "83": "Right-MTG---middle-temporal-gyrus",
157
+ "84": "Left-MTG---middle-temporal-gyrus",
158
+ "85": "Right-OCP---occipital-pole",
159
+ "86": "Left-OCP---occipital-pole",
160
+ "87": "Right-OFuG--occipital-fusiform-gyrus",
161
+ "88": "Left-OFuG--occipital-fusiform-gyrus",
162
+ "89": "Right-OpIFG-opercular-part-of-the-IFG",
163
+ "90": "Left-OpIFG-opercular-part-of-the-IFG",
164
+ "91": "Right-OrIFG-orbital-part-of-the-IFG",
165
+ "92": "Left-OrIFG-orbital-part-of-the-IFG",
166
+ "93": "Right-PCgG--posterior-cingulate-gyrus",
167
+ "94": "Left-PCgG--posterior-cingulate-gyrus",
168
+ "95": "Right-PCu---precuneus",
169
+ "96": "Left-PCu---precuneus",
170
+ "97": "Right-PHG---parahippocampal-gyrus",
171
+ "98": "Left-PHG---parahippocampal-gyrus",
172
+ "99": "Right-PIns--posterior-insula",
173
+ "100": "Left-PIns--posterior-insula",
174
+ "101": "Right-PO----parietal-operculum",
175
+ "102": "Left-PO----parietal-operculum",
176
+ "103": "Right-PoG---postcentral-gyrus",
177
+ "104": "Left-PoG---postcentral-gyrus",
178
+ "105": "Right-POrG--posterior-orbital-gyrus",
179
+ "106": "Left-POrG--posterior-orbital-gyrus",
180
+ "107": "Right-PP----planum-polare",
181
+ "108": "Left-PP----planum-polare",
182
+ "109": "Right-PrG---precentral-gyrus",
183
+ "110": "Left-PrG---precentral-gyrus",
184
+ "111": "Right-PT----planum-temporale",
185
+ "112": "Left-PT----planum-temporale",
186
+ "113": "Right-SCA---subcallosal-area",
187
+ "114": "Left-SCA---subcallosal-area",
188
+ "115": "Right-SFG---superior-frontal-gyrus",
189
+ "116": "Left-SFG---superior-frontal-gyrus",
190
+ "117": "Right-SMC---supplementary-motor-cortex",
191
+ "118": "Left-SMC---supplementary-motor-cortex",
192
+ "119": "Right-SMG---supramarginal-gyrus",
193
+ "120": "Left-SMG---supramarginal-gyrus",
194
+ "121": "Right-SOG---superior-occipital-gyrus",
195
+ "122": "Left-SOG---superior-occipital-gyrus",
196
+ "123": "Right-SPL---superior-parietal-lobule",
197
+ "124": "Left-SPL---superior-parietal-lobule",
198
+ "125": "Right-STG---superior-temporal-gyrus",
199
+ "126": "Left-STG---superior-temporal-gyrus",
200
+ "127": "Right-TMP---temporal-pole",
201
+ "128": "Left-TMP---temporal-pole",
202
+ "129": "Right-TrIFG-triangular-part-of-the-IFG",
203
+ "130": "Left-TrIFG-triangular-part-of-the-IFG",
204
+ "131": "Right-TTG---transverse-temporal-gyrus",
205
+ "132": "Left-TTG---transverse-temporal-gyrus"
206
+ }
207
+ }
208
+ }
209
+ }
210
+ }
docs/3DSlicer_use.png ADDED
docs/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ Detailed whole brain segmentation is an essential quantitative technique in medical image analysis, which provides a non-invasive way of measuring brain regions from a clinical acquired structural magnetic resonance imaging (MRI).
3
+ We provide the pre-trained model for inferencing whole brain segmentation with 133 structures.
4
+
5
+ A tutorial and release of model for whole brain segmentation using the 3D transformer-based segmentation model UNEST.
6
+
7
+ Authors:
8
+ Xin Yu (xin.yu@vanderbilt.edu) (Primary)
9
+
10
+ Yinchi Zhou (yinchi.zhou@vanderbilt.edu) | Yucheng Tang (yuchengt@nvidia.com)
11
+
12
+ <p align="center">
13
+ -------------------------------------------------------------------------------------
14
+ </p>
15
+
16
+ ![](./demo.png) <br>
17
+ <p align="center">
18
+ Fig.1 - The demonstration of T1w MRI images registered in MNI space and the whole brain segmentation labels with 133 classes</p>
19
+
20
+
21
+
22
+ # Model Overview
23
+ A pre-trained larger UNEST base model [1] for volumetric (3D) whole brain segmentation with T1w MR images.
24
+ To leverage information across embedded sequences, ”shifted window” transformers
25
+ are proposed for dense predictions and modeling multi-scale features. However, these
26
+ attempts that aim to complicate the self-attention range often yield high computation
27
+ complexity and data inefficiency. Inspired by the aggregation function in the nested
28
+ ViT, we propose a new design of a 3D U-shape medical segmentation model with
29
+ Nested Transformers (UNesT) hierarchically with the 3D block aggregation function,
30
+ that learn locality behaviors for small structures or small dataset. This design retains
31
+ the original global self-attention mechanism and achieves information communication
32
+ across patches by stacking transformer encoders hierarchically.
33
+
34
+ ![](./unest.png) <br>
35
+ <p align="center">
36
+ Fig.2 - The network architecture of UNEST Base model
37
+ </p>
38
+
39
+
40
+ ## Data
41
+ The training data is from the Vanderbilt University and Vanderbilt University Medical Center with public released OASIS and CANDI datsets.
42
+ Training and testing data are MRI T1-weighted (T1w) 3D volumes coming from 3 different sites. There are a total of 133 classes in the whole brain segmentation task.
43
+ Among 50 T1w MRI scans from Open Access Series on Imaging Studies (OASIS) (Marcus et al., 2007) dataset, 45 scans are used for training and the other 5 for validation.
44
+ The testing cohort contains Colin27 T1w scan (Aubert-Broche et al., 2006) and 13 T1w MRI scans from the Child and Adolescent Neuro Development Initiative (CANDI)
45
+ (Kennedy et al., 2012). All data are registered to the MNI space using the MNI305 (Evans et al., 1993) template and preprocessed follow the method in (Huo et al., 2019). Input images are randomly cropped to the size of 96 × 96 × 96.
46
+
47
+ ### Important
48
+
49
+ ```diff
50
+ + All the brain MRI images for training are registered to Affine registration from the target image to the MNI305 template using NiftyReg.
51
+ + The data should be in the MNI305 space before inference.
52
+
53
+
54
+ Registration to MNI Space: Sample suggestion. E.g., use ANTS or other tools for registering T1 MRI image to MNI305 Space.
55
+
56
+
57
+ ```
58
+ pip install antspyx
59
+ ```
60
+ Sample ANTS registration
61
+ ```
62
+
63
+ import ants
64
+ import sys
65
+ import os
66
+
67
+ fixed_image = ants.image_read('<fixed_image_path>')
68
+ moving_image = ants.image_read('<moving_image_path>')
69
+ transform = ants.registration(fixed_image,moving_image,'Affine')
70
+
71
+ reg3t = ants.apply_transforms(fixed_image,moving_image,transform['fwdtransforms'][0])
72
+ ants.image_write(reg3t,output_image_path)
73
+
74
+ ```
75
+ ## Training configuration
76
+ The training and inference was performed with at least one 24GB-memory GPU.
77
+
78
+ Actual Model Input: 96 x 96 x 96
79
+
80
+ ## Input and output formats
81
+ Input: 1 channel T1w MRI image in MNI305 Space.
82
+
83
+
84
+ ## commands example
85
+ Download trained checkpoint model to ./model/model.pt:
86
+
87
+
88
+ Add scripts component: To run the workflow with customized components, PYTHONPATH should be revised to include the path to the customized component:
89
+
90
+ ```
91
+ export PYTHONPATH=$PYTHONPATH: '<path to the bundle root dir>/scripts'
92
+
93
+ ```
94
+
95
+
96
+ Execute inference:
97
+
98
+ ```
99
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
100
+ ```
101
+
102
+
103
+ ## More examples output
104
+ ![](./wholebrain.png) <br>
105
+ <p align="center">
106
+ Fig.3 - The output prediction comparison with variant and ground truth
107
+ </p>
108
+
109
+
110
+ ## Complete ROI of the whole brain segmentation
111
+ 133 brain structures are segmented.
112
+
113
+ | #1 | #2 | #3 | #4 |
114
+ | :------------ | :---------- | :-------- | :-------- |
115
+ | 0: background | 1 : 3rd-Ventricle | 2 : 4th-Ventricle | 3 : Right-Accumbens-Area |
116
+ | 4 : Left-Accumbens-Area | 5 : Right-Amygdala | 6 : Left-Amygdala | 7 : Brain-Stem |
117
+ | 8 : Right-Caudate | 9 : Left-Caudate | 10 : Right-Cerebellum-Exterior | 11 : Left-Cerebellum-Exterior |
118
+ | 12 : Right-Cerebellum-White-Matter | 13 : Left-Cerebellum-White-Matter | 14 : Right-Cerebral-White-Matter | 15 : Left-Cerebral-White-Matter |
119
+ | 16 : Right-Hippocampus | 17 : Left-Hippocampus | 18 : Right-Inf-Lat-Vent | 19 : Left-Inf-Lat-Vent |
120
+ | 20 : Right-Lateral-Ventricle | 21 : Left-Lateral-Ventricle | 22 : Right-Pallidum | 23 : Left-Pallidum |
121
+ | 24 : Right-Putamen | 25 : Left-Putamen | 26 : Right-Thalamus-Proper | 27 : Left-Thalamus-Proper |
122
+ | 28 : Right-Ventral-DC | 29 : Left-Ventral-DC | 30 : Cerebellar-Vermal-Lobules-I-V | 31 : Cerebellar-Vermal-Lobules-VI-VII |
123
+ | 32 : Cerebellar-Vermal-Lobules-VIII-X | 33 : Left-Basal-Forebrain | 34 : Right-Basal-Forebrain | 35 : Right-ACgG--anterior-cingulate-gyrus |
124
+ | 36 : Left-ACgG--anterior-cingulate-gyrus | 37 : Right-AIns--anterior-insula | 38 : Left-AIns--anterior-insula | 39 : Right-AOrG--anterior-orbital-gyrus |
125
+ | 40 : Left-AOrG--anterior-orbital-gyrus | 41 : Right-AnG---angular-gyrus | 42 : Left-AnG---angular-gyrus | 43 : Right-Calc--calcarine-cortex |
126
+ | 44 : Left-Calc--calcarine-cortex | 45 : Right-CO----central-operculum | 46 : Left-CO----central-operculum | 47 : Right-Cun---cuneus |
127
+ | 48 : Left-Cun---cuneus | 49 : Right-Ent---entorhinal-area | 50 : Left-Ent---entorhinal-area | 51 : Right-FO----frontal-operculum |
128
+ | 52 : Left-FO----frontal-operculum | 53 : Right-FRP---frontal-pole | 54 : Left-FRP---frontal-pole | 55 : Right-FuG---fusiform-gyrus |
129
+ | 56 : Left-FuG---fusiform-gyrus | 57 : Right-GRe---gyrus-rectus | 58 : Left-GRe---gyrus-rectus | 59 : Right-IOG---inferior-occipital-gyrus ,
130
+ | 60 : Left-IOG---inferior-occipital-gyrus | 61 : Right-ITG---inferior-temporal-gyrus | 62 : Left-ITG---inferior-temporal-gyrus | 63 : Right-LiG---lingual-gyrus |
131
+ | 64 : Left-LiG---lingual-gyrus | 65 : Right-LOrG--lateral-orbital-gyrus | 66 : Left-LOrG--lateral-orbital-gyrus | 67 : Right-MCgG--middle-cingulate-gyrus |
132
+ | 68 : Left-MCgG--middle-cingulate-gyrus | 69 : Right-MFC---medial-frontal-cortex | 70 : Left-MFC---medial-frontal-cortex | 71 : Right-MFG---middle-frontal-gyrus |
133
+ | 72 : Left-MFG---middle-frontal-gyrus | 73 : Right-MOG---middle-occipital-gyrus | 74 : Left-MOG---middle-occipital-gyrus | 75 : Right-MOrG--medial-orbital-gyrus |
134
+ | 76 : Left-MOrG--medial-orbital-gyrus | 77 : Right-MPoG--postcentral-gyrus | 78 : Left-MPoG--postcentral-gyrus | 79 : Right-MPrG--precentral-gyrus |
135
+ | 80 : Left-MPrG--precentral-gyrus | 81 : Right-MSFG--superior-frontal-gyrus | 82 : Left-MSFG--superior-frontal-gyrus | 83 : Right-MTG---middle-temporal-gyrus |
136
+ | 84 : Left-MTG---middle-temporal-gyrus | 85 : Right-OCP---occipital-pole | 86 : Left-OCP---occipital-pole | 87 : Right-OFuG--occipital-fusiform-gyrus |
137
+ | 88 : Left-OFuG--occipital-fusiform-gyrus | 89 : Right-OpIFG-opercular-part-of-the-IFG | 90 : Left-OpIFG-opercular-part-of-the-IFG | 91 : Right-OrIFG-orbital-part-of-the-IFG |
138
+ | 92 : Left-OrIFG-orbital-part-of-the-IFG | 93 : Right-PCgG--posterior-cingulate-gyrus | 94 : Left-PCgG--posterior-cingulate-gyrus | 95 : Right-PCu---precuneus |
139
+ | 96 : Left-PCu---precuneus | 97 : Right-PHG---parahippocampal-gyrus | 98 : Left-PHG---parahippocampal-gyrus | 99 : Right-PIns--posterior-insula |
140
+ | 100 : Left-PIns--posterior-insula | 101 : Right-PO----parietal-operculum | 102 : Left-PO----parietal-operculum | 103 : Right-PoG---postcentral-gyrus |
141
+ | 104 : Left-PoG---postcentral-gyrus | 105 : Right-POrG--posterior-orbital-gyrus | 106 : Left-POrG--posterior-orbital-gyrus | 107 : Right-PP----planum-polare |
142
+ | 108 : Left-PP----planum-polare | 109 : Right-PrG---precentral-gyrus | 110 : Left-PrG---precentral-gyrus | 111 : Right-PT----planum-temporale |
143
+ | 112 : Left-PT----planum-temporale | 113 : Right-SCA---subcallosal-area | 114 : Left-SCA---subcallosal-area | 115 : Right-SFG---superior-frontal-gyrus |
144
+ | 116 : Left-SFG---superior-frontal-gyrus | 117 : Right-SMC---supplementary-motor-cortex | 118 : Left-SMC---supplementary-motor-cortex | 119 : Right-SMG---supramarginal-gyrus |
145
+ | 120 : Left-SMG---supramarginal-gyrus | 121 : Right-SOG---superior-occipital-gyrus | 122 : Left-SOG---superior-occipital-gyrus | 123 : Right-SPL---superior-parietal-lobule |
146
+ | 124 : Left-SPL---superior-parietal-lobule | 125 : Right-STG---superior-temporal-gyrus | 126 : Left-STG---superior-temporal-gyrus | 127 : Right-TMP---temporal-pole |
147
+ | 128 : Left-TMP---temporal-pole | 129 : Right-TrIFG-triangular-part-of-the-IFG | 130 : Left-TrIFG-triangular-part-of-the-IFG | 131 : Right-TTG---transverse-temporal-gyrus |
148
+ | 132 : Left-TTG---transverse-temporal-gyrus |
149
+
150
+
151
+ ## Bundle Integration in MONAI Lable
152
+ The inference pipleine can be easily used by the MONAI Label server and 3D Slicer for fast labeling T1w MRI images in MNI space.
153
+
154
+ ![](./3DSlicer_use.png) <br>
155
+
156
+ # Disclaimer
157
+ This is an example, not to be used for diagnostic purposes.
158
+
159
+ # References
160
+ [1] Yu, Xin, Yinchi Zhou, Yucheng Tang et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf
161
+
162
+ [2] Zizhao Zhang et al. Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding. AAAI Conference on Artificial Intelligence (AAAI) 2022
163
+
164
+ [3] Huo, Yuankai, et al. 3D whole brain segmentation using spatially localized atlas network tiles. NeuroImage 194 (2019): 105-119.
docs/demo.png ADDED

Git LFS Details

  • SHA256: 92aae2d9b2901de18b445d6e6efdf48b6c3d8bb5e66ee55c3fde152e13f952f7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
docs/license.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
docs/unest.png ADDED
docs/wholebrain.png ADDED
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79a52ccd77bc35d05410f39788a1b063af3eb3b809b42241335c18aed27ec422
3
+ size 348901503
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/nest/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from .utils import (
3
+ Conv3dSame,
4
+ DropPath,
5
+ Linear,
6
+ Mlp,
7
+ _assert,
8
+ conv3d_same,
9
+ create_conv3d,
10
+ create_pool3d,
11
+ get_padding,
12
+ get_same_padding,
13
+ pad_same,
14
+ to_ntuple,
15
+ trunc_normal_,
16
+ )
scripts/networks/nest/utils.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import collections.abc
5
+ import math
6
+ import warnings
7
+ from itertools import repeat
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from torch import _assert
16
+ except ImportError:
17
+
18
+ def _assert(condition: bool, message: str):
19
+ assert condition, message
20
+
21
+
22
+ def drop_block_2d(
23
+ x,
24
+ drop_prob: float = 0.1,
25
+ block_size: int = 7,
26
+ gamma_scale: float = 1.0,
27
+ with_noise: bool = False,
28
+ inplace: bool = False,
29
+ batchwise: bool = False,
30
+ ):
31
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
32
+
33
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
34
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
35
+ """
36
+ b, c, h, w = x.shape
37
+ total_size = w * h
38
+ clipped_block_size = min(block_size, min(w, h))
39
+ # seed_drop_rate, the gamma parameter
40
+ gamma = (
41
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
42
+ )
43
+
44
+ # Forces the block to be inside the feature map.
45
+ w_i, h_i = torch.meshgrid(torch.arange(w).to(x.device), torch.arange(h).to(x.device))
46
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < w - (clipped_block_size - 1) // 2)) & (
47
+ (h_i >= clipped_block_size // 2) & (h_i < h - (clipped_block_size - 1) // 2)
48
+ )
49
+ valid_block = torch.reshape(valid_block, (1, 1, h, w)).to(dtype=x.dtype)
50
+
51
+ if batchwise:
52
+ # one mask for whole batch, quite a bit faster
53
+ uniform_noise = torch.rand((1, c, h, w), dtype=x.dtype, device=x.device)
54
+ else:
55
+ uniform_noise = torch.rand_like(x)
56
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
57
+ block_mask = -F.max_pool2d(
58
+ -block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
59
+ )
60
+
61
+ if with_noise:
62
+ normal_noise = torch.randn((1, c, h, w), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
63
+ if inplace:
64
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
65
+ else:
66
+ x = x * block_mask + normal_noise * (1 - block_mask)
67
+ else:
68
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
69
+ if inplace:
70
+ x.mul_(block_mask * normalize_scale)
71
+ else:
72
+ x = x * block_mask * normalize_scale
73
+ return x
74
+
75
+
76
+ def drop_block_fast_2d(
77
+ x: torch.Tensor,
78
+ drop_prob: float = 0.1,
79
+ block_size: int = 7,
80
+ gamma_scale: float = 1.0,
81
+ with_noise: bool = False,
82
+ inplace: bool = False,
83
+ ):
84
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
85
+
86
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
87
+ block mask at edges.
88
+ """
89
+ b, c, h, w = x.shape
90
+ total_size = w * h
91
+ clipped_block_size = min(block_size, min(w, h))
92
+ gamma = (
93
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
94
+ )
95
+
96
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
97
+ block_mask = F.max_pool2d(
98
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
99
+ )
100
+
101
+ if with_noise:
102
+ normal_noise = torch.empty_like(x).normal_()
103
+ if inplace:
104
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
105
+ else:
106
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
107
+ else:
108
+ block_mask = 1 - block_mask
109
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
110
+ if inplace:
111
+ x.mul_(block_mask * normalize_scale)
112
+ else:
113
+ x = x * block_mask * normalize_scale
114
+ return x
115
+
116
+
117
+ class DropBlock2d(nn.Module):
118
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
119
+
120
+ def __init__(
121
+ self, drop_prob=0.1, block_size=7, gamma_scale=1.0, with_noise=False, inplace=False, batchwise=False, fast=True
122
+ ):
123
+ super(DropBlock2d, self).__init__()
124
+ self.drop_prob = drop_prob
125
+ self.gamma_scale = gamma_scale
126
+ self.block_size = block_size
127
+ self.with_noise = with_noise
128
+ self.inplace = inplace
129
+ self.batchwise = batchwise
130
+ self.fast = fast # FIXME finish comparisons of fast vs not
131
+
132
+ def forward(self, x):
133
+ if not self.training or not self.drop_prob:
134
+ return x
135
+ if self.fast:
136
+ return drop_block_fast_2d(
137
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
138
+ )
139
+ else:
140
+ return drop_block_2d(
141
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
142
+ )
143
+
144
+
145
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
146
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
147
+
148
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
149
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
150
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
151
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
152
+ 'survival rate' as the argument.
153
+
154
+ """
155
+ if drop_prob == 0.0 or not training:
156
+ return x
157
+ keep_prob = 1 - drop_prob
158
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
159
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
160
+ if keep_prob > 0.0 and scale_by_keep:
161
+ random_tensor.div_(keep_prob)
162
+ return x * random_tensor
163
+
164
+
165
+ class DropPath(nn.Module):
166
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
167
+
168
+ def __init__(self, drop_prob=None, scale_by_keep=True):
169
+ super(DropPath, self).__init__()
170
+ self.drop_prob = drop_prob
171
+ self.scale_by_keep = scale_by_keep
172
+
173
+ def forward(self, x):
174
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
175
+
176
+
177
+ def create_conv3d(in_channels, out_channels, kernel_size, **kwargs):
178
+ """Select a 2d convolution implementation based on arguments
179
+ Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv3d, or CondConv2d.
180
+
181
+ Used extensively by EfficientNet, MobileNetv3 and related networks.
182
+ """
183
+
184
+ depthwise = kwargs.pop("depthwise", False)
185
+ # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
186
+ groups = in_channels if depthwise else kwargs.pop("groups", 1)
187
+
188
+ m = create_conv3d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
189
+ return m
190
+
191
+
192
+ def conv3d_same(
193
+ x,
194
+ weight: torch.Tensor,
195
+ bias: Optional[torch.Tensor] = None,
196
+ stride: Tuple[int, int] = (1, 1, 1),
197
+ padding: Tuple[int, int] = (0, 0, 0),
198
+ dilation: Tuple[int, int] = (1, 1, 1),
199
+ groups: int = 1,
200
+ ):
201
+ x = pad_same(x, weight.shape[-3:], stride, dilation)
202
+ return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)
203
+
204
+
205
+ class Conv3dSame(nn.Conv2d):
206
+ """Tensorflow like 'SAME' convolution wrapper for 2D convolutions"""
207
+
208
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
209
+ super(Conv3dSame, self).__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
210
+
211
+ def forward(self, x):
212
+ return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
213
+
214
+
215
+ def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
216
+ padding = kwargs.pop("padding", "")
217
+ kwargs.setdefault("bias", False)
218
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
219
+ if is_dynamic:
220
+ return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
221
+ else:
222
+ return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
223
+
224
+
225
+ # Calculate symmetric padding for a convolution
226
+ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
227
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
228
+ return padding
229
+
230
+
231
+ # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
232
+ def get_same_padding(x: int, k: int, s: int, d: int):
233
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
234
+
235
+
236
+ # Can SAME padding for given args be done statically?
237
+ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
238
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
239
+
240
+
241
+ # Dynamically pad input x with 'SAME' padding for conv with specified args
242
+ def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
243
+ id, ih, iw = x.size()[-3:]
244
+ pad_d, pad_h, pad_w = (
245
+ get_same_padding(id, k[0], s[0], d[0]),
246
+ get_same_padding(ih, k[1], s[1], d[1]),
247
+ get_same_padding(iw, k[2], s[2], d[2]),
248
+ )
249
+ if pad_d > 0 or pad_h > 0 or pad_w > 0:
250
+ x = F.pad(
251
+ x,
252
+ [pad_d // 2, pad_d - pad_d // 2, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
253
+ value=value,
254
+ )
255
+ return x
256
+
257
+
258
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
259
+ dynamic = False
260
+ if isinstance(padding, str):
261
+ # for any string padding, the padding will be calculated for you, one of three ways
262
+ padding = padding.lower()
263
+ if padding == "same":
264
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
265
+ if is_static_pad(kernel_size, **kwargs):
266
+ # static case, no extra overhead
267
+ padding = get_padding(kernel_size, **kwargs)
268
+ else:
269
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
270
+ padding = 0
271
+ dynamic = True
272
+ elif padding == "valid":
273
+ # 'VALID' padding, same as padding=0
274
+ padding = 0
275
+ else:
276
+ # Default to PyTorch style 'same'-ish symmetric padding
277
+ padding = get_padding(kernel_size, **kwargs)
278
+ return padding, dynamic
279
+
280
+
281
+ # From PyTorch internals
282
+ def _ntuple(n):
283
+ def parse(x):
284
+ if isinstance(x, collections.abc.Iterable):
285
+ return x
286
+ return tuple(repeat(x, n))
287
+
288
+ return parse
289
+
290
+
291
+ to_1tuple = _ntuple(1)
292
+ to_2tuple = _ntuple(2)
293
+ to_3tuple = _ntuple(3)
294
+ to_4tuple = _ntuple(4)
295
+ to_ntuple = _ntuple
296
+
297
+
298
+ def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
299
+ min_value = min_value or divisor
300
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
301
+ # Make sure that round down does not go down by more than 10%.
302
+ if new_v < round_limit * v:
303
+ new_v += divisor
304
+ return new_v
305
+
306
+
307
+ class Linear(nn.Linear):
308
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
309
+
310
+ Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
311
+ weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
312
+ """
313
+
314
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
315
+ if torch.jit.is_scripting():
316
+ bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
317
+ return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
318
+ else:
319
+ return F.linear(input, self.weight, self.bias)
320
+
321
+
322
+ class Mlp(nn.Module):
323
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
324
+
325
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
326
+ super().__init__()
327
+ out_features = out_features or in_features
328
+ hidden_features = hidden_features or in_features
329
+ drop_probs = to_2tuple(drop)
330
+
331
+ self.fc1 = nn.Linear(in_features, hidden_features)
332
+ self.act = act_layer()
333
+ self.drop1 = nn.Dropout(drop_probs[0])
334
+ self.fc2 = nn.Linear(hidden_features, out_features)
335
+ self.drop2 = nn.Dropout(drop_probs[1])
336
+
337
+ def forward(self, x):
338
+ x = self.fc1(x)
339
+ x = self.act(x)
340
+ x = self.drop1(x)
341
+ x = self.fc2(x)
342
+ x = self.drop2(x)
343
+ return x
344
+
345
+
346
+ def avg_pool3d_same(
347
+ x,
348
+ kernel_size: List[int],
349
+ stride: List[int],
350
+ padding: List[int] = (0, 0, 0),
351
+ ceil_mode: bool = False,
352
+ count_include_pad: bool = True,
353
+ ):
354
+ # FIXME how to deal with count_include_pad vs not for external padding?
355
+ x = pad_same(x, kernel_size, stride)
356
+ return F.avg_pool3d(x, kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
357
+
358
+
359
+ class AvgPool3dSame(nn.AvgPool2d):
360
+ """Tensorflow like 'SAME' wrapper for 2D average pooling"""
361
+
362
+ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
363
+ kernel_size = to_2tuple(kernel_size)
364
+ stride = to_2tuple(stride)
365
+ super(AvgPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
366
+
367
+ def forward(self, x):
368
+ x = pad_same(x, self.kernel_size, self.stride)
369
+ return F.avg_pool3d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
370
+
371
+
372
+ def max_pool3d_same(
373
+ x,
374
+ kernel_size: List[int],
375
+ stride: List[int],
376
+ padding: List[int] = (0, 0, 0),
377
+ dilation: List[int] = (1, 1, 1),
378
+ ceil_mode: bool = False,
379
+ ):
380
+ x = pad_same(x, kernel_size, stride, value=-float("inf"))
381
+ return F.max_pool3d(x, kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
382
+
383
+
384
+ class MaxPool3dSame(nn.MaxPool2d):
385
+ """Tensorflow like 'SAME' wrapper for 3D max pooling"""
386
+
387
+ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
388
+ kernel_size = to_2tuple(kernel_size)
389
+ stride = to_2tuple(stride)
390
+ dilation = to_2tuple(dilation)
391
+ super(MaxPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
392
+
393
+ def forward(self, x):
394
+ x = pad_same(x, self.kernel_size, self.stride, value=-float("inf"))
395
+ return F.max_pool3d(x, self.kernel_size, self.stride, (0, 0, 0), self.dilation, self.ceil_mode)
396
+
397
+
398
+ def create_pool3d(pool_type, kernel_size, stride=None, **kwargs):
399
+ stride = stride or kernel_size
400
+ padding = kwargs.pop("padding", "")
401
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
402
+ if is_dynamic:
403
+ if pool_type == "avg":
404
+ return AvgPool3dSame(kernel_size, stride=stride, **kwargs)
405
+ elif pool_type == "max":
406
+ return MaxPool3dSame(kernel_size, stride=stride, **kwargs)
407
+ else:
408
+ raise AssertionError()
409
+
410
+ # assert False, f"Unsupported pool type {pool_type}"
411
+ else:
412
+ if pool_type == "avg":
413
+ return nn.AvgPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
414
+ elif pool_type == "max":
415
+ return nn.MaxPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
416
+ else:
417
+ raise AssertionError()
418
+
419
+ # assert False, f"Unsupported pool type {pool_type}"
420
+
421
+
422
+ def _float_to_int(x: float) -> int:
423
+ """
424
+ Symbolic tracing helper to substitute for inbuilt `int`.
425
+ Hint: Inbuilt `int` can't accept an argument of type `Proxy`
426
+ """
427
+ return int(x)
428
+
429
+
430
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
431
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
432
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
433
+ def norm_cdf(x):
434
+ # Computes standard normal cumulative distribution function
435
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
436
+
437
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
438
+ warnings.warn(
439
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
440
+ "The distribution of values may be incorrect.",
441
+ stacklevel=2,
442
+ )
443
+
444
+ with torch.no_grad():
445
+ # Values are generated by using a truncated uniform distribution and
446
+ # then using the inverse CDF for the normal distribution.
447
+ # Get upper and lower cdf values
448
+ l = norm_cdf((a - mean) / std)
449
+ u = norm_cdf((b - mean) / std)
450
+
451
+ # Uniformly fill tensor with values from [l, u], then translate to
452
+ # [2l-1, 2u-1].
453
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
454
+
455
+ # Use inverse cdf transform for normal distribution to get truncated
456
+ # standard normal
457
+ tensor.erfinv_()
458
+
459
+ # Transform to proper mean, std
460
+ tensor.mul_(std * math.sqrt(2.0))
461
+ tensor.add_(mean)
462
+
463
+ # Clamp to ensure it's in the proper range
464
+ tensor.clamp_(min=a, max=b)
465
+ return tensor
466
+
467
+
468
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
469
+ r"""Fills the input Tensor with values drawn from a truncated
470
+ normal distribution. The values are effectively drawn from the
471
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
472
+ with values outside :math:`[a, b]` redrawn until they are within
473
+ the bounds. The method used for generating the random values works
474
+ best when :math:`a \leq \text{mean} \leq b`.
475
+ Args:
476
+ tensor: an n-dimensional `torch.Tensor`
477
+ mean: the mean of the normal distribution
478
+ std: the standard deviation of the normal distribution
479
+ a: the minimum cutoff value
480
+ b: the maximum cutoff value
481
+ Examples:
482
+ >>> w = torch.empty(3, 5)
483
+ >>> nn.init.trunc_normal_(w)
484
+ """
485
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
scripts/networks/nest_transformer_3D.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # =========================================================================
4
+ # Adapted from https://github.com/google-research/nested-transformer.
5
+ # which has the following license...
6
+ # https://github.com/pytorch/vision/blob/main/LICENSE
7
+ #
8
+ # BSD 3-Clause License
9
+
10
+
11
+ # Redistribution and use in source and binary forms, with or without
12
+ # modification, are permitted provided that the following conditions are met:
13
+
14
+ # * Redistributions of source code must retain the above copyright notice, this
15
+ # list of conditions and the following disclaimer.
16
+
17
+ # * Redistributions in binary form must reproduce the above copyright notice,
18
+ # this list of conditions and the following disclaimer in the documentation
19
+ # and/or other materials provided with the distribution.
20
+
21
+ # * Neither the name of the copyright holder nor the names of its
22
+ # contributors may be used to endorse or promote products derived from
23
+ # this software without specific prior written permission.
24
+
25
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+
36
+ """ Nested Transformer (NesT) in PyTorch
37
+ A PyTorch implement of Aggregating Nested Transformers as described in:
38
+ 'Aggregating Nested Transformers'
39
+ - https://arxiv.org/abs/2105.12723
40
+ The official Jax code is released and available at https://github.com/google-research/nested-transformer.
41
+ The weights have been converted with convert/convert_nest_flax.py
42
+ Acknowledgments:
43
+ * The paper authors for sharing their research, code, and model weights
44
+ * Ross Wightman's existing code off which I based this
45
+ Copyright 2021 Alexander Soare
46
+
47
+ """
48
+
49
+ import collections.abc
50
+ import logging
51
+ import math
52
+ from functools import partial
53
+ from typing import Callable, Sequence
54
+
55
+ import torch
56
+ import torch.nn.functional as F
57
+ from torch import nn
58
+
59
+ from .nest import DropPath, Mlp, _assert, create_conv3d, create_pool3d, to_ntuple, trunc_normal_
60
+ from .patchEmbed3D import PatchEmbed3D
61
+
62
+ _logger = logging.getLogger(__name__)
63
+
64
+
65
+ class Attention(nn.Module):
66
+ """
67
+ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
68
+ an extra "image block" dim
69
+ """
70
+
71
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ head_dim = dim // num_heads
75
+ self.scale = head_dim**-0.5
76
+
77
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+
82
+ def forward(self, x):
83
+ """
84
+ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
85
+ """
86
+ b, t, n, c = x.shape
87
+ # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
88
+ qkv = self.qkv(x).reshape(b, t, n, 3, self.num_heads, c // self.num_heads).permute(3, 0, 4, 1, 2, 5)
89
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
90
+
91
+ attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
92
+ attn = attn.softmax(dim=-1)
93
+ attn = self.attn_drop(attn)
94
+
95
+ x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(b, t, n, c)
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x # (B, T, N, C)
99
+
100
+
101
+ class TransformerLayer(nn.Module):
102
+ """
103
+ This is much like `.vision_transformer.Block` but:
104
+ - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
105
+ - Uses modified Attention layer that handles the "block" dimension
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim,
111
+ num_heads,
112
+ mlp_ratio=4.0,
113
+ qkv_bias=False,
114
+ drop=0.0,
115
+ attn_drop=0.0,
116
+ drop_path=0.0,
117
+ act_layer=nn.GELU,
118
+ norm_layer=nn.LayerNorm,
119
+ ):
120
+ super().__init__()
121
+ self.norm1 = norm_layer(dim)
122
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
127
+
128
+ def forward(self, x):
129
+ y = self.norm1(x)
130
+ x = x + self.drop_path(self.attn(y))
131
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
132
+ return x
133
+
134
+
135
+ class ConvPool(nn.Module):
136
+ def __init__(self, in_channels, out_channels, norm_layer, pad_type=""):
137
+ super().__init__()
138
+ self.conv = create_conv3d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
139
+ self.norm = norm_layer(out_channels)
140
+ self.pool = create_pool3d("max", kernel_size=3, stride=2, padding=pad_type)
141
+
142
+ def forward(self, x):
143
+ """
144
+ x is expected to have shape (B, C, D, H, W)
145
+ """
146
+ _assert(x.shape[-3] % 2 == 0, "BlockAggregation requires even input spatial dims")
147
+ _assert(x.shape[-2] % 2 == 0, "BlockAggregation requires even input spatial dims")
148
+ _assert(x.shape[-1] % 2 == 0, "BlockAggregation requires even input spatial dims")
149
+
150
+ # print('In ConvPool x : {}'.format(x.shape))
151
+ x = self.conv(x)
152
+ # Layer norm done over channel dim only
153
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
154
+ x = self.pool(x)
155
+ return x # (B, C, D//2, H//2, W//2)
156
+
157
+
158
+ def blockify(x, block_size: int):
159
+ """image to blocks
160
+ Args:
161
+ x (Tensor): with shape (B, D, H, W, C)
162
+ block_size (int): edge length of a single square block in units of D, H, W
163
+ """
164
+ b, d, h, w, c = x.shape
165
+ _assert(d % block_size == 0, "`block_size` must divide input depth evenly")
166
+ _assert(h % block_size == 0, "`block_size` must divide input height evenly")
167
+ _assert(w % block_size == 0, "`block_size` must divide input width evenly")
168
+ grid_depth = d // block_size
169
+ grid_height = h // block_size
170
+ grid_width = w // block_size
171
+ x = x.reshape(b, grid_depth, block_size, grid_height, block_size, grid_width, block_size, c)
172
+
173
+ x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
174
+ b, grid_depth * grid_height * grid_width, -1, c
175
+ ) # shape [2, 512, 27, 128]
176
+
177
+ return x # (B, T, N, C)
178
+
179
+
180
+ # @register_notrace_function # reason: int receives Proxy
181
+ def deblockify(x, block_size: int):
182
+ """blocks to image
183
+ Args:
184
+ x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
185
+ block_size (int): edge length of a single square block in units of desired D, H, W
186
+ """
187
+ b, t, _, c = x.shape
188
+ grid_size = round(math.pow(t, 1 / 3))
189
+ depth = height = width = grid_size * block_size
190
+ x = x.reshape(b, grid_size, grid_size, grid_size, block_size, block_size, block_size, c)
191
+
192
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, depth, height, width, c)
193
+
194
+ return x # (B, D, H, W, C)
195
+
196
+
197
+ class NestLevel(nn.Module):
198
+ """Single hierarchical level of a Nested Transformer"""
199
+
200
+ def __init__(
201
+ self,
202
+ num_blocks,
203
+ block_size,
204
+ seq_length,
205
+ num_heads,
206
+ depth,
207
+ embed_dim,
208
+ prev_embed_dim=None,
209
+ mlp_ratio=4.0,
210
+ qkv_bias=True,
211
+ drop_rate=0.0,
212
+ attn_drop_rate=0.0,
213
+ drop_path_rates: Sequence[int] = (),
214
+ norm_layer=None,
215
+ act_layer=None,
216
+ pad_type="",
217
+ ):
218
+ super().__init__()
219
+ self.block_size = block_size
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
221
+
222
+ if prev_embed_dim is not None:
223
+ self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
224
+ else:
225
+ self.pool = nn.Identity()
226
+
227
+ # Transformer encoder
228
+ if len(drop_path_rates):
229
+ assert len(drop_path_rates) == depth, "Must provide as many drop path rates as there are transformer layers"
230
+ self.transformer_encoder = nn.Sequential(
231
+ *[
232
+ TransformerLayer(
233
+ dim=embed_dim,
234
+ num_heads=num_heads,
235
+ mlp_ratio=mlp_ratio,
236
+ qkv_bias=qkv_bias,
237
+ drop=drop_rate,
238
+ attn_drop=attn_drop_rate,
239
+ drop_path=drop_path_rates[i],
240
+ norm_layer=norm_layer,
241
+ act_layer=act_layer,
242
+ )
243
+ for i in range(depth)
244
+ ]
245
+ )
246
+
247
+ def forward(self, x):
248
+ """
249
+ expects x as (B, C, D, H, W)
250
+ """
251
+ x = self.pool(x)
252
+ x = x.permute(0, 2, 3, 4, 1) # (B, H', W', C), switch to channels last for transformer
253
+
254
+ x = blockify(x, self.block_size) # (B, T, N, C')
255
+ x = x + self.pos_embed
256
+
257
+ x = self.transformer_encoder(x) # (B, ,T, N, C')
258
+
259
+ x = deblockify(x, self.block_size) # (B, D', H', W', C') [2, 24, 24, 24, 128]
260
+ # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
261
+ return x.permute(0, 4, 1, 2, 3) # (B, C, D', H', W')
262
+
263
+
264
+ class NestTransformer3D(nn.Module):
265
+ """Nested Transformer (NesT)
266
+ A PyTorch impl of : `Aggregating Nested Transformers`
267
+ - https://arxiv.org/abs/2105.12723
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ img_size=96,
273
+ in_chans=1,
274
+ patch_size=2,
275
+ num_levels=3,
276
+ embed_dims=(128, 256, 512),
277
+ num_heads=(4, 8, 16),
278
+ depths=(2, 2, 20),
279
+ num_classes=1000,
280
+ mlp_ratio=4.0,
281
+ qkv_bias=True,
282
+ drop_rate=0.0,
283
+ attn_drop_rate=0.0,
284
+ drop_path_rate=0.5,
285
+ norm_layer=None,
286
+ act_layer=None,
287
+ pad_type="",
288
+ weight_init="",
289
+ global_pool="avg",
290
+ ):
291
+ """
292
+ Args:
293
+ img_size (int, tuple): input image size
294
+ in_chans (int): number of input channels
295
+ patch_size (int): patch size
296
+ num_levels (int): number of block hierarchies (T_d in the paper)
297
+ embed_dims (int, tuple): embedding dimensions of each level
298
+ num_heads (int, tuple): number of attention heads for each level
299
+ depths (int, tuple): number of transformer layers for each level
300
+ num_classes (int): number of classes for classification head
301
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
302
+ qkv_bias (bool): enable bias for qkv if True
303
+ drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
304
+ attn_drop_rate (float): attention dropout rate
305
+ drop_path_rate (float): stochastic depth rate
306
+ norm_layer: (nn.Module): normalization layer for transformer layers
307
+ act_layer: (nn.Module): activation layer in MLP of transformer layers
308
+ pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
309
+ weight_init: (str): weight init scheme
310
+ global_pool: (str): type of pooling operation to apply to final feature map
311
+ Notes:
312
+ - Default values follow NesT-B from the original Jax code.
313
+ - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
314
+ - For those following the paper, Table A1 may have errors!
315
+ - https://github.com/google-research/nested-transformer/issues/2
316
+ """
317
+ super().__init__()
318
+
319
+ for param_name in ["embed_dims", "num_heads", "depths"]:
320
+ param_value = locals()[param_name]
321
+ if isinstance(param_value, collections.abc.Sequence):
322
+ assert len(param_value) == num_levels, f"Require `len({param_name}) == num_levels`"
323
+
324
+ embed_dims = to_ntuple(num_levels)(embed_dims)
325
+ num_heads = to_ntuple(num_levels)(num_heads)
326
+ depths = to_ntuple(num_levels)(depths)
327
+ self.num_classes = num_classes
328
+ self.num_features = embed_dims[-1]
329
+ self.feature_info = []
330
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
331
+ act_layer = act_layer or nn.GELU
332
+ self.drop_rate = drop_rate
333
+ self.num_levels = num_levels
334
+ if isinstance(img_size, collections.abc.Sequence):
335
+ assert img_size[0] == img_size[1], "Model only handles square inputs"
336
+ img_size = img_size[0]
337
+ assert img_size % patch_size == 0, "`patch_size` must divide `img_size` evenly"
338
+ self.patch_size = patch_size
339
+
340
+ # Number of blocks at each level
341
+ self.num_blocks = (8 ** torch.arange(num_levels)).flip(0).tolist()
342
+ assert (img_size // patch_size) % round(
343
+ math.pow(self.num_blocks[0], 1 / 3)
344
+ ) == 0, "First level blocks don't fit evenly. Check `img_size`, `patch_size`, and `num_levels`"
345
+
346
+ # Block edge size in units of patches
347
+ # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
348
+ # number of blocks along edge of image
349
+ self.block_size = int((img_size // patch_size) // round(math.pow(self.num_blocks[0], 1 / 3)))
350
+
351
+ # Patch embedding
352
+ self.patch_embed = PatchEmbed3D(
353
+ img_size=[img_size, img_size, img_size],
354
+ patch_size=[patch_size, patch_size, patch_size],
355
+ in_chans=in_chans,
356
+ embed_dim=embed_dims[0],
357
+ )
358
+ self.num_patches = self.patch_embed.num_patches
359
+ self.seq_length = self.num_patches // self.num_blocks[0]
360
+ # Build up each hierarchical level
361
+ levels = []
362
+
363
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
364
+ prev_dim = None
365
+ curr_stride = 4
366
+ for i in range(len(self.num_blocks)):
367
+ dim = embed_dims[i]
368
+ levels.append(
369
+ NestLevel(
370
+ self.num_blocks[i],
371
+ self.block_size,
372
+ self.seq_length,
373
+ num_heads[i],
374
+ depths[i],
375
+ dim,
376
+ prev_dim,
377
+ mlp_ratio,
378
+ qkv_bias,
379
+ drop_rate,
380
+ attn_drop_rate,
381
+ dp_rates[i],
382
+ norm_layer,
383
+ act_layer,
384
+ pad_type=pad_type,
385
+ )
386
+ )
387
+ self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f"levels.{i}")]
388
+ prev_dim = dim
389
+ curr_stride *= 2
390
+
391
+ self.levels = nn.ModuleList([levels[i] for i in range(num_levels)])
392
+
393
+ # Final normalization layer
394
+ self.norm = norm_layer(embed_dims[-1])
395
+
396
+ self.init_weights(weight_init)
397
+
398
+ def init_weights(self, mode=""):
399
+ assert mode in ("nlhb", "")
400
+ head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
401
+ for level in self.levels:
402
+ trunc_normal_(level.pos_embed, std=0.02, a=-2, b=2)
403
+ named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
404
+
405
+ @torch.jit.ignore
406
+ def no_weight_decay(self):
407
+ return {f"level.{i}.pos_embed" for i in range(len(self.levels))}
408
+
409
+ def get_classifier(self):
410
+ return self.head
411
+
412
+ def forward_features(self, x):
413
+ """x shape (B, C, D, H, W)"""
414
+ x = self.patch_embed(x)
415
+
416
+ hidden_states_out = [x]
417
+
418
+ for _, level in enumerate(self.levels):
419
+ x = level(x)
420
+ hidden_states_out.append(x)
421
+ # Layer norm done over channel dim only (to NDHWC and back)
422
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
423
+ return x, hidden_states_out
424
+
425
+ def forward(self, x):
426
+ """x shape (B, C, D, H, W)"""
427
+ x = self.forward_features(x)
428
+
429
+ if self.drop_rate > 0.0:
430
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
431
+ return x
432
+
433
+
434
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
435
+ if not depth_first and include_root:
436
+ fn(module=module, name=name)
437
+ for child_name, child_module in module.named_children():
438
+ child_name = ".".join((name, child_name)) if name else child_name
439
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
440
+ if depth_first and include_root:
441
+ fn(module=module, name=name)
442
+ return module
443
+
444
+
445
+ def _init_nest_weights(module: nn.Module, name: str = "", head_bias: float = 0.0):
446
+ """NesT weight initialization
447
+ Can replicate Jax implementation. Otherwise follows vision_transformer.py
448
+ """
449
+ if isinstance(module, nn.Linear):
450
+ if name.startswith("head"):
451
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
452
+ nn.init.constant_(module.bias, head_bias)
453
+ else:
454
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
455
+ if module.bias is not None:
456
+ nn.init.zeros_(module.bias)
457
+ elif isinstance(module, nn.Conv2d):
458
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
459
+ if module.bias is not None:
460
+ nn.init.zeros_(module.bias)
461
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
462
+ nn.init.zeros_(module.bias)
463
+ nn.init.ones_(module.weight)
464
+
465
+
466
+ def resize_pos_embed(posemb, posemb_new):
467
+ """
468
+ Rescale the grid of position embeddings when loading from state_dict
469
+ Expected shape of position embeddings is (1, T, N, C), and considers only square images
470
+ """
471
+ _logger.info("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
472
+ seq_length_old = posemb.shape[2]
473
+ num_blocks_new, seq_length_new = posemb_new.shape[1:3]
474
+ size_new = int(math.sqrt(num_blocks_new * seq_length_new))
475
+ # First change to (1, C, H, W)
476
+ posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
477
+ posemb = F.interpolate(posemb, size=[size_new, size_new], mode="bicubic", align_corners=False)
478
+ # Now change to new (1, T, N, C)
479
+ posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
480
+ return posemb
481
+
482
+
483
+ def checkpoint_filter_fn(state_dict, model):
484
+ """resize positional embeddings of pretrained weights"""
485
+ pos_embed_keys = [k for k in state_dict.keys() if k.startswith("pos_embed_")]
486
+ for k in pos_embed_keys:
487
+ if state_dict[k].shape != getattr(model, k).shape:
488
+ state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
489
+ return state_dict
scripts/networks/patchEmbed3D.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 - 2021 MONAI Consortium
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+
15
+ import math
16
+ from typing import Sequence, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from monai.utils import optional_import
22
+
23
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
+
25
+
26
+ class PatchEmbeddingBlock(nn.Module):
27
+ """
28
+ A patch embedding block, based on: "Dosovitskiy et al.,
29
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ img_size: Tuple[int, int, int],
36
+ patch_size: Tuple[int, int, int],
37
+ hidden_size: int,
38
+ num_heads: int,
39
+ pos_embed: str,
40
+ dropout_rate: float = 0.0,
41
+ ) -> None:
42
+ """
43
+ Args:
44
+ in_channels: dimension of input channels.
45
+ img_size: dimension of input image.
46
+ patch_size: dimension of patch size.
47
+ hidden_size: dimension of hidden layer.
48
+ num_heads: number of attention heads.
49
+ pos_embed: position embedding layer type.
50
+ dropout_rate: faction of the input units to drop.
51
+
52
+ """
53
+
54
+ super().__init__()
55
+
56
+ if not (0 <= dropout_rate <= 1):
57
+ raise AssertionError("dropout_rate should be between 0 and 1.")
58
+
59
+ if hidden_size % num_heads != 0:
60
+ raise AssertionError("hidden size should be divisible by num_heads.")
61
+
62
+ for m, p in zip(img_size, patch_size):
63
+ if m < p:
64
+ raise AssertionError("patch_size should be smaller than img_size.")
65
+
66
+ if pos_embed not in ["conv", "perceptron"]:
67
+ raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
68
+
69
+ if pos_embed == "perceptron":
70
+ if img_size[0] % patch_size[0] != 0:
71
+ raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
72
+
73
+ self.n_patches = (
74
+ (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
75
+ )
76
+ self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
77
+
78
+ self.pos_embed = pos_embed
79
+ self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
80
+ if self.pos_embed == "conv":
81
+ self.patch_embeddings = nn.Conv3d(
82
+ in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
83
+ )
84
+ elif self.pos_embed == "perceptron":
85
+ self.patch_embeddings = nn.Sequential(
86
+ Rearrange(
87
+ "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
88
+ p1=patch_size[0],
89
+ p2=patch_size[1],
90
+ p3=patch_size[2],
91
+ ),
92
+ nn.Linear(self.patch_dim, hidden_size),
93
+ )
94
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
95
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
96
+ self.dropout = nn.Dropout(dropout_rate)
97
+ self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, m):
101
+ if isinstance(m, nn.Linear):
102
+ self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
103
+ if isinstance(m, nn.Linear) and m.bias is not None:
104
+ nn.init.constant_(m.bias, 0)
105
+ elif isinstance(m, nn.LayerNorm):
106
+ nn.init.constant_(m.bias, 0)
107
+ nn.init.constant_(m.weight, 1.0)
108
+
109
+ def trunc_normal_(self, tensor, mean, std, a, b):
110
+ # From PyTorch official master until it's in a few official releases - RW
111
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
112
+ def norm_cdf(x):
113
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
114
+
115
+ with torch.no_grad():
116
+ l = norm_cdf((a - mean) / std)
117
+ u = norm_cdf((b - mean) / std)
118
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
119
+ tensor.erfinv_()
120
+ tensor.mul_(std * math.sqrt(2.0))
121
+ tensor.add_(mean)
122
+ tensor.clamp_(min=a, max=b)
123
+ return tensor
124
+
125
+ def forward(self, x):
126
+ if self.pos_embed == "conv":
127
+ x = self.patch_embeddings(x)
128
+ x = x.flatten(2)
129
+ x = x.transpose(-1, -2)
130
+ elif self.pos_embed == "perceptron":
131
+ x = self.patch_embeddings(x)
132
+ embeddings = x + self.position_embeddings
133
+ embeddings = self.dropout(embeddings)
134
+ return embeddings
135
+
136
+
137
+ class PatchEmbed3D(nn.Module):
138
+ """Video to Patch Embedding.
139
+
140
+ Args:
141
+ patch_size (int): Patch token size. Default: (2,4,4).
142
+ in_chans (int): Number of input video channels. Default: 3.
143
+ embed_dim (int): Number of linear projection output channels. Default: 96.
144
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ img_size: Sequence[int] = (96, 96, 96),
150
+ patch_size=(4, 4, 4),
151
+ in_chans: int = 1,
152
+ embed_dim: int = 96,
153
+ norm_layer=None,
154
+ ):
155
+ super().__init__()
156
+ self.patch_size = patch_size
157
+
158
+ self.in_chans = in_chans
159
+ self.embed_dim = embed_dim
160
+
161
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
162
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
163
+
164
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
165
+
166
+ if norm_layer is not None:
167
+ self.norm = norm_layer(embed_dim)
168
+ else:
169
+ self.norm = None
170
+
171
+ def forward(self, x):
172
+ """Forward function."""
173
+ # padding
174
+ _, _, d, h, w = x.size()
175
+ if w % self.patch_size[2] != 0:
176
+ x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
177
+ if h % self.patch_size[1] != 0:
178
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
179
+ if d % self.patch_size[0] != 0:
180
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
181
+
182
+ x = self.proj(x) # B C D Wh Ww
183
+ if self.norm is not None:
184
+ d, wh, ww = x.size(2), x.size(3), x.size(4)
185
+ x = x.flatten(2).transpose(1, 2)
186
+ x = self.norm(x)
187
+ x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
188
+ # pdb.set_trace()
189
+
190
+ return x
scripts/networks/unest_base_patch_4.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # limitations under the License.
3
+ """
4
+ The 3D NEST transformer based segmentation model
5
+
6
+ MASI Lab, Vanderbilty University
7
+
8
+
9
+ Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
10
+
11
+
12
+ The NEST code is partly from
13
+
14
+ Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
15
+ Interpretable Visual Understanding
16
+ https://arxiv.org/pdf/2105.12723.pdf
17
+
18
+ """
19
+ from typing import Sequence, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from monai.networks.blocks import Convolution
24
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
25
+ from scripts.networks.nest_transformer_3D import NestTransformer3D
26
+ from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
27
+
28
+
29
+ class UNesT(nn.Module):
30
+ """
31
+ UNesT model implementation
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ img_size: Sequence[int] = (96, 96, 96),
39
+ feature_size: int = 16,
40
+ patch_size: int = 2,
41
+ depths: Sequence[int] = (2, 2, 2, 2),
42
+ num_heads: Sequence[int] = (3, 6, 12, 24),
43
+ embed_dim: Sequence[int] = (128, 256, 512),
44
+ window_size: Sequence[int] = (7, 7, 7),
45
+ norm_name: Union[Tuple, str] = "instance",
46
+ conv_block: bool = False,
47
+ res_block: bool = True,
48
+ dropout_rate: float = 0.0,
49
+ ) -> None:
50
+ """
51
+ Args:
52
+ in_channels: dimension of input channels.
53
+ out_channels: dimension of output channels.
54
+ img_size: dimension of input image.
55
+ feature_size: dimension of network feature size.
56
+ hidden_size: dimension of hidden layer.
57
+ mlp_dim: dimension of feedforward layer.
58
+ num_heads: number of attention heads.
59
+ pos_embed: position embedding layer type.
60
+ norm_name: feature normalization type and arguments.
61
+ conv_block: bool argument to determine if convolutional block is used.
62
+ res_block: bool argument to determine if residual block is used.
63
+ dropout_rate: faction of the input units to drop.
64
+
65
+ Examples:
66
+
67
+ # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
68
+ >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
69
+
70
+ # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
71
+ >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
72
+
73
+ """
74
+
75
+ super().__init__()
76
+ if not (0 <= dropout_rate <= 1):
77
+ raise AssertionError("dropout_rate should be between 0 and 1.")
78
+ self.embed_dim = embed_dim
79
+ self.nestViT = NestTransformer3D(
80
+ img_size=96,
81
+ in_chans=1,
82
+ patch_size=patch_size,
83
+ num_levels=3,
84
+ embed_dims=embed_dim,
85
+ num_heads=num_heads,
86
+ depths=depths,
87
+ num_classes=1000,
88
+ mlp_ratio=4.0,
89
+ qkv_bias=True,
90
+ drop_rate=0.0,
91
+ attn_drop_rate=0.0,
92
+ drop_path_rate=0.5,
93
+ norm_layer=None,
94
+ act_layer=None,
95
+ pad_type="",
96
+ weight_init="",
97
+ global_pool="avg",
98
+ )
99
+ self.encoder1 = UNesTConvBlock(
100
+ spatial_dims=3,
101
+ in_channels=1,
102
+ out_channels=feature_size * 2,
103
+ kernel_size=3,
104
+ stride=1,
105
+ norm_name=norm_name,
106
+ res_block=res_block,
107
+ )
108
+ self.encoder2 = UNestUpBlock(
109
+ spatial_dims=3,
110
+ in_channels=self.embed_dim[0],
111
+ out_channels=feature_size * 4,
112
+ num_layer=1,
113
+ kernel_size=3,
114
+ stride=1,
115
+ upsample_kernel_size=2,
116
+ norm_name=norm_name,
117
+ conv_block=False,
118
+ res_block=False,
119
+ )
120
+
121
+ self.encoder3 = UNesTConvBlock(
122
+ spatial_dims=3,
123
+ in_channels=self.embed_dim[0],
124
+ out_channels=8 * feature_size,
125
+ kernel_size=3,
126
+ stride=1,
127
+ norm_name=norm_name,
128
+ res_block=res_block,
129
+ )
130
+ self.encoder4 = UNesTConvBlock(
131
+ spatial_dims=3,
132
+ in_channels=self.embed_dim[1],
133
+ out_channels=16 * feature_size,
134
+ kernel_size=3,
135
+ stride=1,
136
+ norm_name=norm_name,
137
+ res_block=res_block,
138
+ )
139
+ self.decoder5 = UNesTBlock(
140
+ spatial_dims=3,
141
+ in_channels=2 * self.embed_dim[2],
142
+ out_channels=feature_size * 32,
143
+ stride=1,
144
+ kernel_size=3,
145
+ upsample_kernel_size=2,
146
+ norm_name=norm_name,
147
+ res_block=res_block,
148
+ )
149
+ self.decoder4 = UNesTBlock(
150
+ spatial_dims=3,
151
+ in_channels=self.embed_dim[2],
152
+ out_channels=feature_size * 16,
153
+ stride=1,
154
+ kernel_size=3,
155
+ upsample_kernel_size=2,
156
+ norm_name=norm_name,
157
+ res_block=res_block,
158
+ )
159
+ self.decoder3 = UNesTBlock(
160
+ spatial_dims=3,
161
+ in_channels=feature_size * 16,
162
+ out_channels=feature_size * 8,
163
+ stride=1,
164
+ kernel_size=3,
165
+ upsample_kernel_size=2,
166
+ norm_name=norm_name,
167
+ res_block=res_block,
168
+ )
169
+ self.decoder2 = UNesTBlock(
170
+ spatial_dims=3,
171
+ in_channels=feature_size * 8,
172
+ out_channels=feature_size * 4,
173
+ stride=1,
174
+ kernel_size=3,
175
+ upsample_kernel_size=2,
176
+ norm_name=norm_name,
177
+ res_block=res_block,
178
+ )
179
+ self.decoder1 = UNesTBlock(
180
+ spatial_dims=3,
181
+ in_channels=feature_size * 4,
182
+ out_channels=feature_size * 2,
183
+ stride=1,
184
+ kernel_size=3,
185
+ upsample_kernel_size=2,
186
+ norm_name=norm_name,
187
+ res_block=res_block,
188
+ )
189
+ self.encoder10 = Convolution(
190
+ dimensions=3,
191
+ in_channels=32 * feature_size,
192
+ out_channels=64 * feature_size,
193
+ strides=2,
194
+ adn_ordering="ADN",
195
+ dropout=0.0,
196
+ )
197
+ self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
198
+
199
+ def proj_feat(self, x, hidden_size, feat_size):
200
+ x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
201
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
202
+ return x
203
+
204
+ def load_from(self, weights):
205
+ with torch.no_grad():
206
+ # copy weights from patch embedding
207
+ for i in weights["state_dict"]:
208
+ print(i)
209
+ self.vit.patch_embedding.position_embeddings.copy_(
210
+ weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
211
+ )
212
+ self.vit.patch_embedding.cls_token.copy_(
213
+ weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
214
+ )
215
+ self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
216
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
217
+ )
218
+ self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
219
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
220
+ )
221
+
222
+ # copy weights from encoding blocks (default: num of blocks: 12)
223
+ for bname, block in self.vit.blocks.named_children():
224
+ print(block)
225
+ block.loadFrom(weights, n_block=bname)
226
+ # last norm layer of transformer
227
+ self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
228
+ self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
229
+
230
+ def forward(self, x_in):
231
+ x, hidden_states_out = self.nestViT(x_in)
232
+ enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
233
+ x1 = hidden_states_out[0] # 2, 128, 24, 24, 24 2, 128, 12, 12, 12
234
+ enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48 torch.Size([2, 64, 24, 24, 24])
235
+ x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
236
+ enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24 torch.Size([2, 128, 12, 12, 12])
237
+ x3 = hidden_states_out[2] # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
238
+ enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
239
+ x4 = hidden_states_out[3]
240
+ enc4 = x4 # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
241
+ dec4 = x # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
242
+ dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3 torch.Size([2, 1024, 2, 2, 2])
243
+ dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
244
+ dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
245
+ dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
246
+ dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
247
+ out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
248
+ logits = self.out(out)
249
+ return logits
scripts/networks/unest_block.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
8
+
9
+
10
+ class UNesTBlock(nn.Module):
11
+ """ """
12
+
13
+ def __init__(
14
+ self,
15
+ spatial_dims: int,
16
+ in_channels: int,
17
+ out_channels: int, # type: ignore
18
+ kernel_size: Union[Sequence[int], int],
19
+ stride: Union[Sequence[int], int],
20
+ upsample_kernel_size: Union[Sequence[int], int],
21
+ norm_name: Union[Tuple, str],
22
+ res_block: bool = False,
23
+ ) -> None:
24
+ """
25
+ Args:
26
+ spatial_dims: number of spatial dimensions.
27
+ in_channels: number of input channels.
28
+ out_channels: number of output channels.
29
+ kernel_size: convolution kernel size.
30
+ stride: convolution stride.
31
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
32
+ norm_name: feature normalization type and arguments.
33
+ res_block: bool argument to determine if residual block is used.
34
+
35
+ """
36
+
37
+ super(UNesTBlock, self).__init__()
38
+ upsample_stride = upsample_kernel_size
39
+ self.transp_conv = get_conv_layer(
40
+ spatial_dims,
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=upsample_kernel_size,
44
+ stride=upsample_stride,
45
+ conv_only=True,
46
+ is_transposed=True,
47
+ )
48
+
49
+ if res_block:
50
+ self.conv_block = UnetResBlock(
51
+ spatial_dims,
52
+ out_channels + out_channels,
53
+ out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=1,
56
+ norm_name=norm_name,
57
+ )
58
+ else:
59
+ self.conv_block = UnetBasicBlock( # type: ignore
60
+ spatial_dims,
61
+ out_channels + out_channels,
62
+ out_channels,
63
+ kernel_size=kernel_size,
64
+ stride=1,
65
+ norm_name=norm_name,
66
+ )
67
+
68
+ def forward(self, inp, skip):
69
+ # number of channels for skip should equals to out_channels
70
+ out = self.transp_conv(inp)
71
+ # print(out.shape)
72
+ # print(skip.shape)
73
+ out = torch.cat((out, skip), dim=1)
74
+ out = self.conv_block(out)
75
+ return out
76
+
77
+
78
+ class UNestUpBlock(nn.Module):
79
+ """ """
80
+
81
+ def __init__(
82
+ self,
83
+ spatial_dims: int,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ num_layer: int,
87
+ kernel_size: Union[Sequence[int], int],
88
+ stride: Union[Sequence[int], int],
89
+ upsample_kernel_size: Union[Sequence[int], int],
90
+ norm_name: Union[Tuple, str],
91
+ conv_block: bool = False,
92
+ res_block: bool = False,
93
+ ) -> None:
94
+ """
95
+ Args:
96
+ spatial_dims: number of spatial dimensions.
97
+ in_channels: number of input channels.
98
+ out_channels: number of output channels.
99
+ num_layer: number of upsampling blocks.
100
+ kernel_size: convolution kernel size.
101
+ stride: convolution stride.
102
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
103
+ norm_name: feature normalization type and arguments.
104
+ conv_block: bool argument to determine if convolutional block is used.
105
+ res_block: bool argument to determine if residual block is used.
106
+
107
+ """
108
+
109
+ super().__init__()
110
+
111
+ upsample_stride = upsample_kernel_size
112
+ self.transp_conv_init = get_conv_layer(
113
+ spatial_dims,
114
+ in_channels,
115
+ out_channels,
116
+ kernel_size=upsample_kernel_size,
117
+ stride=upsample_stride,
118
+ conv_only=True,
119
+ is_transposed=True,
120
+ )
121
+ if conv_block:
122
+ if res_block:
123
+ self.blocks = nn.ModuleList(
124
+ [
125
+ nn.Sequential(
126
+ get_conv_layer(
127
+ spatial_dims,
128
+ out_channels,
129
+ out_channels,
130
+ kernel_size=upsample_kernel_size,
131
+ stride=upsample_stride,
132
+ conv_only=True,
133
+ is_transposed=True,
134
+ ),
135
+ UnetResBlock(
136
+ spatial_dims=3,
137
+ in_channels=out_channels,
138
+ out_channels=out_channels,
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ norm_name=norm_name,
142
+ ),
143
+ )
144
+ for i in range(num_layer)
145
+ ]
146
+ )
147
+ else:
148
+ self.blocks = nn.ModuleList(
149
+ [
150
+ nn.Sequential(
151
+ get_conv_layer(
152
+ spatial_dims,
153
+ out_channels,
154
+ out_channels,
155
+ kernel_size=upsample_kernel_size,
156
+ stride=upsample_stride,
157
+ conv_only=True,
158
+ is_transposed=True,
159
+ ),
160
+ UnetBasicBlock(
161
+ spatial_dims=3,
162
+ in_channels=out_channels,
163
+ out_channels=out_channels,
164
+ kernel_size=kernel_size,
165
+ stride=stride,
166
+ norm_name=norm_name,
167
+ ),
168
+ )
169
+ for i in range(num_layer)
170
+ ]
171
+ )
172
+ else:
173
+ self.blocks = nn.ModuleList(
174
+ [
175
+ get_conv_layer(
176
+ spatial_dims,
177
+ out_channels,
178
+ out_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ conv_only=True,
182
+ is_transposed=True,
183
+ )
184
+ for i in range(num_layer)
185
+ ]
186
+ )
187
+
188
+ def forward(self, x):
189
+ x = self.transp_conv_init(x)
190
+ for blk in self.blocks:
191
+ x = blk(x)
192
+ return x
193
+
194
+
195
+ class UNesTConvBlock(nn.Module):
196
+ """
197
+ UNesT block with skip connections
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ spatial_dims: int,
203
+ in_channels: int,
204
+ out_channels: int,
205
+ kernel_size: Union[Sequence[int], int],
206
+ stride: Union[Sequence[int], int],
207
+ norm_name: Union[Tuple, str],
208
+ res_block: bool = False,
209
+ ) -> None:
210
+ """
211
+ Args:
212
+ spatial_dims: number of spatial dimensions.
213
+ in_channels: number of input channels.
214
+ out_channels: number of output channels.
215
+ kernel_size: convolution kernel size.
216
+ stride: convolution stride.
217
+ norm_name: feature normalization type and arguments.
218
+ res_block: bool argument to determine if residual block is used.
219
+
220
+ """
221
+
222
+ super().__init__()
223
+
224
+ if res_block:
225
+ self.layer = UnetResBlock(
226
+ spatial_dims=spatial_dims,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ kernel_size=kernel_size,
230
+ stride=stride,
231
+ norm_name=norm_name,
232
+ )
233
+ else:
234
+ self.layer = UnetBasicBlock( # type: ignore
235
+ spatial_dims=spatial_dims,
236
+ in_channels=in_channels,
237
+ out_channels=out_channels,
238
+ kernel_size=kernel_size,
239
+ stride=stride,
240
+ norm_name=norm_name,
241
+ )
242
+
243
+ def forward(self, inp):
244
+ out = self.layer(inp)
245
+ return out