monai
medical
katielink commited on
Commit
9384dae
1 Parent(s): b0b0fff

Initial release

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/model_autoencoder.ts filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - monai
4
+ - medical
5
+ library_name: monai
6
+ license: apache-2.0
7
+ ---
8
+ # Model Overview
9
+ A pre-trained model for volumetric (3D) Brats MRI 3D Latent Diffusion Generative Model.
10
+
11
+ This model is trained on BraTS 2016 and 2017 data from [Medical Decathlon](http://medicaldecathlon.com/), using the Latent diffusion model [1].
12
+
13
+ ![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_network.png)
14
+
15
+ This model is a generator for creating images like the Flair MRIs based on BraTS 2016 and 2017 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output. The `train_autoencoder.json` file describes the training process of the variational autoencoder with GAN loss. The `train_diffusion.json` file describes the training process of the 3D latent diffusion model.
16
+
17
+ In this bundle, the autoencoder uses perceptual loss, which is based on ResNet50 with pre-trained weights (the network is frozen and will not be trained in the bundle). In default, the `pretrained` parameter is specified as `False` in `train_autoencoder.json`. To ensure correct training, changing the default settings is necessary. There are two ways to utilize pretrained weights:
18
+ 1. if set `pretrained` to `True`, ImageNet pretrained weights from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#ResNet50_Weights) will be used. However, the weights are for non-commercial use only.
19
+ 2. if set `pretrained` to `True` and specifies the `perceptual_loss_model_weights_path` parameter, users are able to load weights from a local path. This is the way this bundle used to train, and the pre-trained weights are from some internal data.
20
+
21
+ Please note that each user is responsible for checking the data source of the pre-trained models, the applicable licenses, and determining if suitable for the intended use.
22
+
23
+ #### Example synthetic image
24
+ An example result from inference is shown below:
25
+ ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_example_generation.png)
26
+
27
+ **This is a demonstration network meant to just show the training process for this sort of network with MONAI. To achieve better performance, users need to use larger dataset like [Brats 2021](https://www.synapse.org/#!Synapse:syn25829067/wiki/610865) and have GPU with memory larger than 32G to enable larger networks and attention layers.**
28
+
29
+ ## MONAI Generative Model Dependencies
30
+ [MONAI generative models](https://github.com/Project-MONAI/GenerativeModels) can be installed by
31
+ ```
32
+ pip install lpips==0.1.4
33
+ git clone https://github.com/Project-MONAI/GenerativeModels.git
34
+ cd GenerativeModels/
35
+ git checkout f969c24f88d013dc0045fb7b2885a01fb219992b
36
+ python setup.py install
37
+ cd ..
38
+ ```
39
+
40
+ ## Data
41
+ The training data is BraTS 2016 and 2017 from the Medical Segmentation Decathalon. Users can find more details on the dataset (`Task01_BrainTumour`) at http://medicaldecathlon.com/.
42
+
43
+ - Target: Image Generation
44
+ - Task: Synthesis
45
+ - Modality: MRI
46
+ - Size: 388 3D volumes (1 channel used)
47
+
48
+ ## Training Configuration
49
+ If you have a GPU with less than 32G of memory, you may need to decrease the batch size when training. To do so, modify the `train_batch_size` parameter in the [configs/train_autoencoder.json](../configs/train_autoencoder.json) and [configs/train_diffusion.json](../configs/train_diffusion.json) configuration files.
50
+
51
+ ### Training Configuration of Autoencoder
52
+ The autoencoder was trained using the following configuration:
53
+
54
+ - GPU: at least 32GB GPU memory
55
+ - Actual Model Input: 112 x 128 x 80
56
+ - AMP: False
57
+ - Optimizer: Adam
58
+ - Learning Rate: 1e-5
59
+ - Loss: L1 loss, perceptual loss, KL divergence loss, adversarial loss, GAN BCE loss
60
+
61
+ #### Input
62
+ 1 channel 3D MRI Flair patches
63
+
64
+ #### Output
65
+ - 1 channel 3D MRI reconstructed patches
66
+ - 8 channel mean of latent features
67
+ - 8 channel standard deviation of latent features
68
+
69
+ ### Training Configuration of Diffusion Model
70
+ The latent diffusion model was trained using the following configuration:
71
+
72
+ - GPU: at least 32GB GPU memory
73
+ - Actual Model Input: 36 x 44 x 28
74
+ - AMP: False
75
+ - Optimizer: Adam
76
+ - Learning Rate: 1e-5
77
+ - Loss: MSE loss
78
+
79
+ #### Training Input
80
+ - 8 channel noisy latent features
81
+ - an int that indicates the time step
82
+
83
+ #### Training Output
84
+ 8 channel predicted added noise
85
+
86
+ #### Inference Input
87
+ 8 channel noise
88
+
89
+ #### Inference Output
90
+ 8 channel denoised latent features
91
+
92
+ ### Memory Consumption Warning
93
+
94
+ If you face memory issues with data loading, you can lower the caching rate `cache_rate` in the configurations within range [0, 1] to minimize the System RAM requirements.
95
+
96
+ ## Performance
97
+
98
+ #### Training Loss
99
+ ![A graph showing the autoencoder training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_autoencoder_loss.png)
100
+
101
+ ![A graph showing the latent diffusion training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_diffusion_loss.png)
102
+
103
+ ## MONAI Bundle Commands
104
+
105
+ In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.
106
+
107
+ For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html).
108
+
109
+ ### Execute Autoencoder Training
110
+
111
+ #### Execute Autoencoder Training on single GPU
112
+
113
+ ```
114
+ python -m monai.bundle run --config_file configs/train_autoencoder.json
115
+ ```
116
+
117
+ Please note that if the default dataset path is not modified with the actual path (it should be the path that contains `Task01_BrainTumour`) in the bundle config files, you can also override it by using `--dataset_dir`:
118
+
119
+ ```
120
+ python -m monai.bundle run --config_file configs/train_autoencoder.json --dataset_dir <actual dataset path>
121
+ ```
122
+
123
+ #### Override the `train` config to execute multi-GPU training for Autoencoder
124
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
125
+
126
+ ```
127
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/multi_gpu_train_autoencoder.json']" --lr 8e-5
128
+ ```
129
+
130
+ #### Check the Autoencoder Training result
131
+ The following code generates a reconstructed image from a random input image.
132
+ We can visualize it to see if the autoencoder is trained correctly.
133
+ ```
134
+ python -m monai.bundle run --config_file configs/inference_autoencoder.json
135
+ ```
136
+
137
+ An example of reconstructed image from inference is shown below. If the autoencoder is trained correctly, the reconstructed image should look similar to original image.
138
+
139
+ ![Example reconstructed image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_recon_example.jpg)
140
+
141
+
142
+ ### Execute Latent Diffusion Training
143
+
144
+ #### Execute Latent Diffusion Model Training on single GPU
145
+ After training the autoencoder, run the following command to train the latent diffusion model. This command will print out the scale factor of the latent feature space. If your autoencoder is well trained, this value should be close to 1.0.
146
+
147
+ ```
148
+ python -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json']"
149
+ ```
150
+
151
+ #### Override the `train` config to execute multi-GPU training for Latent Diffusion Model
152
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
153
+
154
+ ```
155
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json','configs/multi_gpu_train_autoencoder.json','configs/multi_gpu_train_diffusion.json']" --lr 8e-5
156
+ ```
157
+
158
+ #### Execute inference
159
+ The following code generates a synthetic image from a random sampled noise.
160
+ ```
161
+ python -m monai.bundle run --config_file configs/inference.json
162
+ ```
163
+
164
+ #### Export checkpoint to TorchScript file
165
+
166
+ The Autoencoder can be exported into a TorchScript file.
167
+
168
+ ```
169
+ python -m monai.bundle ckpt_export autoencoder_def --filepath models/model_autoencoder.ts --ckpt_file models/model_autoencoder.pt --meta_file configs/metadata.json --config_file configs/inference.json
170
+ ```
171
+
172
+ # References
173
+ [1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf
174
+
175
+ # License
176
+ Copyright (c) MONAI Consortium
177
+
178
+ Licensed under the Apache License, Version 2.0 (the "License");
179
+ you may not use this file except in compliance with the License.
180
+ You may obtain a copy of the License at
181
+
182
+ http://www.apache.org/licenses/LICENSE-2.0
183
+
184
+ Unless required by applicable law or agreed to in writing, software
185
+ distributed under the License is distributed on an "AS IS" BASIS,
186
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
187
+ See the License for the specific language governing permissions and
188
+ limitations under the License.
configs/inference.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import torch",
4
+ "$from datetime import datetime",
5
+ "$from pathlib import Path"
6
+ ],
7
+ "bundle_root": ".",
8
+ "model_dir": "$@bundle_root + '/models'",
9
+ "output_dir": "$@bundle_root + '/output'",
10
+ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
11
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
12
+ "output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')",
13
+ "spatial_dims": 3,
14
+ "image_channels": 1,
15
+ "latent_channels": 8,
16
+ "latent_shape": [
17
+ 8,
18
+ 36,
19
+ 44,
20
+ 28
21
+ ],
22
+ "autoencoder_def": {
23
+ "_target_": "generative.networks.nets.AutoencoderKL",
24
+ "spatial_dims": "@spatial_dims",
25
+ "in_channels": "@image_channels",
26
+ "out_channels": "@image_channels",
27
+ "latent_channels": "@latent_channels",
28
+ "num_channels": [
29
+ 64,
30
+ 128,
31
+ 256
32
+ ],
33
+ "num_res_blocks": 2,
34
+ "norm_num_groups": 32,
35
+ "norm_eps": 1e-06,
36
+ "attention_levels": [
37
+ false,
38
+ false,
39
+ false
40
+ ],
41
+ "with_encoder_nonlocal_attn": false,
42
+ "with_decoder_nonlocal_attn": false
43
+ },
44
+ "network_def": {
45
+ "_target_": "generative.networks.nets.DiffusionModelUNet",
46
+ "spatial_dims": "@spatial_dims",
47
+ "in_channels": "@latent_channels",
48
+ "out_channels": "@latent_channels",
49
+ "num_channels": [
50
+ 256,
51
+ 256,
52
+ 512
53
+ ],
54
+ "attention_levels": [
55
+ false,
56
+ true,
57
+ true
58
+ ],
59
+ "num_head_channels": [
60
+ 0,
61
+ 64,
62
+ 64
63
+ ],
64
+ "num_res_blocks": 2
65
+ },
66
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
67
+ "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
68
+ "autoencoder": "$@autoencoder_def.to(@device)",
69
+ "load_diffusion_path": "$@model_dir + '/model.pt'",
70
+ "load_diffusion": "$@network_def.load_state_dict(torch.load(@load_diffusion_path))",
71
+ "diffusion": "$@network_def.to(@device)",
72
+ "noise_scheduler": {
73
+ "_target_": "generative.networks.schedulers.DDIMScheduler",
74
+ "_requires_": [
75
+ "@load_diffusion",
76
+ "@load_autoencoder"
77
+ ],
78
+ "num_train_timesteps": 1000,
79
+ "beta_start": 0.0015,
80
+ "beta_end": 0.0195,
81
+ "beta_schedule": "scaled_linear",
82
+ "clip_sample": false
83
+ },
84
+ "noise": "$torch.randn([1]+@latent_shape).to(@device)",
85
+ "set_timesteps": "$@noise_scheduler.set_timesteps(num_inference_steps=50)",
86
+ "inferer": {
87
+ "_target_": "scripts.ldm_sampler.LDMSampler",
88
+ "_requires_": "@set_timesteps"
89
+ },
90
+ "sample": "$@inferer.sampling_fn(@noise, @autoencoder, @diffusion, @noise_scheduler)",
91
+ "saver": {
92
+ "_target_": "SaveImage",
93
+ "_requires_": "@create_output_dir",
94
+ "output_dir": "@output_dir",
95
+ "output_postfix": "@output_postfix"
96
+ },
97
+ "generated_image": "$@sample",
98
+ "run": [
99
+ "$@saver(@generated_image[0])"
100
+ ]
101
+ }
configs/inference_autoencoder.json ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import torch",
4
+ "$from datetime import datetime",
5
+ "$from pathlib import Path"
6
+ ],
7
+ "bundle_root": ".",
8
+ "model_dir": "$@bundle_root + '/models'",
9
+ "dataset_dir": "@bundle_root",
10
+ "output_dir": "$@bundle_root + '/output'",
11
+ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
12
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
13
+ "output_orig_postfix": "recon",
14
+ "output_recon_postfix": "orig",
15
+ "channel": 0,
16
+ "spacing": [
17
+ 1.1,
18
+ 1.1,
19
+ 1.1
20
+ ],
21
+ "spatial_dims": 3,
22
+ "image_channels": 1,
23
+ "latent_channels": 8,
24
+ "infer_patch_size": [
25
+ 144,
26
+ 176,
27
+ 112
28
+ ],
29
+ "autoencoder_def": {
30
+ "_target_": "generative.networks.nets.AutoencoderKL",
31
+ "spatial_dims": "@spatial_dims",
32
+ "in_channels": "@image_channels",
33
+ "out_channels": "@image_channels",
34
+ "latent_channels": "@latent_channels",
35
+ "num_channels": [
36
+ 64,
37
+ 128,
38
+ 256
39
+ ],
40
+ "num_res_blocks": 2,
41
+ "norm_num_groups": 32,
42
+ "norm_eps": 1e-06,
43
+ "attention_levels": [
44
+ false,
45
+ false,
46
+ false
47
+ ],
48
+ "with_encoder_nonlocal_attn": false,
49
+ "with_decoder_nonlocal_attn": false
50
+ },
51
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
52
+ "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
53
+ "autoencoder": "$@autoencoder_def.to(@device)",
54
+ "preprocessing_transforms": [
55
+ {
56
+ "_target_": "LoadImaged",
57
+ "keys": "image"
58
+ },
59
+ {
60
+ "_target_": "EnsureChannelFirstd",
61
+ "keys": "image"
62
+ },
63
+ {
64
+ "_target_": "Lambdad",
65
+ "keys": "image",
66
+ "func": "$lambda x: x[@channel, :, :, :]"
67
+ },
68
+ {
69
+ "_target_": "AddChanneld",
70
+ "keys": "image"
71
+ },
72
+ {
73
+ "_target_": "EnsureTyped",
74
+ "keys": "image"
75
+ },
76
+ {
77
+ "_target_": "Orientationd",
78
+ "keys": "image",
79
+ "axcodes": "RAS"
80
+ },
81
+ {
82
+ "_target_": "Spacingd",
83
+ "keys": "image",
84
+ "pixdim": "@spacing",
85
+ "mode": "bilinear"
86
+ }
87
+ ],
88
+ "crop_transforms": [
89
+ {
90
+ "_target_": "CenterSpatialCropd",
91
+ "keys": "image",
92
+ "roi_size": "@infer_patch_size"
93
+ }
94
+ ],
95
+ "final_transforms": [
96
+ {
97
+ "_target_": "ScaleIntensityRangePercentilesd",
98
+ "keys": "image",
99
+ "lower": 0,
100
+ "upper": 99.5,
101
+ "b_min": 0,
102
+ "b_max": 1
103
+ }
104
+ ],
105
+ "preprocessing": {
106
+ "_target_": "Compose",
107
+ "transforms": "$@preprocessing_transforms + @crop_transforms + @final_transforms"
108
+ },
109
+ "dataset": {
110
+ "_target_": "monai.apps.DecathlonDataset",
111
+ "root_dir": "@dataset_dir",
112
+ "task": "Task01_BrainTumour",
113
+ "section": "validation",
114
+ "cache_rate": 0.0,
115
+ "num_workers": 8,
116
+ "download": false,
117
+ "transform": "@preprocessing"
118
+ },
119
+ "dataloader": {
120
+ "_target_": "DataLoader",
121
+ "dataset": "@dataset",
122
+ "batch_size": 1,
123
+ "shuffle": true,
124
+ "num_workers": 0
125
+ },
126
+ "saver_orig": {
127
+ "_target_": "SaveImage",
128
+ "_requires_": "@create_output_dir",
129
+ "output_dir": "@output_dir",
130
+ "output_postfix": "@output_orig_postfix",
131
+ "resample": false,
132
+ "padding_mode": "zeros"
133
+ },
134
+ "saver_recon": {
135
+ "_target_": "SaveImage",
136
+ "_requires_": "@create_output_dir",
137
+ "output_dir": "@output_dir",
138
+ "output_postfix": "@output_recon_postfix",
139
+ "resample": false,
140
+ "padding_mode": "zeros"
141
+ },
142
+ "input_img": "$monai.utils.first(@dataloader)['image'].to(@device)",
143
+ "recon_img": "$@autoencoder(@input_img)[0][0]",
144
+ "run": [
145
+ "$@load_autoencoder",
146
+ "$@saver_orig(@input_img[0][0])",
147
+ "$@saver_recon(@recon_img)"
148
+ ]
149
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20230507.json",
3
+ "version": "1.0.0",
4
+ "changelog": {
5
+ "1.0.0": "Initial release"
6
+ },
7
+ "monai_version": "1.2.0rc5",
8
+ "pytorch_version": "1.13.1",
9
+ "numpy_version": "1.22.2",
10
+ "optional_packages_version": {
11
+ "nibabel": "5.1.0",
12
+ "lpips": "0.1.4"
13
+ },
14
+ "name": "BraTS MRI image latent diffusion generation",
15
+ "task": "BraTS MRI image synthesis",
16
+ "description": "A generative model for creating 3D brain MRI from Gaussian noise based on BraTS dataset",
17
+ "authors": "MONAI team",
18
+ "copyright": "Copyright (c) MONAI Consortium",
19
+ "data_source": "http://medicaldecathlon.com/",
20
+ "data_type": "nibabel",
21
+ "image_classes": "Flair brain MRI with 1.1x1.1x1.1 mm voxel size",
22
+ "eval_metrics": {},
23
+ "intended_use": "This is a research tool/prototype and not to be used clinically",
24
+ "references": [],
25
+ "autoencoder_data_format": {
26
+ "inputs": {
27
+ "image": {
28
+ "type": "image",
29
+ "format": "image",
30
+ "num_channels": 1,
31
+ "spatial_shape": [
32
+ 112,
33
+ 128,
34
+ 80
35
+ ],
36
+ "dtype": "float32",
37
+ "value_range": [
38
+ 0,
39
+ 1
40
+ ],
41
+ "is_patch_data": true
42
+ }
43
+ },
44
+ "outputs": {
45
+ "pred": {
46
+ "type": "image",
47
+ "format": "image",
48
+ "num_channels": 1,
49
+ "spatial_shape": [
50
+ 112,
51
+ 128,
52
+ 80
53
+ ],
54
+ "dtype": "float32",
55
+ "value_range": [
56
+ 0,
57
+ 1
58
+ ],
59
+ "is_patch_data": true,
60
+ "channel_def": {
61
+ "0": "image"
62
+ }
63
+ }
64
+ }
65
+ },
66
+ "generator_data_format": {
67
+ "inputs": {
68
+ "latent": {
69
+ "type": "noise",
70
+ "format": "image",
71
+ "num_channels": 8,
72
+ "spatial_shape": [
73
+ 36,
74
+ 44,
75
+ 28
76
+ ],
77
+ "dtype": "float32",
78
+ "value_range": [
79
+ 0,
80
+ 1
81
+ ],
82
+ "is_patch_data": true
83
+ },
84
+ "condition": {
85
+ "type": "timesteps",
86
+ "format": "timesteps",
87
+ "num_channels": 1,
88
+ "spatial_shape": [],
89
+ "dtype": "long",
90
+ "value_range": [
91
+ 0,
92
+ 1000
93
+ ],
94
+ "is_patch_data": false
95
+ }
96
+ },
97
+ "outputs": {
98
+ "pred": {
99
+ "type": "feature",
100
+ "format": "image",
101
+ "num_channels": 8,
102
+ "spatial_shape": [
103
+ 36,
104
+ 44,
105
+ 28
106
+ ],
107
+ "dtype": "float32",
108
+ "value_range": [
109
+ 0,
110
+ 1
111
+ ],
112
+ "is_patch_data": true,
113
+ "channel_def": {
114
+ "0": "image"
115
+ }
116
+ }
117
+ }
118
+ }
119
+ }
configs/multi_gpu_train_autoencoder.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
3
+ "gnetwork": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@autoencoder_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "dnetwork": {
11
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
12
+ "module": "$@discriminator_def.to(@device)",
13
+ "device_ids": [
14
+ "@device"
15
+ ]
16
+ },
17
+ "train#sampler": {
18
+ "_target_": "DistributedSampler",
19
+ "dataset": "@train#dataset",
20
+ "even_divisible": true,
21
+ "shuffle": true
22
+ },
23
+ "train#dataloader#sampler": "@train#sampler",
24
+ "train#dataloader#shuffle": false,
25
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
26
+ "initialize": [
27
+ "$import torch.distributed as dist",
28
+ "$dist.is_initialized() or dist.init_process_group(backend='nccl')",
29
+ "$torch.cuda.set_device(@device)",
30
+ "$monai.utils.set_determinism(seed=123)",
31
+ "$import logging",
32
+ "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)"
33
+ ],
34
+ "run": [
35
+ "$@train#trainer.run()"
36
+ ],
37
+ "finalize": [
38
+ "$dist.is_initialized() and dist.destroy_process_group()"
39
+ ]
40
+ }
configs/multi_gpu_train_diffusion.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "diffusion": {
3
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
4
+ "module": "$@network_def.to(@device)",
5
+ "device_ids": [
6
+ "@device"
7
+ ],
8
+ "find_unused_parameters": true
9
+ },
10
+ "run": [
11
+ "@load_autoencoder",
12
+ "$@autoencoder.eval()",
13
+ "$print('scale factor:',@scale_factor)",
14
+ "$@train#trainer.run()"
15
+ ]
16
+ }
configs/train_autoencoder.json ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import functools",
4
+ "$import glob",
5
+ "$import scripts"
6
+ ],
7
+ "bundle_root": ".",
8
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
9
+ "ckpt_dir": "$@bundle_root + '/models'",
10
+ "tf_dir": "$@bundle_root + '/eval'",
11
+ "dataset_dir": "/workspace/data/medical",
12
+ "pretrained": false,
13
+ "perceptual_loss_model_weights_path": null,
14
+ "train_batch_size": 2,
15
+ "lr": 1e-05,
16
+ "train_patch_size": [
17
+ 112,
18
+ 128,
19
+ 80
20
+ ],
21
+ "channel": 0,
22
+ "spacing": [
23
+ 1.1,
24
+ 1.1,
25
+ 1.1
26
+ ],
27
+ "spatial_dims": 3,
28
+ "image_channels": 1,
29
+ "latent_channels": 8,
30
+ "discriminator_def": {
31
+ "_target_": "generative.networks.nets.PatchDiscriminator",
32
+ "spatial_dims": "@spatial_dims",
33
+ "num_layers_d": 3,
34
+ "num_channels": 32,
35
+ "in_channels": 1,
36
+ "out_channels": 1,
37
+ "norm": "INSTANCE"
38
+ },
39
+ "autoencoder_def": {
40
+ "_target_": "generative.networks.nets.AutoencoderKL",
41
+ "spatial_dims": "@spatial_dims",
42
+ "in_channels": "@image_channels",
43
+ "out_channels": "@image_channels",
44
+ "latent_channels": "@latent_channels",
45
+ "num_channels": [
46
+ 64,
47
+ 128,
48
+ 256
49
+ ],
50
+ "num_res_blocks": 2,
51
+ "norm_num_groups": 32,
52
+ "norm_eps": 1e-06,
53
+ "attention_levels": [
54
+ false,
55
+ false,
56
+ false
57
+ ],
58
+ "with_encoder_nonlocal_attn": false,
59
+ "with_decoder_nonlocal_attn": false
60
+ },
61
+ "perceptual_loss_def": {
62
+ "_target_": "generative.losses.PerceptualLoss",
63
+ "spatial_dims": "@spatial_dims",
64
+ "network_type": "resnet50",
65
+ "is_fake_3d": true,
66
+ "fake_3d_ratio": 0.2,
67
+ "pretrained": "@pretrained",
68
+ "pretrained_path": "@perceptual_loss_model_weights_path",
69
+ "pretrained_state_dict_key": "state_dict"
70
+ },
71
+ "dnetwork": "$@discriminator_def.to(@device)",
72
+ "gnetwork": "$@autoencoder_def.to(@device)",
73
+ "loss_perceptual": "$@perceptual_loss_def.to(@device)",
74
+ "doptimizer": {
75
+ "_target_": "torch.optim.Adam",
76
+ "params": "$@dnetwork.parameters()",
77
+ "lr": "@lr"
78
+ },
79
+ "goptimizer": {
80
+ "_target_": "torch.optim.Adam",
81
+ "params": "$@gnetwork.parameters()",
82
+ "lr": "@lr"
83
+ },
84
+ "preprocessing_transforms": [
85
+ {
86
+ "_target_": "LoadImaged",
87
+ "keys": "image"
88
+ },
89
+ {
90
+ "_target_": "EnsureChannelFirstd",
91
+ "keys": "image"
92
+ },
93
+ {
94
+ "_target_": "Lambdad",
95
+ "keys": "image",
96
+ "func": "$lambda x: x[@channel, :, :, :]"
97
+ },
98
+ {
99
+ "_target_": "AddChanneld",
100
+ "keys": "image"
101
+ },
102
+ {
103
+ "_target_": "EnsureTyped",
104
+ "keys": "image"
105
+ },
106
+ {
107
+ "_target_": "Orientationd",
108
+ "keys": "image",
109
+ "axcodes": "RAS"
110
+ },
111
+ {
112
+ "_target_": "Spacingd",
113
+ "keys": "image",
114
+ "pixdim": "@spacing",
115
+ "mode": "bilinear"
116
+ }
117
+ ],
118
+ "final_transforms": [
119
+ {
120
+ "_target_": "ScaleIntensityRangePercentilesd",
121
+ "keys": "image",
122
+ "lower": 0,
123
+ "upper": 99.5,
124
+ "b_min": 0,
125
+ "b_max": 1
126
+ }
127
+ ],
128
+ "train": {
129
+ "crop_transforms": [
130
+ {
131
+ "_target_": "RandSpatialCropd",
132
+ "keys": "image",
133
+ "roi_size": "@train_patch_size",
134
+ "random_size": false
135
+ }
136
+ ],
137
+ "preprocessing": {
138
+ "_target_": "Compose",
139
+ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
140
+ },
141
+ "dataset": {
142
+ "_target_": "monai.apps.DecathlonDataset",
143
+ "root_dir": "@dataset_dir",
144
+ "task": "Task01_BrainTumour",
145
+ "section": "training",
146
+ "cache_rate": 1.0,
147
+ "num_workers": 8,
148
+ "download": false,
149
+ "transform": "@train#preprocessing"
150
+ },
151
+ "dataloader": {
152
+ "_target_": "DataLoader",
153
+ "dataset": "@train#dataset",
154
+ "batch_size": "@train_batch_size",
155
+ "shuffle": true,
156
+ "num_workers": 0
157
+ },
158
+ "handlers": [
159
+ {
160
+ "_target_": "CheckpointSaver",
161
+ "save_dir": "@ckpt_dir",
162
+ "save_dict": {
163
+ "model": "@gnetwork"
164
+ },
165
+ "save_interval": 0,
166
+ "save_final": true,
167
+ "epoch_level": true,
168
+ "final_filename": "model_autoencoder.pt"
169
+ },
170
+ {
171
+ "_target_": "StatsHandler",
172
+ "tag_name": "train_loss",
173
+ "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
174
+ },
175
+ {
176
+ "_target_": "TensorBoardStatsHandler",
177
+ "log_dir": "@tf_dir",
178
+ "tag_name": "train_loss",
179
+ "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
180
+ }
181
+ ],
182
+ "trainer": {
183
+ "_target_": "scripts.ldm_trainer.VaeGanTrainer",
184
+ "device": "@device",
185
+ "max_epochs": 1500,
186
+ "train_data_loader": "@train#dataloader",
187
+ "g_network": "@gnetwork",
188
+ "g_optimizer": "@goptimizer",
189
+ "g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork, loss_perceptual=@loss_perceptual)",
190
+ "d_network": "@dnetwork",
191
+ "d_optimizer": "@doptimizer",
192
+ "d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)",
193
+ "d_train_steps": 5,
194
+ "g_update_latents": true,
195
+ "latent_shape": "@latent_channels",
196
+ "key_train_metric": "$None",
197
+ "train_handlers": "@train#handlers"
198
+ }
199
+ },
200
+ "initialize": [
201
+ "$monai.utils.set_determinism(seed=0)"
202
+ ],
203
+ "run": [
204
+ "$@train#trainer.run()"
205
+ ]
206
+ }
configs/train_diffusion.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt_dir": "$@bundle_root + '/models'",
3
+ "train_batch_size": 4,
4
+ "lr": 1e-05,
5
+ "train_patch_size": [
6
+ 144,
7
+ 176,
8
+ 112
9
+ ],
10
+ "latent_shape": [
11
+ "@latent_channels",
12
+ 36,
13
+ 44,
14
+ 28
15
+ ],
16
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
17
+ "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
18
+ "autoencoder": "$@autoencoder_def.to(@device)",
19
+ "network_def": {
20
+ "_target_": "generative.networks.nets.DiffusionModelUNet",
21
+ "spatial_dims": "@spatial_dims",
22
+ "in_channels": "@latent_channels",
23
+ "out_channels": "@latent_channels",
24
+ "num_channels": [
25
+ 256,
26
+ 256,
27
+ 512
28
+ ],
29
+ "attention_levels": [
30
+ false,
31
+ true,
32
+ true
33
+ ],
34
+ "num_head_channels": [
35
+ 0,
36
+ 64,
37
+ 64
38
+ ],
39
+ "num_res_blocks": 2
40
+ },
41
+ "diffusion": "$@network_def.to(@device)",
42
+ "optimizer": {
43
+ "_target_": "torch.optim.Adam",
44
+ "params": "$@diffusion.parameters()",
45
+ "lr": "@lr"
46
+ },
47
+ "lr_scheduler": {
48
+ "_target_": "torch.optim.lr_scheduler.MultiStepLR",
49
+ "optimizer": "@optimizer",
50
+ "milestones": [
51
+ 100,
52
+ 1000
53
+ ],
54
+ "gamma": 0.1
55
+ },
56
+ "scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)",
57
+ "noise_scheduler": {
58
+ "_target_": "generative.networks.schedulers.DDPMScheduler",
59
+ "_requires_": [
60
+ "@load_autoencoder"
61
+ ],
62
+ "beta_schedule": "scaled_linear",
63
+ "num_train_timesteps": 1000,
64
+ "beta_start": 0.0015,
65
+ "beta_end": 0.0195
66
+ },
67
+ "inferer": {
68
+ "_target_": "generative.inferers.LatentDiffusionInferer",
69
+ "scheduler": "@noise_scheduler",
70
+ "scale_factor": "@scale_factor"
71
+ },
72
+ "loss": {
73
+ "_target_": "torch.nn.MSELoss"
74
+ },
75
+ "train": {
76
+ "crop_transforms": [
77
+ {
78
+ "_target_": "CenterSpatialCropd",
79
+ "keys": "image",
80
+ "roi_size": "@train_patch_size"
81
+ }
82
+ ],
83
+ "preprocessing": {
84
+ "_target_": "Compose",
85
+ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
86
+ },
87
+ "dataset": {
88
+ "_target_": "monai.apps.DecathlonDataset",
89
+ "root_dir": "@dataset_dir",
90
+ "task": "Task01_BrainTumour",
91
+ "section": "training",
92
+ "cache_rate": 1.0,
93
+ "num_workers": 8,
94
+ "download": false,
95
+ "transform": "@train#preprocessing"
96
+ },
97
+ "dataloader": {
98
+ "_target_": "DataLoader",
99
+ "dataset": "@train#dataset",
100
+ "batch_size": "@train_batch_size",
101
+ "shuffle": true,
102
+ "num_workers": 0
103
+ },
104
+ "handlers": [
105
+ {
106
+ "_target_": "LrScheduleHandler",
107
+ "lr_scheduler": "@lr_scheduler",
108
+ "print_lr": true
109
+ },
110
+ {
111
+ "_target_": "CheckpointSaver",
112
+ "save_dir": "@ckpt_dir",
113
+ "save_dict": {
114
+ "model": "@diffusion"
115
+ },
116
+ "save_interval": 0,
117
+ "save_final": true,
118
+ "epoch_level": true,
119
+ "final_filename": "model.pt"
120
+ },
121
+ {
122
+ "_target_": "StatsHandler",
123
+ "tag_name": "train_diffusion_loss",
124
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
125
+ },
126
+ {
127
+ "_target_": "TensorBoardStatsHandler",
128
+ "log_dir": "@tf_dir",
129
+ "tag_name": "train_diffusion_loss",
130
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
131
+ }
132
+ ],
133
+ "trainer": {
134
+ "_target_": "scripts.ldm_trainer.LDMTrainer",
135
+ "device": "@device",
136
+ "max_epochs": 5000,
137
+ "train_data_loader": "@train#dataloader",
138
+ "network": "@diffusion",
139
+ "autoencoder_model": "@autoencoder",
140
+ "optimizer": "@optimizer",
141
+ "loss_function": "@loss",
142
+ "latent_shape": "@latent_shape",
143
+ "inferer": "@inferer",
144
+ "key_train_metric": "$None",
145
+ "train_handlers": "@train#handlers"
146
+ }
147
+ },
148
+ "initialize": [
149
+ "$monai.utils.set_determinism(seed=0)"
150
+ ],
151
+ "run": [
152
+ "@load_autoencoder",
153
+ "$@autoencoder.eval()",
154
+ "$print('scale factor:',@scale_factor)",
155
+ "$@train#trainer.run()"
156
+ ]
157
+ }
docs/README.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+ A pre-trained model for volumetric (3D) Brats MRI 3D Latent Diffusion Generative Model.
3
+
4
+ This model is trained on BraTS 2016 and 2017 data from [Medical Decathlon](http://medicaldecathlon.com/), using the Latent diffusion model [1].
5
+
6
+ ![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_network.png)
7
+
8
+ This model is a generator for creating images like the Flair MRIs based on BraTS 2016 and 2017 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output. The `train_autoencoder.json` file describes the training process of the variational autoencoder with GAN loss. The `train_diffusion.json` file describes the training process of the 3D latent diffusion model.
9
+
10
+ In this bundle, the autoencoder uses perceptual loss, which is based on ResNet50 with pre-trained weights (the network is frozen and will not be trained in the bundle). In default, the `pretrained` parameter is specified as `False` in `train_autoencoder.json`. To ensure correct training, changing the default settings is necessary. There are two ways to utilize pretrained weights:
11
+ 1. if set `pretrained` to `True`, ImageNet pretrained weights from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#ResNet50_Weights) will be used. However, the weights are for non-commercial use only.
12
+ 2. if set `pretrained` to `True` and specifies the `perceptual_loss_model_weights_path` parameter, users are able to load weights from a local path. This is the way this bundle used to train, and the pre-trained weights are from some internal data.
13
+
14
+ Please note that each user is responsible for checking the data source of the pre-trained models, the applicable licenses, and determining if suitable for the intended use.
15
+
16
+ #### Example synthetic image
17
+ An example result from inference is shown below:
18
+ ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_example_generation.png)
19
+
20
+ **This is a demonstration network meant to just show the training process for this sort of network with MONAI. To achieve better performance, users need to use larger dataset like [Brats 2021](https://www.synapse.org/#!Synapse:syn25829067/wiki/610865) and have GPU with memory larger than 32G to enable larger networks and attention layers.**
21
+
22
+ ## MONAI Generative Model Dependencies
23
+ [MONAI generative models](https://github.com/Project-MONAI/GenerativeModels) can be installed by
24
+ ```
25
+ pip install lpips==0.1.4
26
+ git clone https://github.com/Project-MONAI/GenerativeModels.git
27
+ cd GenerativeModels/
28
+ git checkout f969c24f88d013dc0045fb7b2885a01fb219992b
29
+ python setup.py install
30
+ cd ..
31
+ ```
32
+
33
+ ## Data
34
+ The training data is BraTS 2016 and 2017 from the Medical Segmentation Decathalon. Users can find more details on the dataset (`Task01_BrainTumour`) at http://medicaldecathlon.com/.
35
+
36
+ - Target: Image Generation
37
+ - Task: Synthesis
38
+ - Modality: MRI
39
+ - Size: 388 3D volumes (1 channel used)
40
+
41
+ ## Training Configuration
42
+ If you have a GPU with less than 32G of memory, you may need to decrease the batch size when training. To do so, modify the `train_batch_size` parameter in the [configs/train_autoencoder.json](../configs/train_autoencoder.json) and [configs/train_diffusion.json](../configs/train_diffusion.json) configuration files.
43
+
44
+ ### Training Configuration of Autoencoder
45
+ The autoencoder was trained using the following configuration:
46
+
47
+ - GPU: at least 32GB GPU memory
48
+ - Actual Model Input: 112 x 128 x 80
49
+ - AMP: False
50
+ - Optimizer: Adam
51
+ - Learning Rate: 1e-5
52
+ - Loss: L1 loss, perceptual loss, KL divergence loss, adversarial loss, GAN BCE loss
53
+
54
+ #### Input
55
+ 1 channel 3D MRI Flair patches
56
+
57
+ #### Output
58
+ - 1 channel 3D MRI reconstructed patches
59
+ - 8 channel mean of latent features
60
+ - 8 channel standard deviation of latent features
61
+
62
+ ### Training Configuration of Diffusion Model
63
+ The latent diffusion model was trained using the following configuration:
64
+
65
+ - GPU: at least 32GB GPU memory
66
+ - Actual Model Input: 36 x 44 x 28
67
+ - AMP: False
68
+ - Optimizer: Adam
69
+ - Learning Rate: 1e-5
70
+ - Loss: MSE loss
71
+
72
+ #### Training Input
73
+ - 8 channel noisy latent features
74
+ - an int that indicates the time step
75
+
76
+ #### Training Output
77
+ 8 channel predicted added noise
78
+
79
+ #### Inference Input
80
+ 8 channel noise
81
+
82
+ #### Inference Output
83
+ 8 channel denoised latent features
84
+
85
+ ### Memory Consumption Warning
86
+
87
+ If you face memory issues with data loading, you can lower the caching rate `cache_rate` in the configurations within range [0, 1] to minimize the System RAM requirements.
88
+
89
+ ## Performance
90
+
91
+ #### Training Loss
92
+ ![A graph showing the autoencoder training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_autoencoder_loss.png)
93
+
94
+ ![A graph showing the latent diffusion training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_diffusion_loss.png)
95
+
96
+ ## MONAI Bundle Commands
97
+
98
+ In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.
99
+
100
+ For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html).
101
+
102
+ ### Execute Autoencoder Training
103
+
104
+ #### Execute Autoencoder Training on single GPU
105
+
106
+ ```
107
+ python -m monai.bundle run --config_file configs/train_autoencoder.json
108
+ ```
109
+
110
+ Please note that if the default dataset path is not modified with the actual path (it should be the path that contains `Task01_BrainTumour`) in the bundle config files, you can also override it by using `--dataset_dir`:
111
+
112
+ ```
113
+ python -m monai.bundle run --config_file configs/train_autoencoder.json --dataset_dir <actual dataset path>
114
+ ```
115
+
116
+ #### Override the `train` config to execute multi-GPU training for Autoencoder
117
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
118
+
119
+ ```
120
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/multi_gpu_train_autoencoder.json']" --lr 8e-5
121
+ ```
122
+
123
+ #### Check the Autoencoder Training result
124
+ The following code generates a reconstructed image from a random input image.
125
+ We can visualize it to see if the autoencoder is trained correctly.
126
+ ```
127
+ python -m monai.bundle run --config_file configs/inference_autoencoder.json
128
+ ```
129
+
130
+ An example of reconstructed image from inference is shown below. If the autoencoder is trained correctly, the reconstructed image should look similar to original image.
131
+
132
+ ![Example reconstructed image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_recon_example.jpg)
133
+
134
+
135
+ ### Execute Latent Diffusion Training
136
+
137
+ #### Execute Latent Diffusion Model Training on single GPU
138
+ After training the autoencoder, run the following command to train the latent diffusion model. This command will print out the scale factor of the latent feature space. If your autoencoder is well trained, this value should be close to 1.0.
139
+
140
+ ```
141
+ python -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json']"
142
+ ```
143
+
144
+ #### Override the `train` config to execute multi-GPU training for Latent Diffusion Model
145
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
146
+
147
+ ```
148
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json','configs/multi_gpu_train_autoencoder.json','configs/multi_gpu_train_diffusion.json']" --lr 8e-5
149
+ ```
150
+
151
+ #### Execute inference
152
+ The following code generates a synthetic image from a random sampled noise.
153
+ ```
154
+ python -m monai.bundle run --config_file configs/inference.json
155
+ ```
156
+
157
+ #### Export checkpoint to TorchScript file
158
+
159
+ The Autoencoder can be exported into a TorchScript file.
160
+
161
+ ```
162
+ python -m monai.bundle ckpt_export autoencoder_def --filepath models/model_autoencoder.ts --ckpt_file models/model_autoencoder.pt --meta_file configs/metadata.json --config_file configs/inference.json
163
+ ```
164
+
165
+ # References
166
+ [1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf
167
+
168
+ # License
169
+ Copyright (c) MONAI Consortium
170
+
171
+ Licensed under the Apache License, Version 2.0 (the "License");
172
+ you may not use this file except in compliance with the License.
173
+ You may obtain a copy of the License at
174
+
175
+ http://www.apache.org/licenses/LICENSE-2.0
176
+
177
+ Unless required by applicable law or agreed to in writing, software
178
+ distributed under the License is distributed on an "AS IS" BASIS,
179
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
180
+ See the License for the specific language governing permissions and
181
+ limitations under the License.
docs/data_license.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. Multimodal Brain Tumor Segmentation Challenge 2018
6
+ https://www.med.upenn.edu/sbia/brats2018/data.html
7
+ /*********************************************************************/
8
+
9
+ Data Usage Agreement / Citations
10
+
11
+ You are free to use and/or refer to the BraTS datasets in your own
12
+ research, provided that you always cite the following two manuscripts:
13
+
14
+ [1] Menze BH, Jakab A, Bauer S, Kalpathy-Cramer J, Farahani K, Kirby
15
+ [J, Burren Y, Porz N, Slotboom J, Wiest R, Lanczi L, Gerstner E, Weber
16
+ [MA, Arbel T, Avants BB, Ayache N, Buendia P, Collins DL, Cordier N,
17
+ [Corso JJ, Criminisi A, Das T, Delingette H, Demiralp Γ, Durst CR,
18
+ [Dojat M, Doyle S, Festa J, Forbes F, Geremia E, Glocker B, Golland P,
19
+ [Guo X, Hamamci A, Iftekharuddin KM, Jena R, John NM, Konukoglu E,
20
+ [Lashkari D, Mariz JA, Meier R, Pereira S, Precup D, Price SJ, Raviv
21
+ [TR, Reza SM, Ryan M, Sarikaya D, Schwartz L, Shin HC, Shotton J,
22
+ [Silva CA, Sousa N, Subbanna NK, Szekely G, Taylor TJ, Thomas OM,
23
+ [Tustison NJ, Unal G, Vasseur F, Wintermark M, Ye DH, Zhao L, Zhao B,
24
+ [Zikic D, Prastawa M, Reyes M, Van Leemput K. "The Multimodal Brain
25
+ [Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on
26
+ [Medical Imaging 34(10), 1993-2024 (2015) DOI:
27
+ [10.1109/TMI.2014.2377694
28
+
29
+ [2] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby JS,
30
+ [Freymann JB, Farahani K, Davatzikos C. "Advancing The Cancer Genome
31
+ [Atlas glioma MRI collections with expert segmentation labels and
32
+ [radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:
33
+ [10.1038/sdata.2017.117
34
+
35
+ In addition, if there are no restrictions imposed from the
36
+ journal/conference you submit your paper about citing "Data
37
+ Citations", please be specific and also cite the following:
38
+
39
+ [3] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
40
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
41
+ [Radiomic Features for the Pre-operative Scans of the TCGA-GBM
42
+ [collection", The Cancer Imaging Archive, 2017. DOI:
43
+ [10.7937/K9/TCIA.2017.KLXWJJ1Q
44
+
45
+ [4] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
46
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
47
+ [Radiomic Features for the Pre-operative Scans of the TCGA-LGG
48
+ [collection", The Cancer Imaging Archive, 2017. DOI:
49
+ [10.7937/K9/TCIA.2017.GJQ7R0EF
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d780cd5fff4ec886226c5407391a5906c45e388c5d02efbf20da8729b7513e19
3
+ size 765042309
models/model_autoencoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a788c35df8e6d7b8c5baf2108bfc6e105a1e6685dfbf564d8d07a66194de8727
3
+ size 84050405
models/model_autoencoder.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39fa19c3e4cd35298337ad4dd8684961d55539a7cd3e29155e12731b254aca3f
3
+ size 84147155
scripts/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from . import ldm_sampler, ldm_trainer, losses, utils
scripts/ldm_sampler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from monai.utils import optional_import
17
+ from torch.cuda.amp import autocast
18
+
19
+ tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
20
+
21
+
22
+ class LDMSampler:
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+
26
+ @torch.no_grad()
27
+ def sampling_fn(
28
+ self,
29
+ input_noise: torch.Tensor,
30
+ autoencoder_model: nn.Module,
31
+ diffusion_model: nn.Module,
32
+ scheduler: nn.Module,
33
+ conditioning: torch.Tensor | None = None,
34
+ ) -> torch.Tensor:
35
+ if has_tqdm:
36
+ progress_bar = tqdm(scheduler.timesteps)
37
+ else:
38
+ progress_bar = iter(scheduler.timesteps)
39
+
40
+ image = input_noise
41
+ if conditioning is not None:
42
+ cond_concat = conditioning.squeeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
43
+ cond_concat = cond_concat.expand(list(cond_concat.shape[0:2]) + list(input_noise.shape[2:]))
44
+
45
+ for t in progress_bar:
46
+ with torch.no_grad():
47
+ if conditioning is not None:
48
+ input_t = torch.cat((image, cond_concat), dim=1)
49
+ else:
50
+ input_t = image
51
+ model_output = diffusion_model(
52
+ input_t, timesteps=torch.Tensor((t,)).to(input_noise.device).long(), context=conditioning
53
+ )
54
+ image, _ = scheduler.step(model_output, t, image)
55
+
56
+ with torch.no_grad():
57
+ with autocast():
58
+ sample = autoencoder_model.decode_stage_2_outputs(image)
59
+
60
+ return sample
scripts/ldm_trainer.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
15
+
16
+ import torch
17
+ from monai.config import IgniteInfo
18
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
19
+ from monai.inferers import Inferer, SimpleInferer
20
+ from monai.transforms import Transform
21
+ from monai.utils import min_version, optional_import
22
+ from monai.utils.enums import CommonKeys, GanKeys
23
+ from torch.optim.optimizer import Optimizer
24
+ from torch.utils.data import DataLoader
25
+
26
+ if TYPE_CHECKING:
27
+ from ignite.engine import Engine, EventEnum
28
+ from ignite.metrics import Metric
29
+ else:
30
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
31
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
32
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
33
+ from monai.engines.trainer import SupervisedTrainer, Trainer
34
+
35
+
36
+ class VaeGanTrainer(Trainer):
37
+ """
38
+ Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266,
39
+ inherits from ``Trainer`` and ``Workflow``.
40
+ Training Loop: for each batch of data size `m`
41
+ 1. Generate `m` fakes from random latent codes.
42
+ 2. Update discriminator with these fakes and current batch reals, repeated d_train_steps times.
43
+ 3. If g_update_latents, generate `m` fakes from new random latent codes.
44
+ 4. Update generator with these fakes using discriminator feedback.
45
+ Args:
46
+ device: an object representing the device on which to run.
47
+ max_epochs: the total epoch number for engine to run.
48
+ train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
49
+ g_network: generator (G) network architecture.
50
+ g_optimizer: G optimizer function.
51
+ g_loss_function: G loss function for optimizer.
52
+ d_network: discriminator (D) network architecture.
53
+ d_optimizer: D optimizer function.
54
+ d_loss_function: D loss function for optimizer.
55
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
56
+ g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
57
+ d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
58
+ d_train_steps: number of times to update D with real data minibatch. Defaults to ``1``.
59
+ latent_shape: size of G input latent code. Defaults to ``64``.
60
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
61
+ with respect to the host. For other cases, this argument has no effect.
62
+ d_prepare_batch: callback function to prepare batchdata for D inferer.
63
+ Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
64
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
65
+ g_prepare_batch: callback function to create batch of latent input for G inferer.
66
+ Defaults to return random latents. for more details please refer to:
67
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
68
+ g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
69
+ iteration_update: the callable function for every iteration, expect to accept `engine`
70
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
71
+ if not provided, use `self._iteration()` instead. for more details please refer to:
72
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
73
+ postprocessing: execute additional transformation for the model output data.
74
+ Typically, several Tensor based transforms composed by `Compose`.
75
+ key_train_metric: compute metric when every iteration completed, and save average value to
76
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
77
+ checkpoint into files.
78
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
79
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
80
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
81
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
82
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
83
+ CheckpointHandler, StatsHandler, etc.
84
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
85
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
86
+ default to `True`.
87
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
88
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
89
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
90
+ `device`, `non_blocking`.
91
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
92
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ device: str | torch.device,
98
+ max_epochs: int,
99
+ train_data_loader: DataLoader,
100
+ g_network: torch.nn.Module,
101
+ g_optimizer: Optimizer,
102
+ g_loss_function: Callable,
103
+ d_network: torch.nn.Module,
104
+ d_optimizer: Optimizer,
105
+ d_loss_function: Callable,
106
+ epoch_length: int | None = None,
107
+ g_inferer: Inferer | None = None,
108
+ d_inferer: Inferer | None = None,
109
+ d_train_steps: int = 1,
110
+ latent_shape: int = 64,
111
+ non_blocking: bool = False,
112
+ d_prepare_batch: Callable = default_prepare_batch,
113
+ g_prepare_batch: Callable = default_prepare_batch,
114
+ g_update_latents: bool = True,
115
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
116
+ postprocessing: Transform | None = None,
117
+ key_train_metric: dict[str, Metric] | None = None,
118
+ additional_metrics: dict[str, Metric] | None = None,
119
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
120
+ train_handlers: Sequence | None = None,
121
+ decollate: bool = True,
122
+ optim_set_to_none: bool = False,
123
+ to_kwargs: dict | None = None,
124
+ amp_kwargs: dict | None = None,
125
+ ):
126
+ if not isinstance(train_data_loader, DataLoader):
127
+ raise ValueError("train_data_loader must be PyTorch DataLoader.")
128
+
129
+ # set up Ignite engine and environments
130
+ super().__init__(
131
+ device=device,
132
+ max_epochs=max_epochs,
133
+ data_loader=train_data_loader,
134
+ epoch_length=epoch_length,
135
+ non_blocking=non_blocking,
136
+ prepare_batch=d_prepare_batch,
137
+ iteration_update=iteration_update,
138
+ key_metric=key_train_metric,
139
+ additional_metrics=additional_metrics,
140
+ metric_cmp_fn=metric_cmp_fn,
141
+ handlers=train_handlers,
142
+ postprocessing=postprocessing,
143
+ decollate=decollate,
144
+ to_kwargs=to_kwargs,
145
+ amp_kwargs=amp_kwargs,
146
+ )
147
+ self.g_network = g_network
148
+ self.g_optimizer = g_optimizer
149
+ self.g_loss_function = g_loss_function
150
+ self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
151
+ self.d_network = d_network
152
+ self.d_optimizer = d_optimizer
153
+ self.d_loss_function = d_loss_function
154
+ self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
155
+ self.d_train_steps = d_train_steps
156
+ self.latent_shape = latent_shape
157
+ self.g_prepare_batch = g_prepare_batch
158
+ self.g_update_latents = g_update_latents
159
+ self.optim_set_to_none = optim_set_to_none
160
+
161
+ def _iteration(
162
+ self, engine: VaeGanTrainer, batchdata: dict | Sequence
163
+ ) -> dict[str, torch.Tensor | int | float | bool]:
164
+ """
165
+ Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine.
166
+ Args:
167
+ engine: `VaeGanTrainer` to execute operation for an iteration.
168
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
169
+ Raises:
170
+ ValueError: must provide batch data for current iteration.
171
+ """
172
+ if batchdata is None:
173
+ raise ValueError("must provide batch data for current iteration.")
174
+
175
+ d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)[0]
176
+ g_input = d_input
177
+ g_output, z_mu, z_sigma = engine.g_inferer(g_input, engine.g_network)
178
+
179
+ # Train Discriminator
180
+ d_total_loss = torch.zeros(1)
181
+ for _ in range(engine.d_train_steps):
182
+ engine.d_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
183
+ dloss = engine.d_loss_function(g_output, d_input)
184
+ dloss.backward()
185
+ engine.d_optimizer.step()
186
+ d_total_loss += dloss.item()
187
+
188
+ # Train Generator
189
+ engine.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
190
+ g_loss = engine.g_loss_function(g_output, g_input, z_mu, z_sigma)
191
+ g_loss.backward()
192
+ engine.g_optimizer.step()
193
+
194
+ return {
195
+ GanKeys.REALS: d_input,
196
+ GanKeys.FAKES: g_output,
197
+ GanKeys.LATENTS: g_input,
198
+ GanKeys.GLOSS: g_loss.item(),
199
+ GanKeys.DLOSS: d_total_loss.item(),
200
+ }
201
+
202
+
203
+ class LDMTrainer(SupervisedTrainer):
204
+ """
205
+ Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.
206
+ Args:
207
+ device: an object representing the device on which to run.
208
+ max_epochs: the total epoch number for trainer to run.
209
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
210
+ network: network to train in the trainer, should be regular PyTorch `torch.nn.Module`.
211
+ optimizer: the optimizer associated to the network, should be regular PyTorch optimizer from `torch.optim`
212
+ or its subclass.
213
+ loss_function: the loss function associated to the optimizer, should be regular PyTorch loss,
214
+ which inherit from `torch.nn.modules.loss`.
215
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
216
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
217
+ with respect to the host. For other cases, this argument has no effect.
218
+ prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
219
+ from `engine.state.batch` for every iteration, for more details please refer to:
220
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
221
+ iteration_update: the callable function for every iteration, expect to accept `engine`
222
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
223
+ if not provided, use `self._iteration()` instead. for more details please refer to:
224
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
225
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
226
+ postprocessing: execute additional transformation for the model output data.
227
+ Typically, several Tensor based transforms composed by `Compose`.
228
+ key_train_metric: compute metric when every iteration completed, and save average value to
229
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
230
+ checkpoint into files.
231
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
232
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
233
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
234
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
235
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
236
+ CheckpointHandler, StatsHandler, etc.
237
+ amp: whether to enable auto-mixed-precision training, default is False.
238
+ event_names: additional custom ignite events that will register to the engine.
239
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
240
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
241
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
242
+ #ignite.engine.engine.Engine.register_events.
243
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
244
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
245
+ default to `True`.
246
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
247
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
248
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
249
+ `device`, `non_blocking`.
250
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
251
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ device: str | torch.device,
257
+ max_epochs: int,
258
+ train_data_loader: Iterable | DataLoader,
259
+ network: torch.nn.Module,
260
+ autoencoder_model: torch.nn.Module,
261
+ optimizer: Optimizer,
262
+ loss_function: Callable,
263
+ latent_shape: Sequence,
264
+ inferer: Inferer,
265
+ epoch_length: int | None = None,
266
+ non_blocking: bool = False,
267
+ prepare_batch: Callable = default_prepare_batch,
268
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
269
+ postprocessing: Transform | None = None,
270
+ key_train_metric: dict[str, Metric] | None = None,
271
+ additional_metrics: dict[str, Metric] | None = None,
272
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
273
+ train_handlers: Sequence | None = None,
274
+ amp: bool = False,
275
+ event_names: list[str | EventEnum | type[EventEnum]] | None = None,
276
+ event_to_attr: dict | None = None,
277
+ decollate: bool = True,
278
+ optim_set_to_none: bool = False,
279
+ to_kwargs: dict | None = None,
280
+ amp_kwargs: dict | None = None,
281
+ ) -> None:
282
+ super().__init__(
283
+ device=device,
284
+ max_epochs=max_epochs,
285
+ train_data_loader=train_data_loader,
286
+ network=network,
287
+ optimizer=optimizer,
288
+ loss_function=loss_function,
289
+ inferer=inferer,
290
+ optim_set_to_none=optim_set_to_none,
291
+ epoch_length=epoch_length,
292
+ non_blocking=non_blocking,
293
+ prepare_batch=prepare_batch,
294
+ iteration_update=iteration_update,
295
+ postprocessing=postprocessing,
296
+ key_train_metric=key_train_metric,
297
+ additional_metrics=additional_metrics,
298
+ metric_cmp_fn=metric_cmp_fn,
299
+ train_handlers=train_handlers,
300
+ amp=amp,
301
+ event_names=event_names,
302
+ event_to_attr=event_to_attr,
303
+ decollate=decollate,
304
+ to_kwargs=to_kwargs,
305
+ amp_kwargs=amp_kwargs,
306
+ )
307
+
308
+ self.latent_shape = latent_shape
309
+ self.autoencoder_model = autoencoder_model
310
+
311
+ def _iteration(self, engine: LDMTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
312
+ """
313
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
314
+ Return below items in a dictionary:
315
+ - IMAGE: image Tensor data for model input, already moved to device.
316
+ - LABEL: label Tensor data corresponding to the image, already moved to device.
317
+ - PRED: prediction result of model.
318
+ - LOSS: loss value computed by loss function.
319
+ Args:
320
+ engine: `SupervisedTrainer` to execute operation for an iteration.
321
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
322
+ Raises:
323
+ ValueError: When ``batchdata`` is None.
324
+ """
325
+ if batchdata is None:
326
+ raise ValueError("Must provide batch data for current iteration.")
327
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
328
+ if len(batch) == 2:
329
+ images, labels = batch
330
+ args: tuple = ()
331
+ kwargs: dict = {}
332
+ else:
333
+ images, labels, args, kwargs = batch
334
+ # put iteration outputs into engine.state
335
+ engine.state.output = {CommonKeys.IMAGE: images}
336
+
337
+ # generate noise
338
+ noise_shape = [images.shape[0]] + list(self.latent_shape)
339
+ noise = torch.randn(noise_shape, dtype=images.dtype).to(images.device)
340
+ engine.state.output = {"noise": noise}
341
+
342
+ # Create timesteps
343
+ timesteps = torch.randint(
344
+ 0, engine.inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
345
+ ).long()
346
+
347
+ def _compute_pred_loss():
348
+ # predicted noise
349
+ engine.state.output[CommonKeys.PRED] = engine.inferer(
350
+ inputs=images,
351
+ autoencoder_model=self.autoencoder_model,
352
+ diffusion_model=engine.network,
353
+ noise=noise,
354
+ timesteps=timesteps,
355
+ )
356
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
357
+ # compute loss
358
+ engine.state.output[CommonKeys.LOSS] = engine.loss_function(
359
+ engine.state.output[CommonKeys.PRED], noise
360
+ ).mean()
361
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
362
+
363
+ engine.network.train()
364
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
365
+
366
+ if engine.amp and engine.scaler is not None:
367
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
368
+ _compute_pred_loss()
369
+ engine.scaler.scale(engine.state.output[CommonKeys.LOSS]).backward()
370
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
371
+ engine.scaler.step(engine.optimizer)
372
+ engine.scaler.update()
373
+ else:
374
+ _compute_pred_loss()
375
+ engine.state.output[CommonKeys.LOSS].backward()
376
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
377
+ engine.optimizer.step()
378
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
379
+
380
+ return engine.state.output
scripts/losses.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+
11
+ import torch
12
+ from generative.losses import PatchAdversarialLoss
13
+
14
+ intensity_loss = torch.nn.L1Loss()
15
+ adv_loss = PatchAdversarialLoss(criterion="least_squares")
16
+
17
+ adv_weight = 0.1
18
+ perceptual_weight = 0.1
19
+ # kl_weight: important hyper-parameter.
20
+ # If too large, decoder cannot recon good results from latent space.
21
+ # If too small, latent space will not be regularized enough for the diffusion model
22
+ kl_weight = 1e-7
23
+
24
+
25
+ def compute_kl_loss(z_mu, z_sigma):
26
+ kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
27
+ return torch.sum(kl_loss) / kl_loss.shape[0]
28
+
29
+
30
+ def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual):
31
+ recons_loss = intensity_loss(gen_images, real_images)
32
+ kl_loss = compute_kl_loss(z_mu, z_sigma)
33
+ p_loss = loss_perceptual(gen_images.float(), real_images.float())
34
+ loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss
35
+
36
+ logits_fake = disc_net(gen_images)[-1]
37
+ generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
38
+ loss_g = loss_g + adv_weight * generator_loss
39
+
40
+ return loss_g
41
+
42
+
43
+ def discriminator_loss(gen_images, real_images, disc_net):
44
+ logits_fake = disc_net(gen_images.contiguous().detach())[-1]
45
+ loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
46
+ logits_real = disc_net(real_images.contiguous().detach())[-1]
47
+ loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
48
+ discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
49
+ loss_d = adv_weight * discriminator_loss
50
+ return loss_d
scripts/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+
11
+ import monai
12
+ import torch
13
+
14
+
15
+ def compute_scale_factor(autoencoder, train_loader, device):
16
+ with torch.no_grad():
17
+ check_data = monai.utils.first(train_loader)
18
+ z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device))
19
+ scale_factor = 1 / torch.std(z)
20
+ return scale_factor.item()