xf3227 commited on
Commit
c4c51f0
·
1 Parent(s): 9ee1544
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. adrd/__init__.py +0 -22
  2. adrd/__pycache__/__init__.cpython-311.pyc +0 -0
  3. adrd/_ds/__init__.py +0 -0
  4. adrd/_ds/lddl.py +0 -71
  5. adrd/model/__init__.py +0 -6
  6. adrd/model/__pycache__/__init__.cpython-311.pyc +0 -0
  7. adrd/model/__pycache__/adrd_model.cpython-311.pyc +0 -0
  8. adrd/model/__pycache__/calibration.cpython-311.pyc +0 -0
  9. adrd/model/__pycache__/imaging_model.cpython-311.pyc +0 -0
  10. adrd/model/__pycache__/train_resnet.cpython-311.pyc +0 -0
  11. adrd/model/adrd_model.py +0 -976
  12. adrd/model/calibration.py +0 -450
  13. adrd/model/cnn_resnet3d_with_linear_classifier.py +0 -533
  14. adrd/model/imaging_model.py +0 -843
  15. adrd/model/train_resnet.py +0 -484
  16. adrd/model/transformer.py +0 -600
  17. adrd/nn/__init__.py +0 -12
  18. adrd/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  19. adrd/nn/__pycache__/blocks.cpython-311.pyc +0 -0
  20. adrd/nn/__pycache__/c3d.cpython-311.pyc +0 -0
  21. adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc +0 -0
  22. adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc +0 -0
  23. adrd/nn/__pycache__/dense_net.cpython-311.pyc +0 -0
  24. adrd/nn/__pycache__/focal_loss.cpython-311.pyc +0 -0
  25. adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc +0 -0
  26. adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc +0 -0
  27. adrd/nn/__pycache__/resnet3d.cpython-311.pyc +0 -0
  28. adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc +0 -0
  29. adrd/nn/__pycache__/selfattention.cpython-311.pyc +0 -0
  30. adrd/nn/__pycache__/transformer.cpython-311.pyc +0 -0
  31. adrd/nn/__pycache__/unet.cpython-311.pyc +0 -0
  32. adrd/nn/__pycache__/unet_3d.cpython-311.pyc +0 -0
  33. adrd/nn/__pycache__/unet_img_model.cpython-311.pyc +0 -0
  34. adrd/nn/__pycache__/vitautoenc.cpython-311.pyc +0 -0
  35. adrd/nn/blocks.py +0 -57
  36. adrd/nn/c3d.py +0 -99
  37. adrd/nn/cnn_resnet3d.py +0 -81
  38. adrd/nn/cnn_resnet3d_with_linear_classifier.py +0 -56
  39. adrd/nn/dense_net.py +0 -211
  40. adrd/nn/focal_loss.py +0 -120
  41. adrd/nn/img_model_wrapper.py +0 -174
  42. adrd/nn/net_resnet3d.py +0 -338
  43. adrd/nn/resnet3d.py +0 -256
  44. adrd/nn/resnet_img_model.py +0 -81
  45. adrd/nn/selfattention.py +0 -62
  46. adrd/nn/transformer.py +0 -268
  47. adrd/nn/unet.py +0 -232
  48. adrd/nn/unet_3d.py +0 -63
  49. adrd/nn/unet_img_model.py +0 -211
  50. adrd/nn/vitautoenc.py +0 -163
adrd/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- __version__ = '0.0.1'
2
-
3
- from . import nn
4
- from . import model
5
-
6
- # # load pretrained transformer
7
- # pretrained_transformer = model.Transformer.from_ckpt('{}/ckpt/ckpt.pt'.format(__path__[0]))
8
- # from . import shap_adrd
9
- # from .model import DynamicCalibratedClassifier
10
- # from .model import StaticCalibratedClassifier
11
-
12
- # load fitted transformer and calibrated wrapper
13
- # try:
14
- # fitted_resnet3d = model.CNNResNet3DWithLinearClassifier.from_ckpt('{}/ckpt/ckpt_img_072523.pt'.format(__path__[0]))
15
- # fitted_calibrated_classifier_nonimg = StaticCalibratedClassifier.from_ckpt(
16
- # filepath_state_dict = '{}/ckpt/static_calibrated_classifier_073023.pkl'.format(__path__[0]),
17
- # filepath_wrapped_model = '{}/ckpt/ckpt_080823.pt'.format(__path__[0]),
18
- # )
19
- # fitted_transformer_nonimg = fitted_calibrated_classifier_nonimg.model
20
- # shap_explainer = shap_adrd.SamplingExplainer(fitted_transformer_nonimg)
21
- # except:
22
- # print('Fail to load checkpoints.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (264 Bytes)
 
adrd/_ds/__init__.py DELETED
File without changes
adrd/_ds/lddl.py DELETED
@@ -1,71 +0,0 @@
1
- from typing import Any, Self, overload
2
-
3
-
4
- class lddl:
5
- ''' ... '''
6
- def __init__(self) -> None:
7
- ''' ... '''
8
- self.dat_ld: list[dict[str, Any]] = None
9
- self.dat_dl: dict[str, list[Any]] = None
10
-
11
- @overload
12
- def __getitem__(self, idx: int) -> dict[str, Any]: ...
13
-
14
- @overload
15
- def __getitem__(self, idx: str) -> list[Any]: ...
16
-
17
- def __getitem__(self, idx: str | int) -> list[Any] | dict[str, Any]:
18
- ''' ... '''
19
- if isinstance(idx, str):
20
- return self.dat_dl[idx]
21
- elif isinstance(idx, int):
22
- return self.dat_ld[idx]
23
- else:
24
- raise TypeError('Unexpected key type: {}'.format(type(idx)))
25
-
26
- @classmethod
27
- def from_ld(cls, dat: list[dict[str, Any]]) -> Self:
28
- ''' Construct from list of dicts. '''
29
- obj = cls()
30
- obj.dat_ld = dat
31
- obj.dat_dl = {k: [dat[i][k] for i in range(len(dat))] for k in dat[0]}
32
- return obj
33
-
34
- @classmethod
35
- def from_dl(cls, dat: dict[str, list[Any]]) -> Self:
36
- ''' Construct from dict of lists. '''
37
- obj = cls()
38
- obj.dat_ld = [dict(zip(dat, v)) for v in zip(*dat.values())]
39
- obj.dat_dl = dat
40
- return obj
41
-
42
-
43
- if __name__ == '__main__':
44
- ''' for testing purpose only '''
45
- dl = {
46
- 'a': [0, 1, 2],
47
- 'b': [3, 4, 5],
48
- }
49
-
50
- ld = [
51
- {'a': 0, 'b': 1, 'c': 2},
52
- {'a': 3, 'b': 4, 'c': 5},
53
- ]
54
-
55
- # test constructing from ld
56
- dat_ld = lddl.from_ld(ld)
57
- print(dat_ld.dat_ld)
58
- print(dat_ld.dat_dl)
59
-
60
- # test constructing from dl
61
- dat_dl = lddl.from_dl(dl)
62
- print(dat_dl.dat_ld)
63
- print(dat_dl.dat_dl)
64
-
65
- # test __getitem__
66
- print(dat_dl['a'])
67
- print(dat_dl[0])
68
-
69
- # mouse hover to check if type hints are correct
70
- v = dat_dl['a']
71
- v = dat_dl[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .adrd_model import ADRDModel
2
- from .imaging_model import ImagingModel
3
- from .train_resnet import TrainResNet
4
- # from .transformer import Transformer
5
- from .calibration import DynamicCalibratedClassifier
6
- from .calibration import StaticCalibratedClassifier
 
 
 
 
 
 
 
adrd/model/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (488 Bytes)
 
adrd/model/__pycache__/adrd_model.cpython-311.pyc DELETED
Binary file (56.5 kB)
 
adrd/model/__pycache__/calibration.cpython-311.pyc DELETED
Binary file (27.8 kB)
 
adrd/model/__pycache__/imaging_model.cpython-311.pyc DELETED
Binary file (46.4 kB)
 
adrd/model/__pycache__/train_resnet.cpython-311.pyc DELETED
Binary file (26.1 kB)
 
adrd/model/adrd_model.py DELETED
@@ -1,976 +0,0 @@
1
- __all__ = ['ADRDModel']
2
-
3
- import wandb
4
- import torch
5
- from torch.utils.data import DataLoader
6
- import numpy as np
7
- import tqdm
8
- import multiprocessing
9
- from sklearn.base import BaseEstimator
10
- from sklearn.utils.validation import check_is_fitted
11
- from sklearn.model_selection import train_test_split
12
- from scipy.special import expit
13
- from copy import deepcopy
14
- from contextlib import suppress
15
- from typing import Any, Self, Type
16
- from functools import wraps
17
- from tqdm import tqdm
18
- Tensor = Type[torch.Tensor]
19
- Module = Type[torch.nn.Module]
20
-
21
- # for DistributedDataParallel
22
- import torch.distributed as dist
23
- import torch.multiprocessing as mp
24
- from torch.nn.parallel import DistributedDataParallel as DDP
25
-
26
- from .. import nn
27
- from ..nn import Transformer
28
- from ..utils import TransformerTrainingDataset, TransformerBalancedTrainingDataset, TransformerValidationDataset, TransformerTestingDataset, Transformer2ndOrderBalancedTrainingDataset
29
- from ..utils.misc import ProgressBar
30
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
31
- from ..utils.misc import convert_args_kwargs_to_kwargs
32
-
33
-
34
- def _manage_ctx_fit(func):
35
- ''' ... '''
36
- @wraps(func)
37
- def wrapper(*args, **kwargs):
38
- # format arguments
39
- kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
40
-
41
- if kwargs['self']._device_ids is None:
42
- return func(**kwargs)
43
- else:
44
- # change primary device
45
- default_device = kwargs['self'].device
46
- kwargs['self'].device = kwargs['self']._device_ids[0]
47
- rtn = func(**kwargs)
48
- kwargs['self'].to(default_device)
49
- return rtn
50
- return wrapper
51
-
52
-
53
- class ADRDModel(BaseEstimator):
54
- """Primary model class for ADRD framework.
55
-
56
- The ADRDModel encapsulates the core pipeline of the ADRD framework,
57
- permitting users to train and validate with the provided data. Designed for
58
- user-friendly operation, the ADRDModel is derived from
59
- ``sklearn.base.BaseEstimator``, ensuring compliance with the well-established
60
- API design conventions of scikit-learn.
61
- """
62
- def __init__(self,
63
- src_modalities: dict[str, dict[str, Any]],
64
- tgt_modalities: dict[str, dict[str, Any]],
65
- label_fractions: dict[str, float],
66
- d_model: int = 32,
67
- nhead: int = 1,
68
- num_encoder_layers: int = 1,
69
- num_decoder_layers: int = 1,
70
- num_epochs: int = 32,
71
- batch_size: int = 8,
72
- batch_size_multiplier: int = 1,
73
- lr: float = 1e-2,
74
- weight_decay: float = 0.0,
75
- beta: float = 0.9999,
76
- gamma: float = 2.0,
77
- criterion: str | None = None,
78
- device: str = 'cpu',
79
- cuda_devices: list = [1],
80
- img_net: str | None = None,
81
- imgnet_layers: int | None = 2,
82
- img_size: int | None = 128,
83
- fusion_stage: str = 'middle',
84
- patch_size: int | None = 16,
85
- imgnet_ckpt: str | None = None,
86
- train_imgnet: bool = False,
87
- ckpt_path: str = './adrd_tool/adrd/dev/ckpt/ckpt.pt',
88
- load_from_ckpt: bool = False,
89
- save_intermediate_ckpts: bool = False,
90
- data_parallel: bool = False,
91
- verbose: int = 0,
92
- wandb_ = 0,
93
- balanced_sampling: bool = False,
94
- label_distribution: dict = {},
95
- ranking_loss: bool = False,
96
- _device_ids: list | None = None,
97
-
98
- _dataloader_num_workers: int = 4,
99
- _amp_enabled: bool = False,
100
- ) -> None:
101
- """Create a new ADRD model.
102
-
103
- :param src_modalities: _description_
104
- :type src_modalities: dict[str, dict[str, Any]]
105
- :param tgt_modalities: _description_
106
- :type tgt_modalities: dict[str, dict[str, Any]]
107
- :param label_fractions: _description_
108
- :type label_fractions: dict[str, float]
109
- :param d_model: _description_, defaults to 32
110
- :type d_model: int, optional
111
- :param nhead: number of transformer heads, defaults to 1
112
- :type nhead: int, optional
113
- :param num_encoder_layers: number of encoder layers, defaults to 1
114
- :type num_encoder_layers: int, optional
115
- :param num_decoder_layers: number of decoder layers, defaults to 1
116
- :type num_decoder_layers: int, optional
117
- :param num_epochs: number of training epochs, defaults to 32
118
- :type num_epochs: int, optional
119
- :param batch_size: batch size, defaults to 8
120
- :type batch_size: int, optional
121
- :param batch_size_multiplier: _description_, defaults to 1
122
- :type batch_size_multiplier: int, optional
123
- :param lr: learning rate, defaults to 1e-2
124
- :type lr: float, optional
125
- :param weight_decay: _description_, defaults to 0.0
126
- :type weight_decay: float, optional
127
- :param beta: _description_, defaults to 0.9999
128
- :type beta: float, optional
129
- :param gamma: The focusing parameter for the focal loss. Higher values of gamma make easy-to-classify examples contribute less to the loss relative to hard-to-classify examples. Must be non-negative., defaults to 2.0
130
- :type gamma: float, optional
131
- :param criterion: The criterion to select the best model, defaults to None
132
- :type criterion: str | None, optional
133
- :param device: 'cuda' or 'cpu', defaults to 'cpu'
134
- :type device: str, optional
135
- :param cuda_devices: A list of gpu numbers to data parallel training. The device must be set to 'cuda' and data_parallel must be set to True, defaults to [1]
136
- :type cuda_devices: list, optional
137
- :param img_net: _description_, defaults to None
138
- :type img_net: str | None, optional
139
- :param imgnet_layers: _description_, defaults to 2
140
- :type imgnet_layers: int | None, optional
141
- :param img_size: _description_, defaults to 128
142
- :type img_size: int | None, optional
143
- :param fusion_stage: _description_, defaults to 'middle'
144
- :type fusion_stage: str, optional
145
- :param patch_size: _description_, defaults to 16
146
- :type patch_size: int | None, optional
147
- :param imgnet_ckpt: _description_, defaults to None
148
- :type imgnet_ckpt: str | None, optional
149
- :param train_imgnet: Set to True to finetune the img_net backbone, defaults to False
150
- :type train_imgnet: bool, optional
151
- :param ckpt_path: The model checkpoint point path, defaults to './adrd_tool/adrd/dev/ckpt/ckpt.pt'
152
- :type ckpt_path: str, optional
153
- :param load_from_ckpt: Set to True to load the model weights from checkpoint ckpt_path, defaults to False
154
- :type load_from_ckpt: bool, optional
155
- :param save_intermediate_ckpts: Set to True to save intermediate model checkpoints, defaults to False
156
- :type save_intermediate_ckpts: bool, optional
157
- :param data_parallel: Set to True to enable data parallel trsining, defaults to False
158
- :type data_parallel: bool, optional
159
- :param verbose: _description_, defaults to 0
160
- :type verbose: int, optional
161
- :param wandb_: Set to 1 to track the loss and accuracy curves on wandb, defaults to 0
162
- :type wandb_: int, optional
163
- :param balanced_sampling: _description_, defaults to False
164
- :type balanced_sampling: bool, optional
165
- :param label_distribution: _description_, defaults to {}
166
- :type label_distribution: dict, optional
167
- :param ranking_loss: _description_, defaults to False
168
- :type ranking_loss: bool, optional
169
- :param _device_ids: _description_, defaults to None
170
- :type _device_ids: list | None, optional
171
- :param _dataloader_num_workers: _description_, defaults to 4
172
- :type _dataloader_num_workers: int, optional
173
- :param _amp_enabled: _description_, defaults to False
174
- :type _amp_enabled: bool, optional
175
- """
176
- # for multiprocessing
177
- self._rank = 0
178
- self._lock = None
179
-
180
- # positional parameters
181
- self.src_modalities = src_modalities
182
- self.tgt_modalities = tgt_modalities
183
-
184
- # training parameters
185
- self.label_fractions = label_fractions
186
- self.d_model = d_model
187
- self.nhead = nhead
188
- self.num_encoder_layers = num_encoder_layers
189
- self.num_decoder_layers = num_decoder_layers
190
- self.num_epochs = num_epochs
191
- self.batch_size = batch_size
192
- self.batch_size_multiplier = batch_size_multiplier
193
- self.lr = lr
194
- self.weight_decay = weight_decay
195
- self.beta = beta
196
- self.gamma = gamma
197
- self.criterion = criterion
198
- self.device = device
199
- self.cuda_devices = cuda_devices
200
- self.img_net = img_net
201
- self.patch_size = patch_size
202
- self.img_size = img_size
203
- self.fusion_stage = fusion_stage
204
- self.imgnet_ckpt = imgnet_ckpt
205
- self.imgnet_layers = imgnet_layers
206
- self.train_imgnet = train_imgnet
207
- self.ckpt_path = ckpt_path
208
- self.load_from_ckpt = load_from_ckpt
209
- self.save_intermediate_ckpts = save_intermediate_ckpts
210
- self.data_parallel = data_parallel
211
- self.verbose = verbose
212
- self.label_distribution = label_distribution
213
- self.wandb_ = wandb_
214
- self.balanced_sampling = balanced_sampling
215
- self.ranking_loss = ranking_loss
216
- self._device_ids = _device_ids
217
- self._dataloader_num_workers = _dataloader_num_workers
218
- self._amp_enabled = _amp_enabled
219
- self.scaler = torch.cuda.amp.GradScaler()
220
- # self._init_net()
221
-
222
- @_manage_ctx_fit
223
- def fit(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None, img_mode=0) -> Self:
224
- # def fit(self, x, y) -> Self:
225
- ''' ... '''
226
-
227
- # start a new wandb run to track this script
228
- if self.wandb_ == 1:
229
- wandb.init(
230
- # set the wandb project where this run will be logged
231
- project="ADRD_main",
232
-
233
- # track hyperparameters and run metadata
234
- config={
235
- "Loss": 'Focalloss',
236
- "ranking_loss": self.ranking_loss,
237
- "img architecture": self.img_net,
238
- "EMB": "ALL_SEQ",
239
- "epochs": self.num_epochs,
240
- "d_model": self.d_model,
241
- # 'positional encoding': 'Diff PE',
242
- 'Balanced Sampling': self.balanced_sampling,
243
- 'Shared CNN': 'Yes',
244
- }
245
- )
246
- wandb.run.log_code(".")
247
- else:
248
- wandb.init(mode="disabled")
249
- # for PyTorch computational efficiency
250
- torch.set_num_threads(1)
251
- # print(img_train_trans)
252
- # initialize neural network
253
- print(self.criterion)
254
- print(f"Ranking loss: {self.ranking_loss}")
255
- print(f"Batch size: {self.batch_size}")
256
- print(f"Batch size multiplier: {self.batch_size_multiplier}")
257
-
258
- if img_mode in [0,1,2]:
259
- for k, info in self.src_modalities.items():
260
- if info['type'] == 'imaging':
261
- if 'densenet' in self.img_net.lower() and 'emb' not in self.img_net.lower():
262
- info['shape'] = (1,) + self.img_size
263
- info['img_shape'] = (1,) + self.img_size
264
- elif 'emb' not in self.img_net.lower():
265
- info['shape'] = (1,) + (self.img_size,) * 3
266
- info['img_shape'] = (1,) + (self.img_size,) * 3
267
- elif 'swinunetr' in self.img_net.lower():
268
- info['shape'] = (1, 768, 4, 4, 4)
269
- info['img_shape'] = (1, 768, 4, 4, 4)
270
-
271
-
272
-
273
- self._init_net()
274
- ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
275
-
276
- # initialize optimizer and scheduler
277
- if not self.load_from_ckpt:
278
- self.optimizer = self._init_optimizer()
279
- self.scheduler = self._init_scheduler(self.optimizer)
280
-
281
- # gradient scaler for AMP
282
- if self._amp_enabled:
283
- self.scaler = torch.cuda.amp.GradScaler()
284
-
285
- # initialize the focal losses
286
- self.loss_fn = {}
287
-
288
- for k in self.tgt_modalities:
289
- if self.label_fractions[k] >= 0.3:
290
- alpha = -1
291
- else:
292
- alpha = pow((1 - self.label_fractions[k]), 2)
293
- # alpha = -1
294
- self.loss_fn[k] = nn.SigmoidFocalLoss(
295
- alpha = alpha,
296
- gamma = self.gamma,
297
- reduction = 'none'
298
- )
299
-
300
- # to record the best validation performance criterion
301
- if self.criterion is not None:
302
- best_crit = None
303
- best_crit_AUPR = None
304
-
305
- # progress bar for epoch loops
306
- if self.verbose == 1:
307
- with self._lock if self._lock is not None else suppress():
308
- pbr_epoch = tqdm.tqdm(
309
- desc = 'Rank {:02d}'.format(self._rank),
310
- total = self.num_epochs,
311
- position = self._rank,
312
- ascii = True,
313
- leave = False,
314
- bar_format='{l_bar}{r_bar}'
315
- )
316
-
317
- self.skip_embedding = {}
318
- for k, info in self.src_modalities.items():
319
- # if info['type'] == 'imaging':
320
- # if not self.img_net:
321
- # self.skip_embedding[k] = True
322
- # else:
323
- self.skip_embedding[k] = False
324
-
325
- self.grad_list = []
326
- # Define a hook function to print and store the gradient of a layer
327
- def print_and_store_grad(grad):
328
- self.grad_list.append(grad)
329
- # print(grad)
330
-
331
-
332
- # initialize the ranking loss
333
- self.lambda_coeff = 0.005
334
- self.margin = 0.25
335
- self.margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=self.margin)
336
-
337
- # training loop
338
- for epoch in range(self.start_epoch, self.num_epochs):
339
- met_trn = self.train_one_epoch(ldr_trn, epoch)
340
- met_vld = self.validate_one_epoch(ldr_vld, epoch)
341
-
342
- print(self.ckpt_path.split('/')[-1])
343
-
344
- # save the model if it has the best validation performance criterion by far
345
- if self.criterion is None: continue
346
-
347
- # is current criterion better than previous best?
348
- curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
349
- curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
350
- # AUROC
351
- if best_crit is None or np.isnan(best_crit):
352
- is_better = True
353
- elif self.criterion == 'Loss' and best_crit >= curr_crit:
354
- is_better = True
355
- elif self.criterion != 'Loss' and best_crit <= curr_crit :
356
- is_better = True
357
- else:
358
- is_better = False
359
-
360
- # AUPR
361
- if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
362
- is_better_AUPR = True
363
- elif best_crit_AUPR <= curr_crit_AUPR :
364
- is_better_AUPR = True
365
- else:
366
- is_better_AUPR = False
367
- # update best criterion
368
- if is_better_AUPR:
369
- best_crit_AUPR = curr_crit_AUPR
370
- if self.save_intermediate_ckpts:
371
- print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
372
- self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
373
- if is_better:
374
- best_crit = curr_crit
375
- best_state_dict = deepcopy(self.net_.state_dict())
376
- if self.save_intermediate_ckpts:
377
- print(f"Saving the model to {self.ckpt_path}...")
378
- self.save(self.ckpt_path, epoch)
379
-
380
- if self.verbose > 2:
381
- print('Best {}: {}'.format(self.criterion, best_crit))
382
- print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
383
-
384
- if self.verbose == 1:
385
- with self._lock if self._lock is not None else suppress():
386
- pbr_epoch.update(1)
387
- pbr_epoch.refresh()
388
-
389
- if self.verbose == 1:
390
- with self._lock if self._lock is not None else suppress():
391
- pbr_epoch.close()
392
-
393
- return self
394
-
395
- def train_one_epoch(self, ldr_trn, epoch):
396
- # progress bar for batch loops
397
- if self.verbose > 1:
398
- pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
399
-
400
- # set model to train mode
401
- torch.set_grad_enabled(True)
402
- self.net_.train()
403
-
404
- scores_trn, y_true_trn, y_mask_trn = [], [], []
405
- losses_trn = [[] for _ in self.tgt_modalities]
406
- iters = len(ldr_trn)
407
- for n_iter, (x_batch, y_batch, mask, y_mask) in enumerate(ldr_trn):
408
-
409
- # mount data to the proper device
410
- x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
411
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
412
- mask = {k: mask[k].to(self.device) for k in mask}
413
- y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
414
-
415
- with torch.autocast(
416
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
417
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
418
- enabled = self._amp_enabled,
419
- ):
420
-
421
- outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
422
-
423
- # calculate multitask loss
424
- loss = 0
425
-
426
- # for initial 10 epochs, only the focal loss is used for stable training
427
- if self.ranking_loss:
428
- if epoch < 10:
429
- loss = 0
430
- else:
431
- for i, k in enumerate(self.tgt_modalities):
432
- for ii, kk in enumerate(self.tgt_modalities):
433
- if ii>i:
434
- pairs = (y_mask[k] == 1) & (y_mask[kk] == 1)
435
- total_elements = (torch.abs(y_batch[k][pairs]-y_batch[kk][pairs])).sum()
436
- if total_elements != 0:
437
- loss += self.lambda_coeff * (self.margin_loss(torch.sigmoid(outputs[k])[pairs],torch.sigmoid(outputs[kk][pairs]),y_batch[k][pairs]-y_batch[kk][pairs]))/total_elements
438
-
439
- for i, k in enumerate(self.tgt_modalities):
440
- loss_task = self.loss_fn[k](outputs[k], y_batch[k])
441
- msk_loss_task = loss_task * y_mask[k]
442
- msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
443
- # msk_loss_mean = msk_loss_task.sum()
444
- loss += msk_loss_mean
445
- losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
446
-
447
- # backward
448
- loss = loss / self.batch_size_multiplier
449
- if self._amp_enabled:
450
- self.scaler.scale(loss).backward()
451
- else:
452
- loss.backward()
453
-
454
- if len(self.grad_list) > 0:
455
- print(len(self.grad_list), len(self.grad_list[-1]))
456
- print(f"Gradient at {n_iter}: {self.grad_list[-1][0]}")
457
-
458
- # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.img_model.features[0].weight)
459
- # print("img_MRI_T1_1 ", self.net_.modules_emb_src.img_MRI_T1_1.downsample[0].weight)
460
-
461
- # update parameters
462
- if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
463
- if self._amp_enabled:
464
- self.scaler.step(self.optimizer)
465
- self.scaler.update()
466
- self.optimizer.zero_grad()
467
- else:
468
- self.optimizer.step()
469
- self.optimizer.zero_grad()
470
-
471
- # set self.scheduler
472
- self.scheduler.step(epoch + n_iter / iters)
473
-
474
- ''' TODO: change array to dictionary later '''
475
- outputs = torch.stack(list(outputs.values()), dim=1)
476
- y_batch = torch.stack(list(y_batch.values()), dim=1)
477
- y_mask = torch.stack(list(y_mask.values()), dim=1)
478
-
479
- # save outputs to evaluate performance later
480
- scores_trn.append(outputs.detach().to(torch.float).cpu())
481
- y_true_trn.append(y_batch.cpu())
482
- y_mask_trn.append(y_mask.cpu())
483
-
484
- # update progress bar
485
- if self.verbose > 1:
486
- batch_size = len(next(iter(x_batch.values())))
487
- pbr_batch.update(batch_size, {})
488
- pbr_batch.refresh()
489
-
490
- # clear cuda cache
491
- if "cuda" in self.device:
492
- torch.cuda.empty_cache()
493
-
494
- # for better tqdm progress bar display
495
- if self.verbose > 1:
496
- pbr_batch.close()
497
-
498
- # calculate and print training performance metrics
499
- scores_trn = torch.cat(scores_trn)
500
- y_true_trn = torch.cat(y_true_trn)
501
- y_mask_trn = torch.cat(y_mask_trn)
502
- y_pred_trn = (scores_trn > 0).to(torch.int)
503
- y_prob_trn = torch.sigmoid(scores_trn)
504
- met_trn = get_metrics_multitask(
505
- y_true_trn.numpy(),
506
- y_pred_trn.numpy(),
507
- y_prob_trn.numpy(),
508
- y_mask_trn.numpy()
509
- )
510
-
511
- # add loss to metrics
512
- for i in range(len(self.tgt_modalities)):
513
- met_trn[i]['Loss'] = np.mean(losses_trn[i])
514
-
515
- # log metrics to wandb
516
- wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
517
- wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
518
-
519
- wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
520
- wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
521
-
522
- if self.verbose > 2:
523
- print_metrics_multitask(met_trn)
524
-
525
- return met_trn
526
-
527
- def validate_one_epoch(self, ldr_vld, epoch):
528
- # # progress bar for validation
529
- if self.verbose > 1:
530
- pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
531
-
532
- # set model to validation mode
533
- torch.set_grad_enabled(False)
534
- self.net_.eval()
535
-
536
- scores_vld, y_true_vld, y_mask_vld = [], [], []
537
- losses_vld = [[] for _ in self.tgt_modalities]
538
- for x_batch, y_batch, mask, y_mask in ldr_vld:
539
- # if len(next(iter(x_batch.values()))) < self.batch_size:
540
- # break
541
- # mount data to the proper device
542
- x_batch = {k: x_batch[k].to(self.device) for k in x_batch} # if 'img' not in k}
543
- # x_img_batch = {k: x_img_batch[k].to(self.device) for k in x_img_batch}
544
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
545
- mask = {k: mask[k].to(self.device) for k in mask}
546
- y_mask = {k: y_mask[k].to(self.device) for k in y_mask}
547
-
548
- # forward
549
- with torch.autocast(
550
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
551
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
552
- enabled = self._amp_enabled
553
- ):
554
-
555
- outputs = self.net_(x_batch, mask, skip_embedding=self.skip_embedding)
556
-
557
- # calculate multitask loss
558
- for i, k in enumerate(self.tgt_modalities):
559
- loss_task = self.loss_fn[k](outputs[k], y_batch[k])
560
- msk_loss_task = loss_task * y_mask[k]
561
- losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
562
-
563
- ''' TODO: change array to dictionary later '''
564
- outputs = torch.stack(list(outputs.values()), dim=1)
565
- y_batch = torch.stack(list(y_batch.values()), dim=1)
566
- y_mask = torch.stack(list(y_mask.values()), dim=1)
567
-
568
- # save outputs to evaluate performance later
569
- scores_vld.append(outputs.detach().to(torch.float).cpu())
570
- y_true_vld.append(y_batch.cpu())
571
- y_mask_vld.append(y_mask.cpu())
572
-
573
- # update progress bar
574
- if self.verbose > 1:
575
- batch_size = len(next(iter(x_batch.values())))
576
- pbr_batch.update(batch_size, {})
577
- pbr_batch.refresh()
578
-
579
- # clear cuda cache
580
- if "cuda" in self.device:
581
- torch.cuda.empty_cache()
582
-
583
- # for better tqdm progress bar display
584
- if self.verbose > 1:
585
- pbr_batch.close()
586
-
587
- # calculate and print validation performance metrics
588
- scores_vld = torch.cat(scores_vld)
589
- y_true_vld = torch.cat(y_true_vld)
590
- y_mask_vld = torch.cat(y_mask_vld)
591
- y_pred_vld = (scores_vld > 0).to(torch.int)
592
- y_prob_vld = torch.sigmoid(scores_vld)
593
- met_vld = get_metrics_multitask(
594
- y_true_vld.numpy(),
595
- y_pred_vld.numpy(),
596
- y_prob_vld.numpy(),
597
- y_mask_vld.numpy()
598
- )
599
-
600
- # add loss to metrics
601
- for i in range(len(self.tgt_modalities)):
602
- met_vld[i]['Loss'] = np.mean(losses_vld[i])
603
-
604
- wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
605
- wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
606
-
607
- wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
608
- wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
609
-
610
- if self.verbose > 2:
611
- print_metrics_multitask(met_vld)
612
-
613
- return met_vld
614
-
615
-
616
- def predict_logits(self,
617
- x: list[dict[str, Any]],
618
- _batch_size: int | None = None,
619
- skip_embedding: dict | None = None,
620
- img_transform: Any | None = None,
621
- ) -> list[dict[str, float]]:
622
- '''
623
- The input x can be a single sample or a list of samples.
624
- '''
625
- # input validation
626
- check_is_fitted(self)
627
- print(self.device)
628
-
629
- # for PyTorch computational efficiency
630
- torch.set_num_threads(1)
631
-
632
- # set model to eval mode
633
- torch.set_grad_enabled(False)
634
- self.net_.eval()
635
-
636
- # intialize dataset and dataloader object
637
- dat = TransformerTestingDataset(x, self.src_modalities, img_transform=img_transform)
638
- ldr = DataLoader(
639
- dataset = dat,
640
- batch_size = _batch_size if _batch_size is not None else len(x),
641
- shuffle = False,
642
- drop_last = False,
643
- num_workers = 0,
644
- collate_fn = TransformerTestingDataset.collate_fn,
645
- )
646
- # print("dataloader done")
647
-
648
- # run model and collect results
649
- logits: list[dict[str, float]] = []
650
- for x_batch, mask in tqdm(ldr):
651
- # mount data to the proper device
652
- # print(x_batch['his_SEX'])
653
- x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
654
- mask = {k: mask[k].to(self.device) for k in mask}
655
-
656
- # forward
657
- output: dict[str, Tensor] = self.net_(x_batch, mask, skip_embedding)
658
-
659
- # convert output from dict-of-list to list of dict, then append
660
- tmp = {k: output[k].tolist() for k in self.tgt_modalities}
661
- tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
662
- logits += tmp
663
-
664
- return logits
665
-
666
- def predict_proba(self,
667
- x: list[dict[str, Any]],
668
- skip_embedding: dict | None = None,
669
- temperature: float = 1.0,
670
- _batch_size: int | None = None,
671
- img_transform: Any | None = None,
672
- ) -> list[dict[str, float]]:
673
- ''' ... '''
674
- logits = self.predict_logits(x=x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
675
- print("got logits")
676
- return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
677
-
678
- def predict(self,
679
- x: list[dict[str, Any]],
680
- skip_embedding: dict | None = None,
681
- fpr: dict[str, Any] | None = None,
682
- tpr: dict[str, Any] | None = None,
683
- thresholds: dict[str, Any] | None = None,
684
- _batch_size: int | None = None,
685
- img_transform: Any | None = None,
686
- ) -> list[dict[str, int]]:
687
- ''' ... '''
688
- if fpr is None or tpr is None or thresholds is None:
689
- logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
690
- print("got proba")
691
- return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
692
- else:
693
- logits, proba = self.predict_proba(x, _batch_size=_batch_size, img_transform=img_transform, skip_embedding=skip_embedding)
694
- print("got proba")
695
- youden_index = {}
696
- thr = {}
697
- for i, k in enumerate(self.tgt_modalities):
698
- youden_index[k] = tpr[i] - fpr[i]
699
- thr[k] = thresholds[i][np.argmax(youden_index[k])]
700
- # print(thr[k])
701
- # print(thr)
702
- return logits, proba, [{k: int(smp[k] > thr[k]) for k in self.tgt_modalities} for smp in proba]
703
-
704
- def save(self, filepath: str, epoch: int) -> None:
705
- """Save the model to the given file stream.
706
-
707
- :param filepath: _description_
708
- :type filepath: str
709
- :param epoch: _description_
710
- :type epoch: int
711
- """
712
- check_is_fitted(self)
713
- if self.data_parallel:
714
- state_dict = self.net_.module.state_dict()
715
- else:
716
- state_dict = self.net_.state_dict()
717
-
718
- # attach model hyper parameters
719
- state_dict['src_modalities'] = self.src_modalities
720
- state_dict['tgt_modalities'] = self.tgt_modalities
721
- state_dict['d_model'] = self.d_model
722
- state_dict['nhead'] = self.nhead
723
- state_dict['num_encoder_layers'] = self.num_encoder_layers
724
- state_dict['num_decoder_layers'] = self.num_decoder_layers
725
- state_dict['optimizer'] = self.optimizer
726
- state_dict['img_net'] = self.img_net
727
- state_dict['imgnet_layers'] = self.imgnet_layers
728
- state_dict['img_size'] = self.img_size
729
- state_dict['patch_size'] = self.patch_size
730
- state_dict['imgnet_ckpt'] = self.imgnet_ckpt
731
- state_dict['train_imgnet'] = self.train_imgnet
732
- state_dict['epoch'] = epoch
733
-
734
- if self.scaler is not None:
735
- state_dict['scaler'] = self.scaler.state_dict()
736
- if self.label_distribution:
737
- state_dict['label_distribution'] = self.label_distribution
738
-
739
- torch.save(state_dict, filepath)
740
-
741
- def load(self, filepath: str, map_location: str = 'cpu', img_dict=None) -> None:
742
- """Load a model from the given file stream.
743
-
744
- :param filepath: _description_
745
- :type filepath: str
746
- :param map_location: _description_, defaults to 'cpu'
747
- :type map_location: str, optional
748
- :param img_dict: _description_, defaults to None
749
- :type img_dict: _type_, optional
750
- """
751
- # load state_dict
752
- state_dict = torch.load(filepath, map_location=map_location)
753
-
754
- # load data modalities
755
- self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
756
- self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
757
- if 'label_distribution' in state_dict:
758
- self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
759
- if 'optimizer' in state_dict:
760
- self.optimizer = state_dict.pop('optimizer')
761
-
762
- # initialize model
763
- self.d_model = state_dict.pop('d_model')
764
- self.nhead = state_dict.pop('nhead')
765
- self.num_encoder_layers = state_dict.pop('num_encoder_layers')
766
- self.num_decoder_layers = state_dict.pop('num_decoder_layers')
767
- if 'epoch' in state_dict.keys():
768
- self.start_epoch = state_dict.pop('epoch')
769
- if img_dict is None:
770
- self.img_net = state_dict.pop('img_net')
771
- self.imgnet_layers = state_dict.pop('imgnet_layers')
772
- self.img_size = state_dict.pop('img_size')
773
- self.patch_size = state_dict.pop('patch_size')
774
- self.imgnet_ckpt = state_dict.pop('imgnet_ckpt')
775
- self.train_imgnet = state_dict.pop('train_imgnet')
776
- else:
777
- self.img_net = img_dict['img_net']
778
- self.imgnet_layers = img_dict['imgnet_layers']
779
- self.img_size = img_dict['img_size']
780
- self.patch_size = img_dict['patch_size']
781
- self.imgnet_ckpt = img_dict['imgnet_ckpt']
782
- self.train_imgnet = img_dict['train_imgnet']
783
- state_dict.pop('img_net')
784
- state_dict.pop('imgnet_layers')
785
- state_dict.pop('img_size')
786
- state_dict.pop('patch_size')
787
- state_dict.pop('imgnet_ckpt')
788
- state_dict.pop('train_imgnet')
789
-
790
- for k, info in self.src_modalities.items():
791
- if info['type'] == 'imaging':
792
- if 'emb' not in self.img_net.lower():
793
- info['shape'] = (1,) + (self.img_size,) * 3
794
- info['img_shape'] = (1,) + (self.img_size,) * 3
795
- elif 'swinunetr' in self.img_net.lower():
796
- info['shape'] = (1, 768, 4, 4, 4)
797
- info['img_shape'] = (1, 768, 4, 4, 4)
798
- # print(info['shape'])
799
-
800
- self.net_ = Transformer(self.src_modalities, self.tgt_modalities, self.d_model, self.nhead, self.num_encoder_layers, self.num_decoder_layers, self.device, self.cuda_devices, self.img_net, self.imgnet_layers, self.img_size, self.patch_size, self.imgnet_ckpt, self.train_imgnet, self.fusion_stage)
801
-
802
-
803
- if 'scaler' in state_dict and state_dict['scaler']:
804
- self.scaler.load_state_dict(state_dict.pop('scaler'))
805
- self.net_.load_state_dict(state_dict)
806
- check_is_fitted(self)
807
- self.net_.to(self.device)
808
-
809
- def to(self, device: str) -> Self:
810
- """Mount the model to the given device.
811
-
812
- :param device: _description_
813
- :type device: str
814
- :return: _description_
815
- :rtype: Self
816
- """
817
- self.device = device
818
- if hasattr(self, 'model'): self.net_ = self.net_.to(device)
819
- if hasattr(self, 'img_model'): self.img_model = self.img_model.to(device)
820
- return self
821
-
822
- @classmethod
823
- def from_ckpt(cls, filepath: str, device='cpu', img_dict=None) -> Self:
824
- """Create a new ADRD model and load parameters from the checkpoint.
825
-
826
- This is an alternative constructor.
827
-
828
- :param filepath: _description_
829
- :type filepath: str
830
- :param device: _description_, defaults to 'cpu'
831
- :type device: str, optional
832
- :param img_dict: _description_, defaults to None
833
- :type img_dict: _type_, optional
834
- :return: _description_
835
- :rtype: Self
836
- """
837
- obj = cls(None, None, None,device=device)
838
- if device == 'cuda':
839
- obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
840
- print(obj.device)
841
- obj.load(filepath, map_location=obj.device, img_dict=img_dict)
842
- return obj
843
-
844
- def _init_net(self):
845
- """ ... """
846
- # set the device for use
847
- if self.device == 'cuda':
848
- self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
849
- print("Device: " + self.device)
850
-
851
- self.start_epoch = 0
852
- if self.load_from_ckpt:
853
- try:
854
- print("Loading model from checkpoint...")
855
- self.load(self.ckpt_path, map_location=self.device)
856
- except:
857
- print("Cannot load from checkpoint. Initializing new model...")
858
- self.load_from_ckpt = False
859
-
860
- if not self.load_from_ckpt:
861
- self.net_ = nn.Transformer(
862
- src_modalities = self.src_modalities,
863
- tgt_modalities = self.tgt_modalities,
864
- d_model = self.d_model,
865
- nhead = self.nhead,
866
- num_encoder_layers = self.num_encoder_layers,
867
- num_decoder_layers = self.num_decoder_layers,
868
- device = self.device,
869
- cuda_devices = self.cuda_devices,
870
- img_net = self.img_net,
871
- layers = self.imgnet_layers,
872
- img_size = self.img_size,
873
- patch_size = self.patch_size,
874
- imgnet_ckpt = self.imgnet_ckpt,
875
- train_imgnet = self.train_imgnet,
876
- fusion_stage = self.fusion_stage,
877
- )
878
-
879
- # intialize model parameters using xavier_uniform
880
- for name, p in self.net_.named_parameters():
881
- if p.dim() > 1:
882
- torch.nn.init.xavier_uniform_(p)
883
-
884
- self.net_.to(self.device)
885
-
886
- # Initialize the number of GPUs
887
- if self.data_parallel and torch.cuda.device_count() > 1:
888
- print("Available", torch.cuda.device_count(), "GPUs!")
889
- self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
890
-
891
- # return net
892
-
893
- def _init_dataloader(self, x_trn, x_vld, y_trn, y_vld, img_train_trans=None, img_vld_trans=None):
894
- # initialize dataset and dataloader
895
- if self.balanced_sampling:
896
- dat_trn = Transformer2ndOrderBalancedTrainingDataset(
897
- x_trn, y_trn,
898
- self.src_modalities,
899
- self.tgt_modalities,
900
- dropout_rate = .5,
901
- dropout_strategy = 'permutation',
902
- img_transform=img_train_trans,
903
- )
904
- else:
905
- dat_trn = TransformerTrainingDataset(
906
- x_trn, y_trn,
907
- self.src_modalities,
908
- self.tgt_modalities,
909
- dropout_rate = .5,
910
- dropout_strategy = 'permutation',
911
- img_transform=img_train_trans,
912
- )
913
-
914
- dat_vld = TransformerValidationDataset(
915
- x_vld, y_vld,
916
- self.src_modalities,
917
- self.tgt_modalities,
918
- img_transform=img_vld_trans,
919
- )
920
-
921
- ldr_trn = DataLoader(
922
- dataset = dat_trn,
923
- batch_size = self.batch_size,
924
- shuffle = True,
925
- drop_last = False,
926
- num_workers = self._dataloader_num_workers,
927
- collate_fn = TransformerTrainingDataset.collate_fn,
928
- # pin_memory = True
929
- )
930
-
931
- ldr_vld = DataLoader(
932
- dataset = dat_vld,
933
- batch_size = self.batch_size,
934
- shuffle = False,
935
- drop_last = False,
936
- num_workers = self._dataloader_num_workers,
937
- collate_fn = TransformerValidationDataset.collate_fn,
938
- # pin_memory = True
939
- )
940
-
941
- return ldr_trn, ldr_vld
942
-
943
- def _init_optimizer(self):
944
- """ ... """
945
- params = list(self.net_.parameters())
946
- return torch.optim.AdamW(
947
- params,
948
- lr = self.lr,
949
- betas = (0.9, 0.98),
950
- weight_decay = self.weight_decay
951
- )
952
-
953
- def _init_scheduler(self, optimizer):
954
- """ ... """
955
-
956
- return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
957
- optimizer=optimizer,
958
- T_0=64,
959
- T_mult=2,
960
- eta_min = 0,
961
- verbose=(self.verbose > 2)
962
- )
963
-
964
- def _init_loss_func(self,
965
- num_per_cls: dict[str, tuple[int, int]],
966
- ) -> dict[str, Module]:
967
- """ ... """
968
- return {k: nn.SigmoidFocalLossBeta(
969
- beta = self.beta,
970
- gamma = self.gamma,
971
- num_per_cls = num_per_cls[k],
972
- reduction = 'none',
973
- ) for k in self.tgt_modalities}
974
-
975
- def _proc_fit(self):
976
- """ ... """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/calibration.py DELETED
@@ -1,450 +0,0 @@
1
- import numpy as np
2
- from sklearn.base import BaseEstimator
3
- from sklearn.utils.validation import check_is_fitted
4
- from sklearn.linear_model import LogisticRegression
5
- from sklearn.isotonic import IsotonicRegression
6
- from functools import lru_cache
7
- from functools import cached_property
8
- from typing import Self, Any
9
- from pickle import dump
10
- from pickle import load
11
- from abc import ABC, abstractmethod
12
-
13
- from . import ADRDModel
14
- from ..utils import Formatter
15
- from ..utils import MissingMasker
16
-
17
-
18
- def calibration_curve(
19
- y_true: list[int],
20
- y_pred: list[float],
21
- n_bins: int = 10,
22
- ratio: float = 1.0,
23
- ) -> tuple[list[float], list[float]]:
24
- """
25
- Compute true and predicted probabilities for a calibration curve. The method
26
- assumes the inputs come from a binary classifier, and discretize the [0, 1]
27
- interval into bins.
28
-
29
- Note that this function is an alternative to
30
- sklearn.calibration.calibration_curve() which can only estimate the absolute
31
- proportion of positive cases in each bin.
32
-
33
- Parameters
34
- ----------
35
- y_true : list[int]
36
- True targets.
37
- y_pred : list[float]
38
- Probabilities of the positive class.
39
- n_bins : int, default=10
40
- Number of bins to discretize the [0, 1] interval. A bigger number
41
- requires more data. Bins with no samples (i.e. without corresponding
42
- values in y_prob) will not be returned, thus the returned arrays may
43
- have less than n_bins values.
44
- ratio : float, default=1.0
45
- Used to adjust the class balance.
46
-
47
- Returns
48
- -------
49
- prob_true : list[float]
50
- The proportion of positive samples in each bin.
51
- prob_pred : list[float]
52
- The mean predicted probability in each bin.
53
- """
54
- # generate "n_bin" intervals
55
- tmp = np.around(np.linspace(0, 1, n_bins + 1), decimals=6)
56
- intvs = [(tmp[i - 1], tmp[i]) for i in range(1, len(tmp))]
57
-
58
- # pair up (pred, true) and group them by intervals
59
- tmp = list(zip(y_pred, y_true))
60
- intv_pairs = {(l, r): [p for p in tmp if l <= p[0] < r] for l, r in intvs}
61
-
62
- # calculate balanced proportion of POSITIVE cases for each intervel
63
- # along with the balanced averaged predictions
64
- intv_prob_true: dict[tuple, float] = dict()
65
- intv_prob_pred: dict[tuple, float] = dict()
66
- for intv, pairs in intv_pairs.items():
67
- # number of cases that fall into the interval
68
- n_pairs = len(pairs)
69
-
70
- # it's likely that no predictions fall into the interval
71
- if n_pairs == 0: continue
72
-
73
- # count number of positives and negatives in the interval
74
- n_pos = sum([p[1] for p in pairs])
75
- n_neg = n_pairs - n_pos
76
-
77
- # calculate adjusted proportion of positives
78
- intv_prob_true[intv] = n_pos / (n_pos + n_neg * ratio)
79
-
80
- # calculate adjusted avg. predictions
81
- sum_pred_pos = sum([p[0] for p in pairs if p[1] == 1])
82
- sum_pred_neg = sum([p[0] for p in pairs if p[1] == 0])
83
- intv_prob_pred[intv] = (sum_pred_pos + sum_pred_neg * ratio)
84
- intv_prob_pred[intv] /= (n_pos + n_neg * ratio)
85
-
86
- prob_true = list(intv_prob_true.values())
87
- prob_pred = list(intv_prob_pred.values())
88
- return prob_true, prob_pred
89
-
90
-
91
- class CalibrationCore(BaseEstimator):
92
- """
93
- A wrapper class of multiple regressors to predict the proportions of
94
- positive samples from the predicted probabilities. The method for
95
- calibration can be 'sigmoid' which corresponds to Platt's method (i.e. a
96
- logistic regression model) or 'isotonic' which is a non-parametric approach.
97
- It is not advised to use isotonic calibration with too few calibration
98
- samples (<<1000) since it tends to overfit.
99
-
100
- TODO
101
- ----
102
- - 'sigmoid' method is not trivial to implement.
103
- """
104
- def __init__(self,
105
- method: str = 'isotonic',
106
- ) -> None:
107
- """
108
- Initialization function of CalibrationCore class.
109
-
110
- Parameters
111
- ----------
112
- method : {'sigmoid', 'isotonic'}, default='isotonic'
113
- The method to use for calibration. can be 'sigmoid' which
114
- corresponds to Platt's method (i.e. a logistic regression model) or
115
- 'isotonic' which is a non-parametric approach. It is not advised to
116
- use isotonic calibration with too few calibration samples (<<1000)
117
- since it tends to overfit.
118
-
119
- Raises
120
- ------
121
- ValueError
122
- Sigmoid approach has not been implemented.
123
- """
124
- assert method in ('sigmoid', 'isotonic')
125
- if method == 'sigmoid':
126
- raise ValueError('Sigmoid approach has not been implemented.')
127
- self.method = method
128
-
129
- def fit(self,
130
- prob_pred: list[float],
131
- prob_true: list[float],
132
- ) -> Self:
133
- """
134
- Fit the underlying regressor using prob_pred, prob_true as training
135
- data.
136
-
137
- Parameters
138
- ----------
139
- prob_pred : list[float]
140
- Probabilities predicted directly by a model.
141
- prob_true : list[float]
142
- Target probabilities to calibrate to.
143
-
144
- Returns
145
- -------
146
- Self
147
- CalibrationCore object.
148
- """
149
- # using Platt's method for calibration
150
- if self.method == 'sigmoid':
151
- self.model_ = LogisticRegression()
152
- self.model_.fit(prob_pred, prob_true)
153
-
154
- # using isotonic calibration
155
- elif self.method == 'isotonic':
156
- self.model_ = IsotonicRegression(y_min=0, y_max=1, out_of_bounds='clip')
157
- self.model_.fit(prob_pred, prob_true)
158
-
159
- return self
160
-
161
- def predict(self,
162
- prob_pred: list[float],
163
- ) -> list[float]:
164
- """
165
- Calibrate the input probabilities using the fitted regressor.
166
-
167
- Parameters
168
- ----------
169
- prob_pred : list[float]
170
- Probabilities predicted directly by a model.
171
-
172
- Returns
173
- -------
174
- prob_cali : list[float]
175
- Calibrated probabilities.
176
- """
177
- # as usual, the core needs to be fitted
178
- check_is_fitted(self)
179
-
180
- # note that logistic regression is classification model, we need to call
181
- # 'predict_proba' instead of 'predict' to get the calibrated results
182
- if self.method == 'sigmoid':
183
- prob_cali = self.model_.predict_proba(prob_pred)
184
- elif self.method == 'isotonic':
185
- prob_cali = self.model_.predict(prob_pred)
186
-
187
- return prob_cali
188
-
189
-
190
- class CalibratedClassifier(ABC):
191
- """
192
- Abstract class of calibrated classifier.
193
- """
194
- def __init__(self,
195
- model: ADRDModel,
196
- background_src: list[dict[str, Any]],
197
- background_tgt: list[dict[str, Any]],
198
- background_is_embedding: dict[str, bool] | None = None,
199
- method: str = 'isotonic',
200
- ) -> None:
201
- """
202
- Constructor of Calibrator class.
203
-
204
- Parameters
205
- ----------
206
- model : ADRDModel
207
- Fitted model to calibrate.
208
- background_src : list[dict[str, Any]]
209
- Features of the background dataset.
210
- background_tgt : list[dict[str, Any]]
211
- Labels of the background dataset.
212
- method : {'sigmoid', 'isotonic'}, default='isotonic'
213
- Method used by the underlying regressor.
214
- """
215
- self.method = method
216
- self.model = model
217
- self.src_modalities = model.src_modalities
218
- self.tgt_modalities = model.tgt_modalities
219
- self.background_is_embedding = background_is_embedding
220
-
221
- # format background data
222
- fmt_src = Formatter(self.src_modalities)
223
- fmt_tgt = Formatter(self.tgt_modalities)
224
- self.background_src = [fmt_src(smp) for smp in background_src]
225
- self.background_tgt = [fmt_tgt(smp) for smp in background_tgt]
226
-
227
- @abstractmethod
228
- def predict_proba(self,
229
- src: list[dict[str, Any]],
230
- is_embedding: dict[str, bool] | None = None,
231
- ) -> list[dict[str, float]]:
232
- """
233
- This method returns calibrated probabilities of classification.
234
-
235
- Parameters
236
- ----------
237
- src : list[dict[str, Any]]
238
- Features of the input samples.
239
-
240
- Returns
241
- -------
242
- list[dict[str, float]]
243
- Calibrated probabilities.
244
- """
245
- pass
246
-
247
- def predict(self,
248
- src: list[dict[str, Any]],
249
- is_embedding: dict[str, bool] | None = None,
250
- ) -> list[dict[str, int]]:
251
- """
252
- Make predictions based on the results of predict_proba().
253
-
254
- Parameters
255
- ----------
256
- x : list[dict[str, Any]]
257
- Input features.
258
-
259
- Returns
260
- -------
261
- list[dict[str, int]]
262
- Calibrated predictions.
263
- """
264
- proba = self.predict_proba(src, is_embedding)
265
- return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
266
-
267
- def save(self,
268
- filepath_state_dict: str,
269
- ) -> None:
270
- """
271
- Save the state dict and the underlying model to the given paths.
272
-
273
- Parameters
274
- ----------
275
- filepath_state_dict : str
276
- File path to save the state_dict which includes the background
277
- dataset and the regressor information.
278
- filepath_wrapped_model : str | None, default=None
279
- File path to save the wrapped model. If None, the model won't be
280
- saved.
281
- """
282
- # save state dict
283
- state_dict = {
284
- 'background_src': self.background_src,
285
- 'background_tgt': self.background_tgt,
286
- 'background_is_embedding': self.background_is_embedding,
287
- 'method': self.method,
288
- }
289
- with open(filepath_state_dict, 'wb') as f:
290
- dump(state_dict, f)
291
-
292
- @classmethod
293
- def from_ckpt(cls,
294
- filepath_state_dict: str,
295
- filepath_wrapped_model: str,
296
- ) -> Self:
297
- """
298
- Alternative constructor which loads from checkpoint.
299
-
300
- Parameters
301
- ----------
302
- filepath_state_dict : str
303
- File path to load the state_dict which includes the background
304
- dataset and the regressor information.
305
- filepath_wrapped_model : str
306
- File path of the wrapped model.
307
-
308
- Returns
309
- -------
310
- Self
311
- CalibratedClassifier class object.
312
- """
313
- with open(filepath_state_dict, 'rb') as f:
314
- kwargs = load(f)
315
- kwargs['model'] = ADRDModel.from_ckpt(filepath_wrapped_model)
316
- return cls(**kwargs)
317
-
318
-
319
- class DynamicCalibratedClassifier(CalibratedClassifier):
320
- """
321
- The dynamic approach generates background predictions based on the
322
- missingness pattern of each input. With an astronomical number of
323
- missingness patterns, calibrating each sample requires a comprehensive
324
- process that involves running the ADRDModel on the majority of the
325
- background data and training a corresponding regressor. This results in a
326
- computationally intensive calculation.
327
- """
328
- def predict_proba(self,
329
- src: list[dict[str, Any]],
330
- is_embedding: dict[str, bool] | None = None,
331
- ) -> list[dict[str, float]]:
332
-
333
- # initialize mask generator and format inputs
334
- msk_gen = MissingMasker(self.src_modalities)
335
- fmt_src = Formatter(self.src_modalities)
336
- src = [fmt_src(smp) for smp in src]
337
-
338
- # calculate calibrated probabilities
339
- calibrated_prob: list[dict[str, float]] = []
340
- for smp in src:
341
- # model output and missingness pattern
342
- prob = self.model.predict_proba([smp], is_embedding)[0]
343
- mask = tuple(msk_gen(smp).values())
344
-
345
- # get/fit core and calculate calibrated probabilities
346
- core = self._fit_core(mask)
347
- calibrated_prob.append({k: core[k].predict([prob[k]])[0] for k in self.tgt_modalities})
348
-
349
- return calibrated_prob
350
-
351
- # @lru_cache(maxsize = None)
352
- def _fit_core(self,
353
- missingness_pattern: tuple[bool],
354
- ) -> dict[str, CalibrationCore]:
355
- ''' ... '''
356
- # remove features from all background samples accordingly
357
- background_src, background_tgt = [], []
358
- for src, tgt in zip(self.background_src, self.background_tgt):
359
- src = {k: v for j, (k, v) in enumerate(src.items()) if missingness_pattern[j] == False}
360
-
361
- # make sure there is at least one feature available
362
- if len([v is not None for v in src.values()]) == 0: continue
363
- background_src.append(src)
364
- background_tgt.append(tgt)
365
-
366
- # run model on background samples and collection predictions
367
- background_prob = self.model.predict_proba(background_src, self.background_is_embedding, _batch_size=1024)
368
-
369
- # list[dict] -> dict[list]
370
- N = len(background_src)
371
- background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
372
- background_true = {k: [background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
373
-
374
- # now, fit cores
375
- core: dict[str, CalibrationCore] = dict()
376
- for k in self.tgt_modalities:
377
- prob_true, prob_pred = calibration_curve(
378
- background_true[k], background_prob[k],
379
- ratio = self.background_ratio[k],
380
- )
381
- core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
382
-
383
- return core
384
-
385
- @cached_property
386
- def background_ratio(self) -> dict[str, float]:
387
- ''' The ratio of positives over negatives in the background dataset. '''
388
- return {k: self.background_n_pos[k] / self.background_n_neg[k] for k in self.tgt_modalities}
389
-
390
- @cached_property
391
- def background_n_pos(self) -> dict[str, int]:
392
- ''' Number of positives w.r.t each target in the background dataset. '''
393
- return {k: sum([d[k] for d in self.background_tgt]) for k in self.tgt_modalities}
394
-
395
- @cached_property
396
- def background_n_neg(self) -> dict[str, int]:
397
- ''' Number of negatives w.r.t each target in the background dataset. '''
398
- return {k: len(self.background_tgt) - self.background_n_pos[k] for k in self.tgt_modalities}
399
-
400
-
401
- class StaticCalibratedClassifier(CalibratedClassifier):
402
- """
403
- The static approach generates background predictions without considering the
404
- missingness patterns.
405
- """
406
- def predict_proba(self,
407
- src: list[dict[str, Any]],
408
- is_embedding: dict[str, bool] | None = None,
409
- ) -> list[dict[str, float]]:
410
-
411
- # number of input samples
412
- N = len(src)
413
-
414
- # format inputs, and run ADRDModel, and convert to dict[list]
415
- fmt_src = Formatter(self.src_modalities)
416
- src = [fmt_src(smp) for smp in src]
417
- prob = self.model.predict_proba(src, is_embedding)
418
- prob = {k: [prob[i][k] for i in range(N)] for k in self.tgt_modalities}
419
-
420
- # calibrate probabilities
421
- core = self._fit_core()
422
- calibrated_prob = {k: core[k].predict(prob[k]) for k in self.tgt_modalities}
423
-
424
- # convert back to list[dict]
425
- calibrated_prob: list[dict[str, float]] = [
426
- {k: calibrated_prob[k][i] for k in self.tgt_modalities} for i in range(N)
427
- ]
428
- return calibrated_prob
429
-
430
- @lru_cache(maxsize = None)
431
- def _fit_core(self) -> dict[str, CalibrationCore]:
432
- ''' ... '''
433
- # run model on background samples and collection predictions
434
- background_prob = self.model.predict_proba(self.background_src, self.background_is_embedding, _batch_size=1024)
435
-
436
- # list[dict] -> dict[list]
437
- N = len(self.background_src)
438
- background_prob = {k: [background_prob[i][k] for i in range(N)] for k in self.tgt_modalities}
439
- background_true = {k: [self.background_tgt[i][k] for i in range(N)] for k in self.tgt_modalities}
440
-
441
- # now, fit cores
442
- core: dict[str, CalibrationCore] = dict()
443
- for k in self.tgt_modalities:
444
- prob_true, prob_pred = calibration_curve(
445
- background_true[k], background_prob[k],
446
- ratio = 1.0,
447
- )
448
- core[k] = CalibrationCore(self.method).fit(prob_pred, prob_true)
449
-
450
- return core
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/cnn_resnet3d_with_linear_classifier.py DELETED
@@ -1,533 +0,0 @@
1
- __all__ = ['CNNResNet3DWithLinearClassifier']
2
-
3
- import torch
4
- from torch.utils.data import DataLoader
5
- import numpy as np
6
- import tqdm
7
- from sklearn.base import BaseEstimator
8
- from sklearn.utils.validation import check_is_fitted
9
- from sklearn.model_selection import train_test_split
10
- from scipy.special import expit
11
- from copy import deepcopy
12
- from contextlib import suppress
13
- from typing import Any, Self, Type
14
- from functools import wraps
15
- Tensor = Type[torch.Tensor]
16
- Module = Type[torch.nn.Module]
17
-
18
- from ..utils.misc import ProgressBar
19
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
20
-
21
- from .. import nn
22
- from ..utils import TransformerTrainingDataset
23
- from ..utils import Transformer2ndOrderBalancedTrainingDataset
24
- from ..utils import TransformerValidationDataset
25
- from ..utils import TransformerTestingDataset
26
- from ..utils.misc import ProgressBar
27
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
28
- from ..utils.misc import convert_args_kwargs_to_kwargs
29
-
30
-
31
- def _manage_ctx_fit(func):
32
- ''' ... '''
33
- @wraps(func)
34
- def wrapper(*args, **kwargs):
35
- # format arguments
36
- kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
37
-
38
- if kwargs['self']._device_ids is None:
39
- return func(**kwargs)
40
- else:
41
- # change primary device
42
- default_device = kwargs['self'].device
43
- kwargs['self'].device = kwargs['self']._device_ids[0]
44
- rtn = func(**kwargs)
45
-
46
- # the actual module is wrapped
47
- kwargs['self'].net_ = kwargs['self'].net_.module
48
- kwargs['self'].to(default_device)
49
- return rtn
50
- return wrapper
51
-
52
-
53
- class CNNResNet3DWithLinearClassifier(BaseEstimator):
54
-
55
- def __init__(self,
56
- src_modalities: dict[str, dict[str, Any]],
57
- tgt_modalities: dict[str, dict[str, Any]],
58
- num_epochs: int = 32,
59
- batch_size: int = 8,
60
- batch_size_multiplier: int = 1,
61
- lr: float = 1e-2,
62
- weight_decay: float = 0.0,
63
- beta: float = 0.9999,
64
- gamma: float = 2.0,
65
- scale: float = 1.0,
66
- criterion: str | None = None,
67
- device: str = 'cpu',
68
- verbose: int = 0,
69
- _device_ids: list | None = None,
70
- _dataloader_num_workers: int = 0,
71
- _amp_enabled: bool = False,
72
- _tmp_ckpt_filepath: str | None = None,
73
- ) -> None:
74
- ''' ... '''
75
- # for multiprocessing
76
- self._rank = 0
77
- self._lock = None
78
-
79
- # positional parameters
80
- self.src_modalities = src_modalities
81
- self.tgt_modalities = tgt_modalities
82
-
83
- # training parameters
84
- self.num_epochs = num_epochs
85
- self.batch_size = batch_size
86
- self.batch_size_multiplier = batch_size_multiplier
87
- self.lr = lr
88
- self.weight_decay = weight_decay
89
- self.beta = beta
90
- self.gamma = gamma
91
- self.scale = scale
92
- self.criterion = criterion
93
- self.device = device
94
- self.verbose = verbose
95
- self._device_ids = _device_ids
96
- self._dataloader_num_workers = _dataloader_num_workers
97
- self._amp_enabled = _amp_enabled
98
- self._tmp_ckpt_filepath = _tmp_ckpt_filepath
99
-
100
-
101
- @_manage_ctx_fit
102
- def fit(self, x, y) -> Self:
103
- ''' ... '''
104
- # for PyTorch computational efficiency
105
- torch.set_num_threads(1)
106
-
107
- # initialize neural network
108
- self.net_ = self._init_net()
109
-
110
- # initialize dataloaders
111
- ldr_trn, ldr_vld = self._init_dataloader(x, y)
112
-
113
- # initialize optimizer and scheduler
114
- optimizer = self._init_optimizer()
115
- scheduler = self._init_scheduler(optimizer)
116
-
117
- # gradient scaler for AMP
118
- if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
119
-
120
- # initialize loss function (binary cross entropy)
121
- loss_func = self._init_loss_func({
122
- k: (
123
- sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
124
- sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
125
- ) for k in self.tgt_modalities
126
- })
127
-
128
- # to record the best validation performance criterion
129
- if self.criterion is not None: best_crit = None
130
-
131
- # progress bar for epoch loops
132
- if self.verbose == 1:
133
- with self._lock if self._lock is not None else suppress():
134
- pbr_epoch = tqdm.tqdm(
135
- desc = 'Rank {:02d}'.format(self._rank),
136
- total = self.num_epochs,
137
- position = self._rank,
138
- ascii = True,
139
- leave = False,
140
- bar_format='{l_bar}{r_bar}'
141
- )
142
-
143
- # training loop
144
- for epoch in range(self.num_epochs):
145
- # progress bar for batch loops
146
- if self.verbose > 1:
147
- pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
148
-
149
- # set model to train mode
150
- torch.set_grad_enabled(True)
151
- self.net_.train()
152
-
153
- scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
154
- y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
155
- losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
156
- for n_iter, (x_batch, y_batch, _, mask_y) in enumerate(ldr_trn):
157
- # mount data to the proper device
158
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
159
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
160
- # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
161
- mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
162
-
163
- # forward
164
- with torch.autocast(
165
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
166
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
167
- enabled = self._amp_enabled,
168
- ):
169
- outputs = self.net_(x_batch)
170
-
171
- # calculate multitask loss
172
- loss = 0
173
- for i, tgt_k in enumerate(self.tgt_modalities):
174
- loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
175
- loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
176
- loss += loss_k.mean()
177
- losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
178
-
179
- # backward
180
- if self._amp_enabled:
181
- scaler.scale(loss).backward()
182
- else:
183
- loss.backward()
184
-
185
- # update parameters
186
- if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
187
- if self._amp_enabled:
188
- scaler.step(optimizer)
189
- scaler.update()
190
- optimizer.zero_grad()
191
- else:
192
- optimizer.step()
193
- optimizer.zero_grad()
194
-
195
- # save outputs to evaluate performance later
196
- for tgt_k in self.tgt_modalities:
197
- tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
198
- scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
199
- tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
200
- y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
201
-
202
- # update progress bar
203
- if self.verbose > 1:
204
- batch_size = len(next(iter(x_batch.values())))
205
- pbr_batch.update(batch_size, {})
206
- pbr_batch.refresh()
207
-
208
- # for better tqdm progress bar display
209
- if self.verbose > 1:
210
- pbr_batch.close()
211
-
212
- # set scheduler
213
- scheduler.step()
214
-
215
- # calculate and print training performance metrics
216
- y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
217
- y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
218
- for tgt_k in self.tgt_modalities:
219
- for i in range(len(scores_trn[tgt_k])):
220
- y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
221
- y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
222
- met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
223
-
224
- # add loss to metrics
225
- for tgt_k in self.tgt_modalities:
226
- met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
227
-
228
- if self.verbose > 2:
229
- print_metrics_multitask(met_trn)
230
-
231
- # progress bar for validation
232
- if self.verbose > 1:
233
- pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
234
-
235
- # set model to validation mode
236
- torch.set_grad_enabled(False)
237
- self.net_.eval()
238
-
239
- scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
240
- y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
241
- losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
242
- for x_batch, y_batch, _, mask_y in ldr_vld:
243
- # mount data to the proper device
244
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
245
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
246
- # mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
247
- mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
248
-
249
- # forward
250
- with torch.autocast(
251
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
252
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
253
- enabled = self._amp_enabled
254
- ):
255
- outputs = self.net_(x_batch)
256
-
257
- # calculate multitask loss
258
- for i, tgt_k in enumerate(self.tgt_modalities):
259
- loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
260
- loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
261
- losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
262
-
263
- # save outputs to evaluate performance later
264
- for tgt_k in self.tgt_modalities:
265
- tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
266
- scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
267
- tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
268
- y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
269
-
270
- # update progress bar
271
- if self.verbose > 1:
272
- batch_size = len(next(iter(x_batch.values())))
273
- pbr_batch.update(batch_size, {})
274
- pbr_batch.refresh()
275
-
276
- # for better tqdm progress bar display
277
- if self.verbose > 1:
278
- pbr_batch.close()
279
-
280
- # calculate and print validation performance metrics
281
- y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
282
- y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
283
- for tgt_k in self.tgt_modalities:
284
- for i in range(len(scores_vld[tgt_k])):
285
- y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
286
- y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
287
- met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
288
-
289
- # add loss to metrics
290
- for tgt_k in self.tgt_modalities:
291
- met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
292
-
293
- if self.verbose > 2:
294
- print_metrics_multitask(met_vld)
295
-
296
- # save the model if it has the best validation performance criterion by far
297
- if self.criterion is None: continue
298
-
299
- # is current criterion better than previous best?
300
- curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
301
- if best_crit is None or np.isnan(best_crit):
302
- is_better = True
303
- elif self.criterion == 'Loss' and best_crit >= curr_crit:
304
- is_better = True
305
- elif self.criterion != 'Loss' and best_crit <= curr_crit:
306
- is_better = True
307
- else:
308
- is_better = False
309
-
310
- # update best criterion
311
- if is_better:
312
- best_crit = curr_crit
313
- best_state_dict = deepcopy(self.net_.state_dict())
314
-
315
- if self._tmp_ckpt_filepath is not None:
316
- self.save(self._tmp_ckpt_filepath)
317
-
318
- if self.verbose > 2:
319
- print('Best {}: {}'.format(self.criterion, best_crit))
320
-
321
- if self.verbose == 1:
322
- with self._lock if self._lock is not None else suppress():
323
- pbr_epoch.update(1)
324
- pbr_epoch.refresh()
325
-
326
- if self.verbose == 1:
327
- with self._lock if self._lock is not None else suppress():
328
- pbr_epoch.close()
329
-
330
- # restore the model of the best validation performance across all epoches
331
- if ldr_vld is not None and self.criterion is not None:
332
- self.net_.load_state_dict(best_state_dict)
333
-
334
- return self
335
-
336
- def predict_logits(self,
337
- x: list[dict[str, Any]],
338
- _batch_size: int | None = None,
339
- ) -> list[dict[str, float]]:
340
- """
341
- The input x can be a single sample or a list of samples.
342
- """
343
- # input validation
344
- check_is_fitted(self)
345
-
346
- # for PyTorch computational efficiency
347
- torch.set_num_threads(1)
348
-
349
- # set model to eval mode
350
- torch.set_grad_enabled(False)
351
- self.net_.eval()
352
-
353
- # intialize dataset and dataloader object
354
- dat = TransformerTestingDataset(x, self.src_modalities)
355
- ldr = DataLoader(
356
- dataset = dat,
357
- batch_size = _batch_size if _batch_size is not None else len(x),
358
- shuffle = False,
359
- drop_last = False,
360
- num_workers = 0,
361
- collate_fn = TransformerTestingDataset.collate_fn,
362
- )
363
-
364
- # run model and collect results
365
- logits: list[dict[str, float]] = []
366
- for x_batch, _ in ldr:
367
- # mount data to the proper device
368
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
369
-
370
- # forward
371
- output: dict[str, Tensor] = self.net_(x_batch)
372
-
373
- # convert output from dict-of-list to list of dict, then append
374
- tmp = {k: output[k].tolist() for k in self.tgt_modalities}
375
- tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
376
- logits += tmp
377
-
378
- return logits
379
-
380
- def predict_proba(self,
381
- x: list[dict[str, Any]],
382
- temperature: float = 1.0,
383
- _batch_size: int | None = None,
384
- ) -> list[dict[str, float]]:
385
- ''' ... '''
386
- logits = self.predict_logits(x, _batch_size)
387
- return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
388
-
389
- def predict(self,
390
- x: list[dict[str, Any]],
391
- _batch_size: int | None = None,
392
- ) -> list[dict[str, int]]:
393
- ''' ... '''
394
- logits = self.predict_logits(x, _batch_size)
395
- return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
396
-
397
- def save(self, filepath: str) -> None:
398
- ''' ... '''
399
- check_is_fitted(self)
400
- state_dict = self.net_.state_dict()
401
-
402
- # attach model hyper parameters
403
- state_dict['src_modalities'] = self.src_modalities
404
- state_dict['tgt_modalities'] = self.tgt_modalities
405
- print('Saving model checkpoint to {} ... '.format(filepath), end='')
406
- torch.save(state_dict, filepath)
407
- print('Done.')
408
-
409
- def load(self, filepath: str) -> None:
410
- ''' ... '''
411
- # load state_dict
412
- state_dict = torch.load(filepath, map_location='cpu')
413
-
414
- # load essential parameters
415
- self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
416
- self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
417
-
418
- # initialize model
419
- self.net_ = nn.CNNResNet3DWithLinearClassifier(
420
- self.src_modalities,
421
- self.tgt_modalities,
422
- )
423
-
424
- # load model parameters
425
- self.net_.load_state_dict(state_dict)
426
- self.to(self.device)
427
-
428
- def to(self, device: str) -> Self:
429
- ''' Mount model to the given device. '''
430
- self.device = device
431
- if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
432
- return self
433
-
434
- @classmethod
435
- def from_ckpt(cls, filepath: str) -> Self:
436
- ''' ... '''
437
- obj = cls(None, None)
438
- obj.load(filepath)
439
- return obj
440
-
441
- def _init_net(self):
442
- """ ... """
443
- net = nn.CNNResNet3DWithLinearClassifier(
444
- self.src_modalities,
445
- self.tgt_modalities,
446
- ).to(self.device)
447
-
448
- # train on multiple GPUs using torch.nn.DataParallel
449
- if self._device_ids is not None:
450
- net = torch.nn.DataParallel(net, device_ids=self._device_ids)
451
-
452
- # intialize model parameters using xavier_uniform
453
- for p in net.parameters():
454
- if p.dim() > 1:
455
- torch.nn.init.xavier_uniform_(p)
456
-
457
- return net
458
-
459
- def _init_dataloader(self, x, y):
460
- """ ... """
461
- # split dataset
462
- x_trn, x_vld, y_trn, y_vld = train_test_split(
463
- x, y, test_size = 0.2, random_state = 0,
464
- )
465
-
466
- # initialize dataset and dataloader
467
- # dat_trn = TransformerTrainingDataset(
468
- dat_trn = Transformer2ndOrderBalancedTrainingDataset(
469
- x_trn, y_trn,
470
- self.src_modalities,
471
- self.tgt_modalities,
472
- dropout_rate = .5,
473
- # dropout_strategy = 'compensated',
474
- dropout_strategy = 'permutation',
475
- )
476
-
477
- dat_vld = TransformerValidationDataset(
478
- x_vld, y_vld,
479
- self.src_modalities,
480
- self.tgt_modalities,
481
- )
482
-
483
- ldr_trn = DataLoader(
484
- dataset = dat_trn,
485
- batch_size = self.batch_size,
486
- shuffle = True,
487
- drop_last = False,
488
- num_workers = self._dataloader_num_workers,
489
- collate_fn = TransformerTrainingDataset.collate_fn,
490
- # pin_memory = True
491
- )
492
-
493
- ldr_vld = DataLoader(
494
- dataset = dat_vld,
495
- batch_size = self.batch_size,
496
- shuffle = False,
497
- drop_last = False,
498
- num_workers = self._dataloader_num_workers,
499
- collate_fn = TransformerValidationDataset.collate_fn,
500
- # pin_memory = True
501
- )
502
-
503
- return ldr_trn, ldr_vld
504
-
505
- def _init_optimizer(self):
506
- """ ... """
507
- return torch.optim.AdamW(
508
- self.net_.parameters(),
509
- lr = self.lr,
510
- betas = (0.9, 0.98),
511
- weight_decay = self.weight_decay
512
- )
513
-
514
- def _init_scheduler(self, optimizer):
515
- """ ... """
516
- return torch.optim.lr_scheduler.OneCycleLR(
517
- optimizer = optimizer,
518
- max_lr = self.lr,
519
- total_steps = self.num_epochs,
520
- verbose = (self.verbose > 2)
521
- )
522
-
523
- def _init_loss_func(self,
524
- num_per_cls: dict[str, tuple[int, int]],
525
- ) -> dict[str, Module]:
526
- """ ... """
527
- return {k: nn.SigmoidFocalLoss(
528
- beta = self.beta,
529
- gamma = self.gamma,
530
- scale = self.scale,
531
- num_per_cls = num_per_cls[k],
532
- reduction = 'none',
533
- ) for k in self.tgt_modalities}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/imaging_model.py DELETED
@@ -1,843 +0,0 @@
1
- __all__ = ['Transformer']
2
-
3
- import wandb
4
- import torch
5
- import numpy as np
6
- import functools
7
- import inspect
8
- import monai
9
- import random
10
-
11
- from tqdm import tqdm
12
- from functools import wraps
13
- from sklearn.base import BaseEstimator
14
- from sklearn.utils.validation import check_is_fitted
15
- from sklearn.model_selection import train_test_split
16
- from scipy.special import expit
17
- from copy import deepcopy
18
- from contextlib import suppress
19
- from typing import Any, Self, Type
20
- Tensor = Type[torch.Tensor]
21
- Module = Type[torch.nn.Module]
22
- from torch.utils.data import DataLoader
23
- from monai.utils.type_conversion import convert_to_tensor
24
- from monai.transforms import (
25
- LoadImaged,
26
- Compose,
27
- CropForegroundd,
28
- CopyItemsd,
29
- SpatialPadd,
30
- EnsureChannelFirstd,
31
- Spacingd,
32
- OneOf,
33
- ScaleIntensityRanged,
34
- HistogramNormalized,
35
- RandSpatialCropSamplesd,
36
- RandSpatialCropd,
37
- CenterSpatialCropd,
38
- RandCoarseDropoutd,
39
- RandCoarseShuffled,
40
- Resized,
41
- )
42
-
43
- # for DistributedDataParallel
44
- import torch.distributed as dist
45
- import torch.multiprocessing as mp
46
- from torch.nn.parallel import DistributedDataParallel as DDP
47
-
48
- from .. import nn
49
- from ..utils.misc import ProgressBar
50
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
51
- from ..utils.misc import convert_args_kwargs_to_kwargs
52
-
53
- import warnings
54
- warnings.filterwarnings("ignore")
55
-
56
-
57
- def _manage_ctx_fit(func):
58
- ''' ... '''
59
- @wraps(func)
60
- def wrapper(*args, **kwargs):
61
- # format arguments
62
- kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
63
-
64
- if kwargs['self']._device_ids is None:
65
- return func(**kwargs)
66
- else:
67
- # change primary device
68
- default_device = kwargs['self'].device
69
- kwargs['self'].device = kwargs['self']._device_ids[0]
70
- rtn = func(**kwargs)
71
- kwargs['self'].to(default_device)
72
- return rtn
73
- return wrapper
74
-
75
- def collate_handle_corrupted(samples_list, dataset, labels, dtype=torch.half):
76
- # print(len(samples_list))
77
- orig_len = len(samples_list)
78
- # for the loss to be consistent, we drop samples with NaN values in any of their corresponding crops
79
- for i, s in enumerate(samples_list):
80
- ic(s is None)
81
- if s is None:
82
- continue
83
- samples_list = list(filter(lambda x: x is not None, samples_list))
84
-
85
- if len(samples_list) == 0:
86
- ic('recursive call')
87
- return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
88
-
89
- # collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
90
- try:
91
- if "image" in samples_list[0]:
92
- samples_list = [s for s in samples_list if not torch.isnan(s["image"]).any()]
93
- # print('samples list: ', len(samples_list))
94
- collated_images = torch.stack([convert_to_tensor(s["image"]) for s in samples_list])
95
- # print("here1")
96
- collated_labels = {k: torch.Tensor([s["label"][k] if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
97
- # print("here2")
98
- collated_mask = {k: torch.Tensor([1 if s["label"][k] is not None else 0 for s in samples_list]) for k in labels}
99
- # print("here3")
100
- return {"image": collated_images,
101
- "label": collated_labels,
102
- "mask": collated_mask}
103
- except:
104
- return collate_handle_corrupted([dataset[random.randint(0, len(dataset)-1)] for _ in range(orig_len)], dataset, labels)
105
-
106
-
107
-
108
- def get_backend(img_backend):
109
- if img_backend == 'C3D':
110
- return nn.C3D
111
- elif img_backend == 'DenseNet':
112
- return nn.DenseNet
113
-
114
-
115
- class ImagingModel(BaseEstimator):
116
- ''' ... '''
117
- def __init__(self,
118
- tgt_modalities: list[str],
119
- label_fractions: dict[str, float],
120
- num_epochs: int = 32,
121
- batch_size: int = 8,
122
- batch_size_multiplier: int = 1,
123
- lr: float = 1e-2,
124
- weight_decay: float = 0.0,
125
- beta: float = 0.9999,
126
- gamma: float = 2.0,
127
- bn_size: int = 4,
128
- growth_rate: int = 12,
129
- block_config: tuple = (3, 3, 3),
130
- compression: float = 0.5,
131
- num_init_features: int = 16,
132
- drop_rate: float = 0.2,
133
- criterion: str | None = None,
134
- device: str = 'cpu',
135
- cuda_devices: list = [1],
136
- ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/dev/ckpt/ckpt.pt',
137
- load_from_ckpt: bool = True,
138
- save_intermediate_ckpts: bool = False,
139
- data_parallel: bool = False,
140
- verbose: int = 0,
141
- img_backend: str | None = None,
142
- label_distribution: dict = {},
143
- wandb_ = 1,
144
- _device_ids: list | None = None,
145
- _dataloader_num_workers: int = 4,
146
- _amp_enabled: bool = False,
147
- ) -> None:
148
- ''' ... '''
149
- # for multiprocessing
150
- self._rank = 0
151
- self._lock = None
152
-
153
- # positional parameters
154
- self.tgt_modalities = tgt_modalities
155
-
156
- # training parameters
157
- self.label_fractions = label_fractions
158
- self.num_epochs = num_epochs
159
- self.batch_size = batch_size
160
- self.batch_size_multiplier = batch_size_multiplier
161
- self.lr = lr
162
- self.weight_decay = weight_decay
163
- self.beta = beta
164
- self.gamma = gamma
165
- self.bn_size = bn_size
166
- self.growth_rate = growth_rate
167
- self.block_config = block_config
168
- self.compression = compression
169
- self.num_init_features = num_init_features
170
- self.drop_rate = drop_rate
171
- self.criterion = criterion
172
- self.device = device
173
- self.cuda_devices = cuda_devices
174
- self.ckpt_path = ckpt_path
175
- self.load_from_ckpt = load_from_ckpt
176
- self.save_intermediate_ckpts = save_intermediate_ckpts
177
- self.data_parallel = data_parallel
178
- self.verbose = verbose
179
- self.img_backend = img_backend
180
- self.label_distribution = label_distribution
181
- self.wandb_ = wandb_
182
- self._device_ids = _device_ids
183
- self._dataloader_num_workers = _dataloader_num_workers
184
- self._amp_enabled = _amp_enabled
185
- self.scaler = torch.cuda.amp.GradScaler()
186
-
187
- @_manage_ctx_fit
188
- def fit(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None) -> Self:
189
- # def fit(self, x, y) -> Self:
190
- ''' ... '''
191
-
192
- # start a new wandb run to track this script
193
- if self.wandb_ == 1:
194
- wandb.init(
195
- # set the wandb project where this run will be logged
196
- project="ADRD_main",
197
-
198
- # track hyperparameters and run metadata
199
- config={
200
- "Model": "DenseNet",
201
- "Loss": 'Focalloss',
202
- "EMB": "ALL_EMB",
203
- "epochs": 256,
204
- }
205
- )
206
- wandb.run.log_code("/home/skowshik/ADRD_repo/pipeline_v1_main/adrd_tool")
207
- else:
208
- wandb.init(mode="disabled")
209
- # for PyTorch computational efficiency
210
- torch.set_num_threads(1)
211
- print(self.criterion)
212
-
213
- # initialize neural network
214
- self._init_net()
215
-
216
- # for k, info in self.src_modalities.items():
217
- # if info['type'] == 'imaging' and self.img_net != 'EMB':
218
- # info['shape'] = (1,) + (self.img_size,) * 3
219
- # info['img_shape'] = (1,) + (self.img_size,) * 3
220
- # print(info['shape'])
221
-
222
- # initialize dataloaders
223
- # ldr_trn, ldr_vld = self._init_dataloader(x, y)
224
- # ldr_trn, ldr_vld = self._init_dataloader(x_trn, x_vld, y_trn, y_vld)
225
- ldr_trn, ldr_vld = self._init_dataloader(trn_list, vld_list, img_train_trans=img_train_trans, img_vld_trans=img_vld_trans)
226
-
227
- # initialize optimizer and scheduler
228
- if not self.load_from_ckpt:
229
- self.optimizer = self._init_optimizer()
230
- self.scheduler = self._init_scheduler(self.optimizer)
231
-
232
- # gradient scaler for AMP
233
- if self._amp_enabled:
234
- self.scaler = torch.cuda.amp.GradScaler()
235
-
236
- # initialize focal loss function
237
- self.loss_fn = {}
238
-
239
- for k in self.tgt_modalities:
240
- if self.label_fractions[k] >= 0.3:
241
- alpha = -1
242
- else:
243
- alpha = pow((1 - self.label_fractions[k]), 2)
244
- # alpha = -1
245
- self.loss_fn[k] = nn.SigmoidFocalLoss(
246
- alpha = alpha,
247
- gamma = self.gamma,
248
- reduction = 'none'
249
- )
250
-
251
- # to record the best validation performance criterion
252
- if self.criterion is not None:
253
- best_crit = None
254
- best_crit_AUPR = None
255
-
256
- # progress bar for epoch loops
257
- if self.verbose == 1:
258
- with self._lock if self._lock is not None else suppress():
259
- pbr_epoch = tqdm(
260
- desc = 'Rank {:02d}'.format(self._rank),
261
- total = self.num_epochs,
262
- position = self._rank,
263
- ascii = True,
264
- leave = False,
265
- bar_format='{l_bar}{r_bar}'
266
- )
267
-
268
- # Define a hook function to print and store the gradient of a layer
269
- def print_and_store_grad(grad, grad_list):
270
- grad_list.append(grad)
271
- # print(grad)
272
-
273
- # grad_list = []
274
- # self.net_.modules_emb_src['img_MRI_T1'].downsample[0].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
275
-
276
- # lambda_coeff = 0.0001
277
- # margin_loss = torch.nn.MarginRankingLoss(reduction='sum', margin=0.05)
278
-
279
- # training loop
280
- for epoch in range(self.start_epoch, self.num_epochs):
281
- met_trn = self.train_one_epoch(ldr_trn, epoch)
282
- met_vld = self.validate_one_epoch(ldr_vld, epoch)
283
-
284
- print(self.ckpt_path.split('/')[-1])
285
-
286
- # save the model if it has the best validation performance criterion by far
287
- if self.criterion is None: continue
288
-
289
-
290
- # is current criterion better than previous best?
291
- curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
292
- curr_crit_AUPR = np.mean([met_vld[i]["AUC (PR)"] for i in range(len(self.tgt_modalities))])
293
- # AUROC
294
- if best_crit is None or np.isnan(best_crit):
295
- is_better = True
296
- elif self.criterion == 'Loss' and best_crit >= curr_crit:
297
- is_better = True
298
- elif self.criterion != 'Loss' and best_crit <= curr_crit :
299
- is_better = True
300
- else:
301
- is_better = False
302
-
303
- # AUPR
304
- if best_crit_AUPR is None or np.isnan(best_crit_AUPR):
305
- is_better_AUPR = True
306
- elif best_crit_AUPR <= curr_crit_AUPR :
307
- is_better_AUPR = True
308
- else:
309
- is_better_AUPR = False
310
-
311
- # update best criterion
312
- if is_better_AUPR:
313
- best_crit_AUPR = curr_crit_AUPR
314
- if self.save_intermediate_ckpts:
315
- print(f"Saving the model to {self.ckpt_path[:-3]}_AUPR.pt...")
316
- self.save(self.ckpt_path[:-3]+"_AUPR.pt", epoch)
317
-
318
- if is_better:
319
- best_crit = curr_crit
320
- best_state_dict = deepcopy(self.net_.state_dict())
321
- if self.save_intermediate_ckpts:
322
- print(f"Saving the model to {self.ckpt_path}...")
323
- self.save(self.ckpt_path, epoch)
324
-
325
- if self.verbose > 2:
326
- print('Best {}: {}'.format(self.criterion, best_crit))
327
- print('Best {}: {}'.format('AUC (PR)', best_crit_AUPR))
328
-
329
- if self.verbose == 1:
330
- with self._lock if self._lock is not None else suppress():
331
- pbr_epoch.update(1)
332
- pbr_epoch.refresh()
333
-
334
- return self
335
-
336
- def train_one_epoch(self, ldr_trn, epoch):
337
-
338
- # progress bar for batch loops
339
- if self.verbose > 1:
340
- pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
341
-
342
- torch.set_grad_enabled(True)
343
- self.net_.train()
344
-
345
- scores_trn, y_true_trn, y_mask_trn = [], [], []
346
- losses_trn = [[] for _ in self.tgt_modalities]
347
- iters = len(ldr_trn)
348
- print(iters)
349
- for n_iter, batch_data in enumerate(ldr_trn):
350
- # if len(batch_data["image"]) < self.batch_size:
351
- # continue
352
-
353
- x_batch = batch_data["image"].to(self.device, non_blocking=True)
354
- y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
355
- y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
356
-
357
- with torch.autocast(
358
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
359
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
360
- enabled = self._amp_enabled,
361
- ):
362
-
363
- outputs = self.net_(x_batch, shap=False)
364
- # print(outputs.shape)
365
- # calculate multitask loss
366
- loss = 0
367
- for i, k in enumerate(self.tgt_modalities):
368
- loss_task = self.loss_fn[k](outputs[k], y_batch[k])
369
- msk_loss_task = loss_task * y_mask[k]
370
- msk_loss_mean = msk_loss_task.sum() / y_mask[k].sum()
371
- loss += msk_loss_mean
372
- losses_trn[i] += msk_loss_task.detach().cpu().numpy().tolist()
373
-
374
- # backward
375
- if self._amp_enabled:
376
- self.scaler.scale(loss).backward()
377
- else:
378
- loss.backward()
379
-
380
- # print(len(grad_list), len(grad_list[-1]))
381
- # print(f"Gradient at {n_iter}: {grad_list[-1][0]}")
382
-
383
- # update parameters
384
- if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
385
- if self._amp_enabled:
386
- self.scaler.step(self.optimizer)
387
- self.scaler.update()
388
- self.optimizer.zero_grad()
389
- else:
390
- self.optimizer.step()
391
- self.optimizer.zero_grad()
392
- # set self.scheduler
393
- self.scheduler.step(epoch + n_iter / iters)
394
- # print(f"Weight: {self.net_.module.features[0].weight[0]}")
395
-
396
- ''' TODO: change array to dictionary later '''
397
- outputs = torch.stack(list(outputs.values()), dim=1)
398
- y_batch = torch.stack(list(y_batch.values()), dim=1)
399
- y_mask = torch.stack(list(y_mask.values()), dim=1)
400
-
401
- # save outputs to evaluate performance later
402
- scores_trn.append(outputs.detach().to(torch.float).cpu())
403
- y_true_trn.append(y_batch.cpu())
404
- y_mask_trn.append(y_mask.cpu())
405
-
406
- # log metrics to wandb
407
-
408
- # update progress bar
409
- if self.verbose > 1:
410
- batch_size = len(x_batch)
411
- pbr_batch.update(batch_size, {})
412
- pbr_batch.refresh()
413
-
414
- # clear cuda cache
415
- if "cuda" in self.device:
416
- torch.cuda.empty_cache()
417
-
418
- # for better tqdm progress bar display
419
- if self.verbose > 1:
420
- pbr_batch.close()
421
-
422
- # # set self.scheduler
423
- # self.scheduler.step()
424
-
425
- # calculate and print training performance metrics
426
- scores_trn = torch.cat(scores_trn)
427
- y_true_trn = torch.cat(y_true_trn)
428
- y_mask_trn = torch.cat(y_mask_trn)
429
- y_pred_trn = (scores_trn > 0).to(torch.int)
430
- y_prob_trn = torch.sigmoid(scores_trn)
431
- met_trn = get_metrics_multitask(
432
- y_true_trn.numpy(),
433
- y_pred_trn.numpy(),
434
- y_prob_trn.numpy(),
435
- y_mask_trn.numpy()
436
- )
437
-
438
- # add loss to metrics
439
- for i in range(len(self.tgt_modalities)):
440
- met_trn[i]['Loss'] = np.mean(losses_trn[i])
441
-
442
- wandb.log({f"Train loss {list(self.tgt_modalities)[i]}": met_trn[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
443
- wandb.log({f"Train Balanced Accuracy {list(self.tgt_modalities)[i]}": met_trn[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
444
-
445
- wandb.log({f"Train AUC (ROC) {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
446
- wandb.log({f"Train AUPR {list(self.tgt_modalities)[i]}": met_trn[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
447
-
448
- if self.verbose > 2:
449
- print_metrics_multitask(met_trn)
450
-
451
- return met_trn
452
-
453
- # @torch.no_grad()
454
- def validate_one_epoch(self, ldr_vld, epoch):
455
- # progress bar for validation
456
- if self.verbose > 1:
457
- pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
458
-
459
- # set model to validation mode
460
- torch.set_grad_enabled(False)
461
- self.net_.eval()
462
-
463
- scores_vld, y_true_vld, y_mask_vld = [], [], []
464
- losses_vld = [[] for _ in self.tgt_modalities]
465
- for batch_data in ldr_vld:
466
- # if len(batch_data["image"]) < self.batch_size:
467
- # continue
468
- x_batch = batch_data["image"].to(self.device, non_blocking=True)
469
- y_batch = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["label"].items()}
470
- y_mask = {k: v.to(self.device, non_blocking=True) for k,v in batch_data["mask"].items()}
471
-
472
- # forward
473
- with torch.autocast(
474
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
475
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
476
- enabled = self._amp_enabled
477
- ):
478
-
479
- outputs = self.net_(x_batch, shap=False)
480
-
481
- # calculate multitask loss
482
- for i, k in enumerate(self.tgt_modalities):
483
- loss_task = self.loss_fn[k](outputs[k], y_batch[k])
484
- msk_loss_task = loss_task * y_mask[k]
485
- losses_vld[i] += msk_loss_task.detach().cpu().numpy().tolist()
486
-
487
- ''' TODO: change array to dictionary later '''
488
- outputs = torch.stack(list(outputs.values()), dim=1)
489
- y_batch = torch.stack(list(y_batch.values()), dim=1)
490
- y_mask = torch.stack(list(y_mask.values()), dim=1)
491
-
492
- # save outputs to evaluate performance later
493
- scores_vld.append(outputs.detach().to(torch.float).cpu())
494
- y_true_vld.append(y_batch.cpu())
495
- y_mask_vld.append(y_mask.cpu())
496
-
497
- # update progress bar
498
- if self.verbose > 1:
499
- batch_size = len(x_batch)
500
- pbr_batch.update(batch_size, {})
501
- pbr_batch.refresh()
502
-
503
- # clear cuda cache
504
- if "cuda" in self.device:
505
- torch.cuda.empty_cache()
506
-
507
- # for better tqdm progress bar display
508
- if self.verbose > 1:
509
- pbr_batch.close()
510
-
511
- # calculate and print validation performance metrics
512
- scores_vld = torch.cat(scores_vld)
513
- y_true_vld = torch.cat(y_true_vld)
514
- y_mask_vld = torch.cat(y_mask_vld)
515
- y_pred_vld = (scores_vld > 0).to(torch.int)
516
- y_prob_vld = torch.sigmoid(scores_vld)
517
- met_vld = get_metrics_multitask(
518
- y_true_vld.numpy(),
519
- y_pred_vld.numpy(),
520
- y_prob_vld.numpy(),
521
- y_mask_vld.numpy()
522
- )
523
-
524
- # add loss to metrics
525
- for i in range(len(self.tgt_modalities)):
526
- met_vld[i]['Loss'] = np.mean(losses_vld[i])
527
-
528
- wandb.log({f"Validation loss {list(self.tgt_modalities)[i]}": met_vld[i]['Loss'] for i in range(len(self.tgt_modalities))}, step=epoch)
529
- wandb.log({f"Validation Balanced Accuracy {list(self.tgt_modalities)[i]}": met_vld[i]['Balanced Accuracy'] for i in range(len(self.tgt_modalities))}, step=epoch)
530
-
531
- wandb.log({f"Validation AUC (ROC) {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (ROC)'] for i in range(len(self.tgt_modalities))}, step=epoch)
532
- wandb.log({f"Validation AUPR {list(self.tgt_modalities)[i]}": met_vld[i]['AUC (PR)'] for i in range(len(self.tgt_modalities))}, step=epoch)
533
-
534
- if self.verbose > 2:
535
- print_metrics_multitask(met_vld)
536
-
537
- return met_vld
538
-
539
-
540
- def save(self, filepath: str, epoch: int = 0) -> None:
541
- ''' ... '''
542
- check_is_fitted(self)
543
- if self.data_parallel:
544
- state_dict = self.net_.module.state_dict()
545
- else:
546
- state_dict = self.net_.state_dict()
547
-
548
- # attach model hyper parameters
549
- state_dict['tgt_modalities'] = self.tgt_modalities
550
- state_dict['optimizer'] = self.optimizer
551
- state_dict['bn_size'] = self.bn_size
552
- state_dict['growth_rate'] = self.growth_rate
553
- state_dict['block_config'] = self.block_config
554
- state_dict['compression'] = self.compression
555
- state_dict['num_init_features'] = self.num_init_features
556
- state_dict['drop_rate'] = self.drop_rate
557
- state_dict['epoch'] = epoch
558
-
559
- if self.scaler is not None:
560
- state_dict['scaler'] = self.scaler.state_dict()
561
- if self.label_distribution:
562
- state_dict['label_distribution'] = self.label_distribution
563
-
564
- torch.save(state_dict, filepath)
565
-
566
- def load(self, filepath: str, map_location: str = 'cpu', how='latest') -> None:
567
- ''' ... '''
568
- # load state_dict
569
- if how == 'latest':
570
- if torch.load(filepath)['epoch'] > torch.load(f'{filepath[:-3]}_AUPR.pt')['epoch']:
571
- print("Loading model saved using AUROC")
572
- state_dict = torch.load(filepath, map_location=map_location)
573
- else:
574
- print("Loading model saved using AUPR")
575
- state_dict = torch.load(f'{filepath[:-3]}_AUPR.pt', map_location=map_location)
576
- else:
577
- state_dict = torch.load(filepath, map_location=map_location)
578
-
579
- # load data modalities
580
- self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
581
- if 'label_distribution' in state_dict:
582
- self.label_distribution: dict[str, dict[int, int]] = state_dict.pop('label_distribution')
583
- if 'optimizer' in state_dict:
584
- self.optimizer = state_dict.pop('optimizer')
585
- if 'bn_size' in state_dict:
586
- self.bn_size = state_dict.pop('bn_size')
587
- if 'growth_rate' in state_dict:
588
- self.growth_rate = state_dict.pop('growth_rate')
589
- if 'block_config' in state_dict:
590
- self.block_config = state_dict.pop('block_config')
591
- if 'compression' in state_dict:
592
- self.compression = state_dict.pop('compression')
593
- if 'num_init_features' in state_dict:
594
- self.num_init_features = state_dict.pop('num_init_features')
595
- if 'drop_rate' in state_dict:
596
- self.drop_rate = state_dict.pop('drop_rate')
597
- if 'epoch' in state_dict:
598
- self.start_epoch = state_dict.pop('epoch')
599
- print(f'Epoch: {self.start_epoch}')
600
-
601
- # initialize model
602
-
603
- self.net_ = get_backend(self.img_backend)(
604
- tgt_modalities = self.tgt_modalities,
605
- bn_size = self.bn_size,
606
- growth_rate=self.growth_rate,
607
- block_config=self.block_config,
608
- compression=self.compression,
609
- num_init_features=self.num_init_features,
610
- drop_rate=self.drop_rate,
611
- load_from_ckpt=self.load_from_ckpt
612
- )
613
- print(self.net_)
614
-
615
- if 'scaler' in state_dict and state_dict['scaler']:
616
- self.scaler.load_state_dict(state_dict.pop('scaler'))
617
- self.net_.load_state_dict(state_dict)
618
- check_is_fitted(self)
619
- self.net_.to(self.device)
620
-
621
- def to(self, device: str) -> Self:
622
- ''' Mount model to the given device. '''
623
- self.device = device
624
- if hasattr(self, 'model'): self.net_ = self.net_.to(device)
625
- return self
626
-
627
- @classmethod
628
- def from_ckpt(cls, filepath: str, device='cpu', img_backend=None, load_from_ckpt=True, how='latest') -> Self:
629
- ''' ... '''
630
- obj = cls(None, None, None,device=device)
631
- if device == 'cuda':
632
- obj.device = "{}:{}".format(obj.device, str(obj.cuda_devices[0]))
633
- print(obj.device)
634
- obj.img_backend=img_backend
635
- obj.load_from_ckpt = load_from_ckpt
636
- obj.load(filepath, map_location=obj.device, how=how)
637
- return obj
638
-
639
- def _init_net(self):
640
- """ ... """
641
- self.start_epoch = 0
642
- # set the device for use
643
- if self.device == 'cuda':
644
- self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
645
- # self.load(self.ckpt_path, map_location=self.device)
646
- # print("Loading model from checkpoint...")
647
- # self.load(self.ckpt_path, map_location=self.device)
648
-
649
- if self.load_from_ckpt:
650
- try:
651
- print("Loading model from checkpoint...")
652
- self.load(self.ckpt_path, map_location=self.device)
653
- except:
654
- print("Cannot load from checkpoint. Initializing new model...")
655
- self.load_from_ckpt = False
656
-
657
- if not self.load_from_ckpt:
658
- self.net_ = get_backend(self.img_backend)(
659
- tgt_modalities = self.tgt_modalities,
660
- bn_size = self.bn_size,
661
- growth_rate=self.growth_rate,
662
- block_config=self.block_config,
663
- compression=self.compression,
664
- num_init_features=self.num_init_features,
665
- drop_rate=self.drop_rate,
666
- load_from_ckpt=self.load_from_ckpt
667
- )
668
-
669
- # # intialize model parameters using xavier_uniform
670
- # for p in self.net_.parameters():
671
- # if p.dim() > 1:
672
- # torch.nn.init.xavier_uniform_(p)
673
-
674
- self.net_.to(self.device)
675
-
676
- # Initialize the number of GPUs
677
- if self.data_parallel and torch.cuda.device_count() > 1:
678
- print("Available", torch.cuda.device_count(), "GPUs!")
679
- self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
680
-
681
- # return net
682
-
683
- def _init_dataloader(self, trn_list, vld_list, img_train_trans=None, img_vld_trans=None):
684
- # def _init_dataloader(self, x, y):
685
- """ ... """
686
- # # split dataset
687
- # x_trn, x_vld, y_trn, y_vld = train_test_split(
688
- # x, y, test_size = 0.2, random_state = 0,
689
- # )
690
-
691
- # # initialize dataset and dataloader
692
- # dat_trn = CNNTrainingValidationDataset(
693
- # x_trn, y_trn,
694
- # self.tgt_modalities,
695
- # img_transform=img_train_trans,
696
- # )
697
-
698
- # dat_vld = CNNTrainingValidationDataset(
699
- # x_vld, y_vld,
700
- # self.tgt_modalities,
701
- # img_transform=img_vld_trans,
702
- # )
703
-
704
- dat_trn = monai.data.Dataset(data=trn_list, transform=img_train_trans)
705
- dat_vld = monai.data.Dataset(data=vld_list, transform=img_vld_trans)
706
- collate_fn_trn = functools.partial(collate_handle_corrupted, dataset=dat_trn, dtype=torch.FloatTensor, labels=self.tgt_modalities)
707
- collate_fn_vld = functools.partial(collate_handle_corrupted, dataset=dat_vld, dtype=torch.FloatTensor, labels=self.tgt_modalities)
708
-
709
- ldr_trn = DataLoader(
710
- dataset = dat_trn,
711
- batch_size = self.batch_size,
712
- shuffle = True,
713
- drop_last = False,
714
- num_workers = self._dataloader_num_workers,
715
- collate_fn = collate_fn_trn,
716
- # pin_memory = True
717
- )
718
-
719
- ldr_vld = DataLoader(
720
- dataset = dat_vld,
721
- batch_size = self.batch_size,
722
- shuffle = False,
723
- drop_last = False,
724
- num_workers = self._dataloader_num_workers,
725
- collate_fn = collate_fn_vld,
726
- # pin_memory = True
727
- )
728
-
729
- return ldr_trn, ldr_vld
730
-
731
- def _init_optimizer(self):
732
- """ ... """
733
- params = list(self.net_.parameters())
734
- # for p in params:
735
- # print(p.requires_grad)
736
- return torch.optim.AdamW(
737
- params,
738
- lr = self.lr,
739
- betas = (0.9, 0.98),
740
- weight_decay = self.weight_decay
741
- )
742
-
743
- def _init_scheduler(self, optimizer):
744
- """ ... """
745
- # return torch.optim.lr_scheduler.OneCycleLR(
746
- # optimizer = optimizer,
747
- # max_lr = self.lr,
748
- # total_steps = self.num_epochs,
749
- # verbose = (self.verbose > 2)
750
- # )
751
-
752
- # return torch.optim.lr_scheduler.CosineAnnealingLR(
753
- # optimizer=optimizer,
754
- # T_max=64,
755
- # verbose=(self.verbose > 2)
756
- # )
757
-
758
- return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
759
- optimizer=optimizer,
760
- T_0=64,
761
- T_mult=2,
762
- eta_min = 0,
763
- verbose=(self.verbose > 2)
764
- )
765
-
766
- def _init_loss_func(self,
767
- num_per_cls: dict[str, tuple[int, int]],
768
- ) -> dict[str, Module]:
769
- """ ... """
770
- return {k: nn.SigmoidFocalLossBeta(
771
- beta = self.beta,
772
- gamma = self.gamma,
773
- num_per_cls = num_per_cls[k],
774
- reduction = 'none',
775
- ) for k in self.tgt_modalities}
776
-
777
- def _proc_fit(self):
778
- """ ... """
779
-
780
- def _init_test_dataloader(self, batch_size, tst_list, img_tst_trans=None):
781
- # input validation
782
- check_is_fitted(self)
783
- print(self.device)
784
-
785
- # for PyTorch computational efficiency
786
- torch.set_num_threads(1)
787
-
788
- # set model to eval mode
789
- torch.set_grad_enabled(False)
790
- self.net_.eval()
791
-
792
- dat_tst = monai.data.Dataset(data=tst_list, transform=img_tst_trans)
793
- collate_fn_tst = functools.partial(collate_handle_corrupted, dataset=dat_tst, dtype=torch.FloatTensor, labels=self.tgt_modalities)
794
- # print(collate_fn_tst)
795
-
796
- ldr_tst = DataLoader(
797
- dataset = dat_tst,
798
- batch_size = batch_size,
799
- shuffle = False,
800
- drop_last = False,
801
- num_workers = self._dataloader_num_workers,
802
- collate_fn = collate_fn_tst,
803
- # pin_memory = True
804
- )
805
- return ldr_tst
806
-
807
-
808
- def predict_logits(self,
809
- ldr_tst: Any | None = None,
810
- ) -> list[dict[str, float]]:
811
-
812
- # run model and collect results
813
- logits: list[dict[str, float]] = []
814
- for batch_data in tqdm(ldr_tst):
815
- # print(batch_data["image"])
816
- if len(batch_data) == 0:
817
- continue
818
- x_batch = batch_data["image"].to(self.device, non_blocking=True)
819
- outputs = self.net_(x_batch, shap=False)
820
-
821
- # convert output from dict-of-list to list of dict, then append
822
- tmp = {k: outputs[k].tolist() for k in self.tgt_modalities}
823
- tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
824
- logits += tmp
825
-
826
- return logits
827
-
828
- def predict_proba(self,
829
- ldr_tst: Any | None = None,
830
- temperature: float = 1.0,
831
- ) -> list[dict[str, float]]:
832
- ''' ... '''
833
- logits = self.predict_logits(ldr_tst)
834
- print("got logits")
835
- return logits, [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
836
-
837
- def predict(self,
838
- ldr_tst: Any | None = None,
839
- ) -> list[dict[str, int]]:
840
- ''' ... '''
841
- logits, proba = self.predict_proba(ldr_tst)
842
- print("got proba")
843
- return logits, proba, [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/train_resnet.py DELETED
@@ -1,484 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import tqdm
4
- from sklearn.base import BaseEstimator
5
- from sklearn.utils.validation import check_is_fitted
6
- from sklearn.model_selection import train_test_split
7
- from scipy.special import expit
8
- from copy import deepcopy
9
- from contextlib import suppress
10
- from typing import Any, Self
11
- from icecream import ic
12
-
13
- from .. import nn
14
- from ..utils import TransformerTrainingDataset
15
- from ..utils import TransformerValidationDataset
16
- from ..utils import MissingMasker
17
- from ..utils import ConstantImputer
18
- from ..utils import Formatter
19
- from ..utils.misc import ProgressBar
20
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
21
-
22
-
23
- class TrainResNet(BaseEstimator):
24
- ''' ... '''
25
- def __init__(self,
26
- src_modalities: dict[str, dict[str, Any]],
27
- tgt_modalities: dict[str, dict[str, Any]],
28
- label_fractions: dict[str, float],
29
- num_epochs: int = 32,
30
- batch_size: int = 8,
31
- lr: float = 1e-2,
32
- weight_decay: float = 0.0,
33
- gamma: float = 0.0,
34
- criterion: str | None = None,
35
- device: str = 'cpu',
36
- cuda_devices: list = [1,2],
37
- mri_feature: str = 'img_MRI_T1',
38
- ckpt_path: str = '/home/skowshik/ADRD_repo/adrd_tool/adrd/dev/ckpt/ckpt.pt',
39
- load_from_ckpt: bool = True,
40
- save_intermediate_ckpts: bool = False,
41
- data_parallel: bool = False,
42
- verbose: int = 0,
43
- ):
44
- ''' ... '''
45
- # for multiprocessing
46
- self._rank = 0
47
- self._lock = None
48
-
49
- # positional parameters
50
- self.src_modalities = src_modalities
51
- self.tgt_modalities = tgt_modalities
52
-
53
- # training parameters
54
- self.label_fractions = label_fractions
55
- self.num_epochs = num_epochs
56
- self.batch_size = batch_size
57
- self.lr = lr
58
- self.weight_decay = weight_decay
59
- self.gamma = gamma
60
- self.criterion = criterion
61
- self.device = device
62
- self.cuda_devices = cuda_devices
63
- self.mri_feature = mri_feature
64
- self.ckpt_path = ckpt_path
65
- self.load_from_ckpt = load_from_ckpt
66
- self.save_intermediate_ckpts = save_intermediate_ckpts
67
- self.data_parallel = data_parallel
68
- self.verbose = verbose
69
-
70
- def fit(self, x, y):
71
- ''' ... '''
72
- # for PyTorch computational efficiency
73
- torch.set_num_threads(1)
74
-
75
- # set the device for use
76
- if self.device == 'cuda':
77
- self.device = "{}:{}".format(self.device, str(self.cuda_devices[0]))
78
-
79
- # initialize model
80
- if self.load_from_ckpt:
81
- try:
82
- print("Loading model from checkpoint...")
83
- self.load(self.ckpt_path, map_location=self.device)
84
- except:
85
- print("Cannot load from checkpoint. Initializing new model...")
86
- self.load_from_ckpt = False
87
-
88
- # initialize model
89
- if not self.load_from_ckpt:
90
- self.net_ = nn.ResNetModel(
91
- self.tgt_modalities,
92
- mri_feature = self.mri_feature
93
- )
94
- # intialize model parameters using xavier_uniform
95
- for p in self.net_.parameters():
96
- if p.dim() > 1:
97
- torch.nn.init.xavier_uniform_(p)
98
-
99
- self.net_.to(self.device)
100
-
101
- # Initialize the number of GPUs
102
- if self.data_parallel and torch.cuda.device_count() > 1:
103
- print("Available", torch.cuda.device_count(), "GPUs!")
104
- self.net_ = torch.nn.DataParallel(self.net_, device_ids=self.cuda_devices)
105
-
106
-
107
- # split dataset
108
- x_trn, x_vld, y_trn, y_vld = train_test_split(
109
- x, y, test_size = 0.2, random_state = 0,
110
- )
111
-
112
- # initialize dataset and dataloader
113
- dat_trn = TransformerTrainingDataset(
114
- x_trn, y_trn,
115
- self.src_modalities,
116
- self.tgt_modalities,
117
- dropout_rate = .5,
118
- dropout_strategy = 'compensated',
119
- mri_feature = self.mri_feature,
120
- )
121
-
122
- dat_vld = TransformerValidationDataset(
123
- x_vld, y_vld,
124
- self.src_modalities,
125
- self.tgt_modalities,
126
- mri_feature = self.mri_feature,
127
- )
128
-
129
- # ic(dat_trn[0])
130
-
131
- ldr_trn = torch.utils.data.DataLoader(
132
- dat_trn,
133
- batch_size = self.batch_size,
134
- shuffle = True,
135
- drop_last = False,
136
- num_workers = 0,
137
- collate_fn = TransformerTrainingDataset.collate_fn,
138
- # pin_memory = True
139
- )
140
-
141
- ldr_vld = torch.utils.data.DataLoader(
142
- dat_vld,
143
- batch_size = self.batch_size,
144
- shuffle = False,
145
- drop_last = False,
146
- num_workers = 0,
147
- collate_fn = TransformerTrainingDataset.collate_fn,
148
- # pin_memory = True
149
- )
150
-
151
- # initialize optimizer
152
- optimizer = torch.optim.AdamW(
153
- self.net_.parameters(),
154
- lr = self.lr,
155
- betas = (0.9, 0.98),
156
- weight_decay = self.weight_decay
157
- )
158
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64, verbose=(self.verbose > 2))
159
-
160
- # initialize loss function (binary cross entropy)
161
- loss_fn = {}
162
-
163
- for k in self.tgt_modalities:
164
- alpha = pow((1 - self.label_fractions[k]), self.gamma)
165
- # if alpha < 0.5:
166
- # alpha = -1
167
- loss_fn[k] = nn.SigmoidFocalLoss(
168
- alpha = alpha,
169
- gamma = self.gamma,
170
- reduction = 'none'
171
- )
172
-
173
- # to record the best validation performance criterion
174
- if self.criterion is not None:
175
- best_crit = None
176
-
177
- # progress bar for epoch loops
178
- if self.verbose == 1:
179
- with self._lock if self._lock is not None else suppress():
180
- pbr_epoch = tqdm.tqdm(
181
- desc = 'Rank {:02d}'.format(self._rank),
182
- total = self.num_epochs,
183
- position = self._rank,
184
- ascii = True,
185
- leave = False,
186
- bar_format='{l_bar}{r_bar}'
187
- )
188
-
189
- # Define a hook function to print and store the gradient of a layer
190
- def print_and_store_grad(grad, grad_list):
191
- grad_list.append(grad)
192
- # print(grad)
193
-
194
- # grad_list = []
195
- # self.net_.module.img_net_.featurizer.down_tr64.ops[0].conv1.weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
196
- # self.net_.module.modules_emb_src['gender'].weight.register_hook(lambda grad: print_and_store_grad(grad, grad_list))
197
-
198
-
199
- # training loop
200
- for epoch in range(self.num_epochs):
201
- # progress bar for batch loops
202
- if self.verbose > 1:
203
- pbr_batch = ProgressBar(len(dat_trn), 'Epoch {:03d} (TRN)'.format(epoch))
204
-
205
- # set model to train mode
206
- torch.set_grad_enabled(True)
207
- self.net_.train()
208
-
209
- scores_trn, y_true_trn = [], []
210
- losses_trn = [[] for _ in self.tgt_modalities]
211
- for x_batch, y_batch, mask in ldr_trn:
212
-
213
- # mount data to the proper device
214
- x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
215
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
216
-
217
- # forward
218
- outputs = self.net_(x_batch)
219
-
220
- # calculate multitask loss
221
- loss = 0
222
- for i, k in enumerate(self.tgt_modalities):
223
- loss_task = loss_fn[k](outputs[k], y_batch[k])
224
- loss += loss_task.mean()
225
- losses_trn[i] += loss_task.detach().cpu().numpy().tolist()
226
-
227
- # backward
228
- optimizer.zero_grad(set_to_none=True)
229
- loss.backward()
230
- optimizer.step()
231
-
232
- ''' TODO: change array to dictionary later '''
233
- outputs = torch.stack(list(outputs.values()), dim=1)
234
- y_batch = torch.stack(list(y_batch.values()), dim=1)
235
-
236
- # save outputs to evaluate performance later
237
- scores_trn.append(outputs.detach().to(torch.float).cpu())
238
- y_true_trn.append(y_batch.cpu())
239
-
240
- # update progress bar
241
- if self.verbose > 1:
242
- batch_size = len(next(iter(x_batch.values())))
243
- pbr_batch.update(batch_size, {})
244
- pbr_batch.refresh()
245
-
246
- # clear cuda cache
247
- if "cuda" in self.device:
248
- torch.cuda.empty_cache()
249
-
250
- # for better tqdm progress bar display
251
- if self.verbose > 1:
252
- pbr_batch.close()
253
-
254
- # set scheduler
255
- scheduler.step()
256
-
257
- # calculate and print training performance metrics
258
- scores_trn = torch.cat(scores_trn)
259
- y_true_trn = torch.cat(y_true_trn)
260
- y_pred_trn = (scores_trn > 0).to(torch.int)
261
- y_prob_trn = torch.sigmoid(scores_trn)
262
- met_trn = get_metrics_multitask(
263
- y_true_trn.numpy(),
264
- y_pred_trn.numpy(),
265
- y_prob_trn.numpy()
266
- )
267
-
268
- # add loss to metrics
269
- for i in range(len(self.tgt_modalities)):
270
- met_trn[i]['Loss'] = np.mean(losses_trn[i])
271
-
272
- if self.verbose > 2:
273
- print_metrics_multitask(met_trn)
274
-
275
- # progress bar for validation
276
- if self.verbose > 1:
277
- pbr_batch = ProgressBar(len(dat_vld), 'Epoch {:03d} (VLD)'.format(epoch))
278
-
279
- # set model to validation mode
280
- torch.set_grad_enabled(False)
281
- self.net_.eval()
282
-
283
- scores_vld, y_true_vld = [], []
284
- losses_vld = [[] for _ in self.tgt_modalities]
285
- for x_batch, y_batch, mask in ldr_vld:
286
- # mount data to the proper device
287
- x_batch = {k: x_batch[k].to(self.device) for k in x_batch}
288
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in y_batch}
289
-
290
- # forward
291
- outputs = self.net_(x_batch)
292
-
293
- # calculate multitask loss
294
- for i, k in enumerate(self.tgt_modalities):
295
- loss_task = loss_fn[k](outputs[k], y_batch[k])
296
- losses_vld[i] += loss_task.detach().cpu().numpy().tolist()
297
-
298
- ''' TODO: change array to dictionary later '''
299
- outputs = torch.stack(list(outputs.values()), dim=1)
300
- y_batch = torch.stack(list(y_batch.values()), dim=1)
301
-
302
- # save outputs to evaluate performance later
303
- scores_vld.append(outputs.detach().to(torch.float).cpu())
304
- y_true_vld.append(y_batch.cpu())
305
-
306
- # update progress bar
307
- if self.verbose > 1:
308
- batch_size = len(next(iter(x_batch.values())))
309
- pbr_batch.update(batch_size, {})
310
- pbr_batch.refresh()
311
-
312
- # clear cuda cache
313
- if "cuda" in self.device:
314
- torch.cuda.empty_cache()
315
-
316
- # for better tqdm progress bar display
317
- if self.verbose > 1:
318
- pbr_batch.close()
319
-
320
- # calculate and print validation performance metrics
321
- scores_vld = torch.cat(scores_vld)
322
- y_true_vld = torch.cat(y_true_vld)
323
- y_pred_vld = (scores_vld > 0).to(torch.int)
324
- y_prob_vld = torch.sigmoid(scores_vld)
325
- met_vld = get_metrics_multitask(
326
- y_true_vld.numpy(),
327
- y_pred_vld.numpy(),
328
- y_prob_vld.numpy()
329
- )
330
-
331
- # add loss to metrics
332
- for i in range(len(self.tgt_modalities)):
333
- met_vld[i]['Loss'] = np.mean(losses_vld[i])
334
-
335
- if self.verbose > 2:
336
- print_metrics_multitask(met_vld)
337
-
338
- # save the model if it has the best validation performance criterion by far
339
- if self.criterion is None: continue
340
-
341
- # is current criterion better than previous best?
342
- curr_crit = np.mean([met_vld[i][self.criterion] for i in range(len(self.tgt_modalities))])
343
- if best_crit is None or np.isnan(best_crit):
344
- is_better = True
345
- elif self.criterion == 'Loss' and best_crit >= curr_crit:
346
- is_better = True
347
- elif self.criterion != 'Loss' and best_crit <= curr_crit:
348
- is_better = True
349
- else:
350
- is_better = False
351
-
352
- # update best criterion
353
- if is_better:
354
- best_crit = curr_crit
355
- best_state_dict = deepcopy(self.net_.state_dict())
356
- if self.save_intermediate_ckpts:
357
- print("Saving the model...")
358
- self.save(self.ckpt_path)
359
-
360
- if self.verbose > 2:
361
- print('Best {}: {}'.format(self.criterion, best_crit))
362
-
363
- if self.verbose == 1:
364
- with self._lock if self._lock is not None else suppress():
365
- pbr_epoch.update(1)
366
- pbr_epoch.refresh()
367
-
368
- if self.verbose == 1:
369
- with self._lock if self._lock is not None else suppress():
370
- pbr_epoch.close()
371
-
372
- # restore the model of the best validation performance across all epoches
373
- if ldr_vld is not None and self.criterion is not None:
374
- self.net_.load_state_dict(best_state_dict)
375
-
376
- return self
377
-
378
- def predict_logits(self,
379
- x: list[dict[str, Any]],
380
- ) -> list[dict[str, float]]:
381
- '''
382
- The input x can be a single sample or a list of samples.
383
- '''
384
- # input validation
385
- check_is_fitted(self)
386
-
387
- # for PyTorch computational efficiency
388
- torch.set_num_threads(1)
389
-
390
- # set model to eval mode
391
- torch.set_grad_enabled(False)
392
- self.net_.eval()
393
-
394
- # number of samples to evaluate
395
- n_samples = len(x)
396
-
397
- # format x
398
- fmt = Formatter(self.src_modalities)
399
- x = [fmt(smp) for smp in x]
400
-
401
- # generate missing mask (BEFORE IMPUTATION)
402
- msk = MissingMasker(self.src_modalities)
403
- mask = [msk(smp) for smp in x]
404
-
405
- # reformat x and then impute by 0s
406
- imp = ConstantImputer(self.src_modalities)
407
- x = [imp(smp) for smp in x]
408
-
409
- # convert list-of-dict to dict-of-list
410
- x = {k: [smp[k] for smp in x] for k in self.src_modalities}
411
- mask = {k: [smp[k] for smp in mask] for k in self.src_modalities}
412
-
413
- # to tensor
414
- x = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in x.items()}
415
- mask = {k: torch.as_tensor(np.array(v)).to(self.device) for k, v in mask.items()}
416
-
417
- # calculate logits
418
- logits = self.net_(x)
419
-
420
- # convert dict-of-list to list-of-dict
421
- logits = {k: logits[k].tolist() for k in self.tgt_modalities}
422
- logits = [{k: logits[k][i] for k in self.tgt_modalities} for i in range(n_samples)]
423
-
424
- return logits
425
-
426
- def predict_proba(self,
427
- x: list[dict[str, Any]],
428
- temperature: float = 1.0
429
- ) -> list[dict[str, float]]:
430
- ''' ... '''
431
- # calculate logits
432
- logits = self.predict_logits(x)
433
-
434
- # convert logits to probabilities and
435
- proba = [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
436
- return proba
437
-
438
- def predict(self,
439
- x: list[dict[str, Any]],
440
- ) -> list[dict[str, int]]:
441
- ''' ... '''
442
- proba = self.predict_proba(x)
443
- return [{k: int(smp[k] > 0.5) for k in self.tgt_modalities} for smp in proba]
444
-
445
- def save(self, filepath: str) -> None:
446
- ''' ... '''
447
- check_is_fitted(self)
448
- if self.data_parallel:
449
- state_dict = self.net_.module.state_dict()
450
- else:
451
- state_dict = self.net_.state_dict()
452
-
453
- # attach model hyper parameters
454
- state_dict['src_modalities'] = self.src_modalities
455
- state_dict['tgt_modalities'] = self.tgt_modalities
456
- state_dict['mri_feature'] = self.mri_feature
457
-
458
- torch.save(state_dict, filepath)
459
-
460
- def load(self, filepath: str, map_location: str='cpu') -> None:
461
- ''' ... '''
462
- # load state_dict
463
- state_dict = torch.load(filepath, map_location=map_location)
464
-
465
- # load data modalities
466
- self.src_modalities = state_dict.pop('src_modalities')
467
- self.tgt_modalities = state_dict.pop('tgt_modalities')
468
-
469
- # initialize model
470
- self.net_ = nn.ResNetModel(
471
- self.tgt_modalities,
472
- mri_feature = state_dict.pop('mri_feature')
473
- )
474
-
475
- # load model parameters
476
- self.net_.load_state_dict(state_dict)
477
- self.net_.to(self.device)
478
-
479
- @classmethod
480
- def from_ckpt(cls, filepath: str, device='cpu') -> Self:
481
- ''' ... '''
482
- obj = cls(None, None, None,device=device)
483
- obj.load(filepath)
484
- return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/model/transformer.py DELETED
@@ -1,600 +0,0 @@
1
- __all__ = ['Transformer']
2
-
3
- import torch
4
- from torch.utils.data import DataLoader
5
- import numpy as np
6
- import tqdm
7
- from sklearn.base import BaseEstimator
8
- from sklearn.utils.validation import check_is_fitted
9
- from sklearn.model_selection import train_test_split
10
- from scipy.special import expit
11
- from copy import deepcopy
12
- from contextlib import suppress
13
- from typing import Any, Self, Type
14
- from functools import wraps
15
- Tensor = Type[torch.Tensor]
16
- Module = Type[torch.nn.Module]
17
-
18
- from .. import nn
19
- from ..utils import TransformerTrainingDataset
20
- from ..utils import TransformerBalancedTrainingDataset
21
- from ..utils import Transformer2ndOrderBalancedTrainingDataset
22
- from ..utils import TransformerValidationDataset
23
- from ..utils import TransformerTestingDataset
24
- from ..utils.misc import ProgressBar
25
- from ..utils.misc import get_metrics_multitask, print_metrics_multitask
26
- from ..utils.misc import convert_args_kwargs_to_kwargs
27
-
28
-
29
- def _manage_ctx_fit(func):
30
- ''' ... '''
31
- @wraps(func)
32
- def wrapper(*args, **kwargs):
33
- # format arguments
34
- kwargs = convert_args_kwargs_to_kwargs(func, args, kwargs)
35
-
36
- if kwargs['self']._device_ids is None:
37
- return func(**kwargs)
38
- else:
39
- # change primary device
40
- default_device = kwargs['self'].device
41
- kwargs['self'].device = kwargs['self']._device_ids[0]
42
- rtn = func(**kwargs)
43
- kwargs['self'].to(default_device)
44
- return rtn
45
- return wrapper
46
-
47
-
48
- class Transformer(BaseEstimator):
49
- ''' ... '''
50
- def __init__(self,
51
- src_modalities: dict[str, dict[str, Any]],
52
- tgt_modalities: dict[str, dict[str, Any]],
53
- d_model: int = 32,
54
- nhead: int = 1,
55
- num_layers: int = 1,
56
- num_epochs: int = 32,
57
- batch_size: int = 8,
58
- batch_size_multiplier: int = 1,
59
- lr: float = 1e-2,
60
- weight_decay: float = 0.0,
61
- beta: float = 0.9999,
62
- gamma: float = 2.0,
63
- scale: float = 1.0,
64
- lambd: float = 0.0,
65
- criterion: str | None = None,
66
- device: str = 'cpu',
67
- verbose: int = 0,
68
- _device_ids: list | None = None,
69
- _dataloader_num_workers: int = 0,
70
- _amp_enabled: bool = False,
71
- ) -> None:
72
- ''' ... '''
73
- # for multiprocessing
74
- self._rank = 0
75
- self._lock = None
76
-
77
- # positional parameters
78
- self.src_modalities = src_modalities
79
- self.tgt_modalities = tgt_modalities
80
-
81
- # training parameters
82
- self.d_model = d_model
83
- self.nhead = nhead
84
- self.num_layers = num_layers
85
- self.num_epochs = num_epochs
86
- self.batch_size = batch_size
87
- self.batch_size_multiplier = batch_size_multiplier
88
- self.lr = lr
89
- self.weight_decay = weight_decay
90
- self.beta = beta
91
- self.gamma = gamma
92
- self.scale = scale
93
- self.lambd = lambd
94
- self.criterion = criterion
95
- self.device = device
96
- self.verbose = verbose
97
- self._device_ids = _device_ids
98
- self._dataloader_num_workers = _dataloader_num_workers
99
- self._amp_enabled = _amp_enabled
100
-
101
- @_manage_ctx_fit
102
- def fit(self,
103
- x, y,
104
- is_embedding: dict[str, bool] | None = None,
105
- ) -> Self:
106
- ''' ... '''
107
- # for PyTorch computational efficiency
108
- torch.set_num_threads(1)
109
-
110
- # initialize neural network
111
- self.net_ = self._init_net()
112
-
113
- # initialize dataloaders
114
- ldr_trn, ldr_vld = self._init_dataloader(x, y, is_embedding)
115
-
116
- # initialize optimizer and scheduler
117
- optimizer = self._init_optimizer()
118
- scheduler = self._init_scheduler(optimizer)
119
-
120
- # gradient scaler for AMP
121
- if self._amp_enabled: scaler = torch.cuda.amp.GradScaler()
122
-
123
- # initialize loss function (binary cross entropy)
124
- loss_func = self._init_loss_func({
125
- k: (
126
- sum([_[k] == 0 for _ in ldr_trn.dataset.tgt]),
127
- sum([_[k] == 1 for _ in ldr_trn.dataset.tgt]),
128
- ) for k in self.tgt_modalities
129
- })
130
-
131
- # to record the best validation performance criterion
132
- if self.criterion is not None: best_crit = None
133
-
134
- # progress bar for epoch loops
135
- if self.verbose == 1:
136
- with self._lock if self._lock is not None else suppress():
137
- pbr_epoch = tqdm.tqdm(
138
- desc = 'Rank {:02d}'.format(self._rank),
139
- total = self.num_epochs,
140
- position = self._rank,
141
- ascii = True,
142
- leave = False,
143
- bar_format='{l_bar}{r_bar}'
144
- )
145
-
146
- # training loop
147
- for epoch in range(self.num_epochs):
148
- # progress bar for batch loops
149
- if self.verbose > 1:
150
- pbr_batch = ProgressBar(len(ldr_trn.dataset), 'Epoch {:03d} (TRN)'.format(epoch))
151
-
152
- # set model to train mode
153
- torch.set_grad_enabled(True)
154
- self.net_.train()
155
-
156
- scores_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
157
- y_true_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
158
- losses_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
159
- for n_iter, (x_batch, y_batch, mask_x, mask_y) in enumerate(ldr_trn):
160
- # mount data to the proper device
161
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
162
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
163
- mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
164
- mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
165
-
166
- # forward
167
- with torch.autocast(
168
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
169
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
170
- enabled = self._amp_enabled,
171
- ):
172
- outputs = self.net_(x_batch, mask_x, is_embedding)
173
-
174
- # calculate multitask loss
175
- loss = 0
176
- for i, tgt_k in enumerate(self.tgt_modalities):
177
- loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
178
- loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
179
- loss += loss_k.mean()
180
- losses_trn[tgt_k] += loss_k.detach().cpu().numpy().tolist()
181
-
182
- # if self.lambd != 0:
183
-
184
- # backward
185
- if self._amp_enabled:
186
- scaler.scale(loss).backward()
187
- else:
188
- loss.backward()
189
-
190
- # update parameters
191
- if n_iter != 0 and n_iter % self.batch_size_multiplier == 0:
192
- if self._amp_enabled:
193
- scaler.step(optimizer)
194
- scaler.update()
195
- optimizer.zero_grad()
196
- else:
197
- optimizer.step()
198
- optimizer.zero_grad()
199
-
200
- # save outputs to evaluate performance later
201
- for tgt_k in self.tgt_modalities:
202
- tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
203
- scores_trn[tgt_k] += tmp.detach().cpu().numpy().tolist()
204
- tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
205
- y_true_trn[tgt_k] += tmp.cpu().numpy().tolist()
206
-
207
- # update progress bar
208
- if self.verbose > 1:
209
- batch_size = len(next(iter(x_batch.values())))
210
- pbr_batch.update(batch_size, {})
211
- pbr_batch.refresh()
212
-
213
- # for better tqdm progress bar display
214
- if self.verbose > 1:
215
- pbr_batch.close()
216
-
217
- # set scheduler
218
- scheduler.step()
219
-
220
- # calculate and print training performance metrics
221
- y_pred_trn: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
222
- y_prob_trn: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
223
- for tgt_k in self.tgt_modalities:
224
- for i in range(len(scores_trn[tgt_k])):
225
- y_pred_trn[tgt_k].append(1 if scores_trn[tgt_k][i] > 0 else 0)
226
- y_prob_trn[tgt_k].append(expit(scores_trn[tgt_k][i]))
227
- met_trn = get_metrics_multitask(y_true_trn, y_pred_trn, y_prob_trn)
228
-
229
- # add loss to metrics
230
- for tgt_k in self.tgt_modalities:
231
- met_trn[tgt_k]['Loss'] = np.mean(losses_trn[tgt_k])
232
-
233
- if self.verbose > 2:
234
- print_metrics_multitask(met_trn)
235
-
236
- # progress bar for validation
237
- if self.verbose > 1:
238
- pbr_batch = ProgressBar(len(ldr_vld.dataset), 'Epoch {:03d} (VLD)'.format(epoch))
239
-
240
- # set model to validation mode
241
- torch.set_grad_enabled(False)
242
- self.net_.eval()
243
-
244
- scores_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
245
- y_true_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
246
- losses_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
247
- for x_batch, y_batch, mask_x, mask_y in ldr_vld:
248
- # mount data to the proper device
249
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
250
- y_batch = {k: y_batch[k].to(torch.float).to(self.device) for k in self.tgt_modalities}
251
- mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
252
- mask_y = {k: mask_y[k].to(self.device) for k in self.tgt_modalities}
253
-
254
- # forward
255
- with torch.autocast(
256
- device_type = 'cpu' if self.device == 'cpu' else 'cuda',
257
- dtype = torch.bfloat16 if self.device == 'cpu' else torch.float16,
258
- enabled = self._amp_enabled
259
- ):
260
- outputs = self.net_(x_batch, mask_x, is_embedding)
261
-
262
- # calculate multitask loss
263
- for i, tgt_k in enumerate(self.tgt_modalities):
264
- loss_k = loss_func[tgt_k](outputs[tgt_k], y_batch[tgt_k])
265
- loss_k = torch.masked_select(loss_k, torch.logical_not(mask_y[tgt_k].squeeze()))
266
- losses_vld[tgt_k] += loss_k.detach().cpu().numpy().tolist()
267
-
268
- # save outputs to evaluate performance later
269
- for tgt_k in self.tgt_modalities:
270
- tmp = torch.masked_select(outputs[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
271
- scores_vld[tgt_k] += tmp.detach().cpu().numpy().tolist()
272
- tmp = torch.masked_select(y_batch[tgt_k], torch.logical_not(mask_y[tgt_k].squeeze()))
273
- y_true_vld[tgt_k] += tmp.cpu().numpy().tolist()
274
-
275
- # update progress bar
276
- if self.verbose > 1:
277
- batch_size = len(next(iter(x_batch.values())))
278
- pbr_batch.update(batch_size, {})
279
- pbr_batch.refresh()
280
-
281
- # for better tqdm progress bar display
282
- if self.verbose > 1:
283
- pbr_batch.close()
284
-
285
- # calculate and print validation performance metrics
286
- y_pred_vld: dict[str, list[int]] = {k: [] for k in self.tgt_modalities}
287
- y_prob_vld: dict[str, list[float]] = {k: [] for k in self.tgt_modalities}
288
- for tgt_k in self.tgt_modalities:
289
- for i in range(len(scores_vld[tgt_k])):
290
- y_pred_vld[tgt_k].append(1 if scores_vld[tgt_k][i] > 0 else 0)
291
- y_prob_vld[tgt_k].append(expit(scores_vld[tgt_k][i]))
292
- met_vld = get_metrics_multitask(y_true_vld, y_pred_vld, y_prob_vld)
293
-
294
- # add loss to metrics
295
- for tgt_k in self.tgt_modalities:
296
- met_vld[tgt_k]['Loss'] = np.mean(losses_vld[tgt_k])
297
-
298
- if self.verbose > 2:
299
- print_metrics_multitask(met_vld)
300
-
301
- # save the model if it has the best validation performance criterion by far
302
- if self.criterion is None: continue
303
-
304
- # is current criterion better than previous best?
305
- curr_crit = np.mean([met_vld[k][self.criterion] for k in self.tgt_modalities])
306
- if best_crit is None or np.isnan(best_crit):
307
- is_better = True
308
- elif self.criterion == 'Loss' and best_crit >= curr_crit:
309
- is_better = True
310
- elif self.criterion != 'Loss' and best_crit <= curr_crit:
311
- is_better = True
312
- else:
313
- is_better = False
314
-
315
- # update best criterion
316
- if is_better:
317
- best_crit = curr_crit
318
- best_state_dict = deepcopy(self.net_.state_dict())
319
-
320
- if self.verbose > 2:
321
- print('Best {}: {}'.format(self.criterion, best_crit))
322
-
323
- if self.verbose == 1:
324
- with self._lock if self._lock is not None else suppress():
325
- pbr_epoch.update(1)
326
- pbr_epoch.refresh()
327
-
328
- if self.verbose == 1:
329
- with self._lock if self._lock is not None else suppress():
330
- pbr_epoch.close()
331
-
332
- # restore the model of the best validation performance across all epoches
333
- if ldr_vld is not None and self.criterion is not None:
334
- self.net_.load_state_dict(best_state_dict)
335
-
336
- return self
337
-
338
- def predict_logits(self,
339
- x: list[dict[str, Any]],
340
- is_embedding: dict[str, bool] | None = None,
341
- _batch_size: int | None = None,
342
- ) -> list[dict[str, float]]:
343
- '''
344
- The input x can be a single sample or a list of samples.
345
- '''
346
- # input validation
347
- check_is_fitted(self)
348
-
349
- # for PyTorch computational efficiency
350
- torch.set_num_threads(1)
351
-
352
- # set model to eval mode
353
- torch.set_grad_enabled(False)
354
- self.net_.eval()
355
-
356
- # intialize dataset and dataloader object
357
- dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
358
- ldr = DataLoader(
359
- dataset = dat,
360
- batch_size = _batch_size if _batch_size is not None else len(x),
361
- shuffle = False,
362
- drop_last = False,
363
- num_workers = 0,
364
- collate_fn = TransformerTestingDataset.collate_fn,
365
- )
366
-
367
- # run model and collect results
368
- logits: list[dict[str, float]] = []
369
- for x_batch, mask_x in ldr:
370
- # mount data to the proper device
371
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
372
- mask_x = {k: mask_x[k].to(self.device) for k in self.src_modalities}
373
-
374
- # forward
375
- output: dict[str, Tensor] = self.net_(x_batch, mask_x, is_embedding)
376
-
377
- # convert output from dict-of-list to list of dict, then append
378
- tmp = {k: output[k].tolist() for k in self.tgt_modalities}
379
- tmp = [{k: tmp[k][i] for k in self.tgt_modalities} for i in range(len(next(iter(tmp.values()))))]
380
- logits += tmp
381
-
382
- return logits
383
-
384
- def predict_proba(self,
385
- x: list[dict[str, Any]],
386
- is_embedding: dict[str, bool] | None = None,
387
- temperature: float = 1.0,
388
- _batch_size: int | None = None,
389
- ) -> list[dict[str, float]]:
390
- ''' ... '''
391
- logits = self.predict_logits(x, is_embedding, _batch_size)
392
- return [{k: expit(smp[k] / temperature) for k in self.tgt_modalities} for smp in logits]
393
-
394
- def predict(self,
395
- x: list[dict[str, Any]],
396
- is_embedding: dict[str, bool] | None = None,
397
- _batch_size: int | None = None,
398
- ) -> list[dict[str, int]]:
399
- ''' ... '''
400
- logits = self.predict_logits(x, is_embedding, _batch_size)
401
- return [{k: int(smp[k] > 0.0) for k in self.tgt_modalities} for smp in logits]
402
-
403
- def save(self, filepath: str) -> None:
404
- ''' ... '''
405
- check_is_fitted(self)
406
- state_dict = self.net_.state_dict()
407
-
408
- # attach model hyper parameters
409
- state_dict['src_modalities'] = self.src_modalities
410
- state_dict['tgt_modalities'] = self.tgt_modalities
411
- state_dict['d_model'] = self.d_model
412
- state_dict['nhead'] = self.nhead
413
- state_dict['num_layers'] = self.num_layers
414
- torch.save(state_dict, filepath)
415
-
416
- def load(self, filepath: str) -> None:
417
- ''' ... '''
418
- # load state_dict
419
- state_dict = torch.load(filepath, map_location='cpu')
420
-
421
- # load essential parameters
422
- self.src_modalities: dict[str, dict[str, Any]] = state_dict.pop('src_modalities')
423
- self.tgt_modalities: dict[str, dict[str, Any]] = state_dict.pop('tgt_modalities')
424
- self.d_model = state_dict.pop('d_model')
425
- self.nhead = state_dict.pop('nhead')
426
- self.num_layers = state_dict.pop('num_layers')
427
-
428
- # initialize model
429
- self.net_ = nn.Transformer(
430
- self.src_modalities,
431
- self.tgt_modalities,
432
- self.d_model,
433
- self.nhead,
434
- self.num_layers,
435
- )
436
-
437
- # load model parameters
438
- self.net_.load_state_dict(state_dict)
439
- self.to(self.device)
440
-
441
- def to(self, device: str) -> Self:
442
- ''' Mount model to the given device. '''
443
- self.device = device
444
- if hasattr(self, 'net_'): self.net_ = self.net_.to(device)
445
- return self
446
-
447
- @classmethod
448
- def from_ckpt(cls, filepath: str) -> Self:
449
- ''' ... '''
450
- obj = cls(None, None)
451
- obj.load(filepath)
452
- return obj
453
-
454
- def _init_net(self):
455
- """ ... """
456
- net = nn.Transformer(
457
- self.src_modalities,
458
- self.tgt_modalities,
459
- self.d_model,
460
- self.nhead,
461
- self.num_layers,
462
- ).to(self.device)
463
-
464
- # train on multiple GPUs using torch.nn.DataParallel
465
- if self._device_ids is not None:
466
- net = torch.nn.DataParallel(net, device_ids=self._device_ids)
467
-
468
- # intialize model parameters using xavier_uniform
469
- for p in net.parameters():
470
- if p.dim() > 1:
471
- torch.nn.init.xavier_uniform_(p)
472
-
473
- return net
474
-
475
- def _init_dataloader(self, x, y, is_embedding):
476
- """ ... """
477
- # split dataset
478
- x_trn, x_vld, y_trn, y_vld = train_test_split(
479
- x, y, test_size = 0.2, random_state = 0,
480
- )
481
-
482
- # initialize dataset and dataloader
483
- # dat_trn = TransformerTrainingDataset(
484
- # dat_trn = TransformerBalancedTrainingDataset(
485
- dat_trn = Transformer2ndOrderBalancedTrainingDataset(
486
- x_trn, y_trn,
487
- self.src_modalities,
488
- self.tgt_modalities,
489
- dropout_rate = .5,
490
- # dropout_strategy = 'compensated',
491
- dropout_strategy = 'permutation',
492
- )
493
-
494
- dat_vld = TransformerValidationDataset(
495
- x_vld, y_vld,
496
- self.src_modalities,
497
- self.tgt_modalities,
498
- is_embedding,
499
- )
500
-
501
- ldr_trn = DataLoader(
502
- dataset = dat_trn,
503
- batch_size = self.batch_size,
504
- shuffle = True,
505
- drop_last = False,
506
- num_workers = self._dataloader_num_workers,
507
- collate_fn = TransformerTrainingDataset.collate_fn,
508
- # pin_memory = True
509
- )
510
-
511
- ldr_vld = DataLoader(
512
- dataset = dat_vld,
513
- batch_size = self.batch_size,
514
- shuffle = False,
515
- drop_last = False,
516
- num_workers = self._dataloader_num_workers,
517
- collate_fn = TransformerValidationDataset.collate_fn,
518
- # pin_memory = True
519
- )
520
-
521
- return ldr_trn, ldr_vld
522
-
523
- def _init_optimizer(self):
524
- """ ... """
525
- return torch.optim.AdamW(
526
- self.net_.parameters(),
527
- lr = self.lr,
528
- betas = (0.9, 0.98),
529
- weight_decay = self.weight_decay
530
- )
531
-
532
- def _init_scheduler(self, optimizer):
533
- """ ... """
534
- return torch.optim.lr_scheduler.OneCycleLR(
535
- optimizer = optimizer,
536
- max_lr = self.lr,
537
- total_steps = self.num_epochs,
538
- verbose = (self.verbose > 2)
539
- )
540
-
541
- def _init_loss_func(self,
542
- num_per_cls: dict[str, tuple[int, int]],
543
- ) -> dict[str, Module]:
544
- """ ... """
545
- return {k: nn.SigmoidFocalLoss(
546
- beta = self.beta,
547
- gamma = self.gamma,
548
- scale = self.scale,
549
- num_per_cls = num_per_cls[k],
550
- reduction = 'none',
551
- ) for k in self.tgt_modalities}
552
-
553
- def _extract_embedding(self,
554
- x: list[dict[str, Any]],
555
- is_embedding: dict[str, bool] | None = None,
556
- _batch_size: int | None = None,
557
- ) -> list[dict[str, Any]]:
558
- """ ... """
559
- # input validation
560
- check_is_fitted(self)
561
-
562
- # for PyTorch computational efficiency
563
- torch.set_num_threads(1)
564
-
565
- # set model to eval mode
566
- torch.set_grad_enabled(False)
567
- self.net_.eval()
568
-
569
- # intialize dataset and dataloader object
570
- dat = TransformerTestingDataset(x, self.src_modalities, is_embedding)
571
- ldr = DataLoader(
572
- dataset = dat,
573
- batch_size = _batch_size if _batch_size is not None else len(x),
574
- shuffle = False,
575
- drop_last = False,
576
- num_workers = 0,
577
- collate_fn = TransformerTestingDataset.collate_fn,
578
- )
579
-
580
- # run model and extract embeddings
581
- embeddings: list[dict[str, Any]] = []
582
- for x_batch, _ in ldr:
583
- # mount data to the proper device
584
- x_batch = {k: x_batch[k].to(self.device) for k in self.src_modalities}
585
-
586
- # forward
587
- out: dict[str, Tensor] = self.net_.forward_emb(x_batch, is_embedding)
588
-
589
- # convert output from dict-of-list to list of dict, then append
590
- tmp = {k: out[k].detach().cpu().numpy() for k in self.src_modalities}
591
- tmp = [{k: tmp[k][i] for k in self.src_modalities} for i in range(len(next(iter(tmp.values()))))]
592
- embeddings += tmp
593
-
594
- # remove imputed embeddings
595
- for i in range(len(x)):
596
- avail = [k for k, v in x[i].items() if v is not None]
597
- embeddings[i] = {k: embeddings[i][k] for k in avail}
598
-
599
- return embeddings
600
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- from .transformer import Transformer
2
- from .vitautoenc import ViTAutoEnc
3
- from .unet import UNet3D
4
- from .unet_3d import UNet3DBase
5
- from .focal_loss import SigmoidFocalLoss
6
- from .unet_img_model import ImageModel
7
- from .img_model_wrapper import ImagingModelWrapper
8
- from .resnet_img_model import ResNetModel
9
- from .c3d import C3D
10
- from .dense_net import DenseNet
11
- from .cnn_resnet3d import CNNResNet3D
12
- from .cnn_resnet3d_with_linear_classifier import CNNResNet3DWithLinearClassifier
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (937 Bytes)
 
adrd/nn/__pycache__/blocks.cpython-311.pyc DELETED
Binary file (2.85 kB)
 
adrd/nn/__pycache__/c3d.cpython-311.pyc DELETED
Binary file (5.15 kB)
 
adrd/nn/__pycache__/cnn_resnet3d.cpython-311.pyc DELETED
Binary file (4.35 kB)
 
adrd/nn/__pycache__/cnn_resnet3d_with_linear_classifier.cpython-311.pyc DELETED
Binary file (4.03 kB)
 
adrd/nn/__pycache__/dense_net.cpython-311.pyc DELETED
Binary file (13.8 kB)
 
adrd/nn/__pycache__/focal_loss.cpython-311.pyc DELETED
Binary file (6.2 kB)
 
adrd/nn/__pycache__/img_model_wrapper.cpython-311.pyc DELETED
Binary file (8.67 kB)
 
adrd/nn/__pycache__/net_resnet3d.cpython-311.pyc DELETED
Binary file (17.1 kB)
 
adrd/nn/__pycache__/resnet3d.cpython-311.pyc DELETED
Binary file (13.2 kB)
 
adrd/nn/__pycache__/resnet_img_model.cpython-311.pyc DELETED
Binary file (2.83 kB)
 
adrd/nn/__pycache__/selfattention.cpython-311.pyc DELETED
Binary file (3.56 kB)
 
adrd/nn/__pycache__/transformer.cpython-311.pyc DELETED
Binary file (14 kB)
 
adrd/nn/__pycache__/unet.cpython-311.pyc DELETED
Binary file (15.8 kB)
 
adrd/nn/__pycache__/unet_3d.cpython-311.pyc DELETED
Binary file (3 kB)
 
adrd/nn/__pycache__/unet_img_model.cpython-311.pyc DELETED
Binary file (14.1 kB)
 
adrd/nn/__pycache__/vitautoenc.cpython-311.pyc DELETED
Binary file (8.59 kB)
 
adrd/nn/blocks.py DELETED
@@ -1,57 +0,0 @@
1
- # Copyright (c) MONAI Consortium
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- # Unless required by applicable law or agreed to in writing, software
7
- # distributed under the License is distributed on an "AS IS" BASIS,
8
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
- # See the License for the specific language governing permissions and
10
- # limitations under the License.
11
-
12
- from monai.networks.blocks.mlp import MLPBlock
13
- from typing import Sequence, Union
14
- import torch
15
- import torch.nn as nn
16
-
17
- from ..nn.selfattention import SABlock
18
-
19
- class TransformerBlock(nn.Module):
20
- """
21
- A transformer block, based on: "Dosovitskiy et al.,
22
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
23
- """
24
-
25
- def __init__(
26
- self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
27
- ) -> None:
28
- """
29
- Args:
30
- hidden_size: dimension of hidden layer.
31
- mlp_dim: dimension of feedforward layer.
32
- num_heads: number of attention heads.
33
- dropout_rate: faction of the input units to drop.
34
- qkv_bias: apply bias term for the qkv linear layer
35
-
36
- """
37
-
38
- super().__init__()
39
-
40
- if not (0 <= dropout_rate <= 1):
41
- raise ValueError("dropout_rate should be between 0 and 1.")
42
-
43
- if hidden_size % num_heads != 0:
44
- raise ValueError("hidden_size should be divisible by num_heads.")
45
-
46
- self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
47
- self.norm1 = nn.LayerNorm(hidden_size)
48
- self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
49
- self.norm2 = nn.LayerNorm(hidden_size)
50
-
51
- def forward(self, x, return_attention=False):
52
- y, attn = self.attn(self.norm1(x))
53
- if return_attention:
54
- return attn
55
- x = x + y
56
- x = x + self.mlp(self.norm2(x))
57
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/c3d.py DELETED
@@ -1,99 +0,0 @@
1
- # From https://github.com/xmuyzz/3D-CNN-PyTorch/blob/master/models/C3DNet.py
2
-
3
- import torch
4
- import torch.nn as nn
5
- import sys
6
- # from icecream import ic
7
- import math
8
-
9
- class C3D(torch.nn.Module):
10
-
11
- def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None):
12
-
13
- super(C3D, self).__init__()
14
- self.conv_group1 = nn.Sequential(
15
- nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
16
- nn.BatchNorm3d(64),
17
- nn.ReLU(),
18
- nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
19
- self.conv_group2 = nn.Sequential(
20
- nn.Conv3d(64, 128, kernel_size=3, padding=1),
21
- nn.BatchNorm3d(128),
22
- nn.ReLU(),
23
- nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
24
- self.conv_group3 = nn.Sequential(
25
- nn.Conv3d(128, 256, kernel_size=3, padding=1),
26
- nn.BatchNorm3d(256),
27
- nn.ReLU(),
28
- nn.Conv3d(256, 256, kernel_size=3, padding=1),
29
- nn.BatchNorm3d(256),
30
- nn.ReLU(),
31
- nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
32
- )
33
- self.conv_group4 = nn.Sequential(
34
- nn.Conv3d(256, 512, kernel_size=3, padding=1),
35
- nn.BatchNorm3d(512),
36
- nn.ReLU(),
37
- nn.Conv3d(512, 512, kernel_size=3, padding=1),
38
- nn.BatchNorm3d(512),
39
- nn.ReLU(),
40
- nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
41
- )
42
-
43
- # last_duration = int(math.floor(128 / 16))
44
- # last_size = int(math.ceil(128 / 32))
45
- self.fc1 = nn.Sequential(
46
- nn.Linear((512 * 15 * 9 * 9) , 512),
47
- nn.ReLU(),
48
- nn.Dropout(0.5))
49
- self.fc2 = nn.Sequential(
50
- nn.Linear(512, 256),
51
- nn.ReLU(),
52
- nn.Dropout(0.5))
53
- # self.fc = nn.Sequential(
54
- # nn.Linear(4096, num_classes))
55
-
56
- self.fc = torch.nn.ModuleDict()
57
- for k in tgt_modalities:
58
- self.fc[k] = torch.nn.Linear(256, 1)
59
-
60
- def forward(self, x):
61
- # for k in x.keys():
62
- # x[k] = x[k].to(torch.float32)
63
-
64
- # x = torch.stack([o for o in x.values()], dim=0)[0]
65
- # print(x.shape)
66
-
67
- out = self.conv_group1(x)
68
- out = self.conv_group2(out)
69
- out = self.conv_group3(out)
70
- out = self.conv_group4(out)
71
- out = out.view(out.size(0), -1)
72
- # print(out.shape)
73
- out = self.fc1(out)
74
- out = self.fc2(out)
75
- # out = self.fc(out)
76
-
77
- tgt_iter = self.fc.keys()
78
- out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter}
79
- return out_tgt
80
-
81
-
82
- if __name__ == "__main__":
83
- model = C3D(tgt_modalities=['NC', 'MCI', 'DE'])
84
- print(model)
85
- x = torch.rand((1, 1, 128, 128, 128))
86
- # layers = list(model.features.named_children())
87
- # features = nn.Sequential(*list(model.features.children()))(x)
88
- # print(features.shape)
89
- print(sum(p.numel() for p in model.parameters()))
90
- # layer_found = False
91
- # features = None
92
- # desired_layer_name = 'transition3'
93
-
94
- # for name, layer in layers:
95
- # if name == desired_layer_name:
96
- # x = layer(x)
97
- # print(x)
98
- # model(x)
99
- # print(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/cnn_resnet3d.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from typing import Any, Type
4
- Tensor = Type[torch.Tensor]
5
-
6
- from .resnet3d import r3d_18
7
-
8
-
9
- class CNNResNet3D(nn.Module):
10
-
11
- def __init__(self,
12
- src_modalities: dict[str, dict[str, Any]],
13
- tgt_modalities: dict[str, dict[str, Any]]
14
- ) -> None:
15
- """ ... """
16
- super().__init__()
17
-
18
- # resnet
19
- # embedding modules for source
20
- self.modules_emb_src = nn.ModuleDict()
21
- for k, info in src_modalities.items():
22
- if info['type'] == 'imaging' and len(info['img_shape']) == 4:
23
- self.modules_emb_src[k] = nn.Sequential(
24
- r3d_18(),
25
- nn.Dropout(0.5)
26
- )
27
- else:
28
- # unrecognized
29
- raise ValueError('{} is an unrecognized data modality'.format(k))
30
-
31
- # classifiers (binary only)
32
- self.modules_cls = nn.ModuleDict()
33
- for k, info in tgt_modalities.items():
34
- if info['type'] == 'categorical' and info['num_categories'] == 2:
35
- # categorical
36
- self.modules_cls[k] = nn.Linear(256, 1)
37
- else:
38
- # unrecognized
39
- raise ValueError
40
-
41
- def forward(self,
42
- x: dict[str, Tensor],
43
- ) -> dict[str, Tensor]:
44
- """ ... """
45
- out_emb = self.forward_emb(x)
46
- out_emb = out_emb[list(out_emb.keys())[0]]
47
- out_cls = self.forward_cls(out_emb)
48
- return out_cls
49
-
50
- def forward_emb(self,
51
- x: dict[str, Tensor],
52
- ) -> dict[str, Tensor]:
53
- """ ... """
54
- out_emb = dict()
55
- for k in self.modules_emb_src.keys():
56
- out_emb[k] = self.modules_emb_src[k](x[k])
57
- return out_emb
58
-
59
- def forward_cls(self,
60
- out_emb: dict[str, Tensor]
61
- ) -> dict[str, Tensor]:
62
- """ ... """
63
- out_cls = dict()
64
- for k in self.modules_cls.keys():
65
- out_cls[k] = self.modules_cls[k](out_emb).squeeze(1)
66
- return out_cls
67
-
68
-
69
- # for testing purpose only
70
- if __name__ == '__main__':
71
- src_modalities = {
72
- 'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]}
73
- }
74
- tgt_modalities = {
75
- 'AD': {'type': 'categorical', 'num_categories': 2},
76
- 'PD': {'type': 'categorical', 'num_categories': 2}
77
- }
78
- net = CNNResNet3D(src_modalities, tgt_modalities)
79
- net.eval()
80
- x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)}
81
- print(net(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/cnn_resnet3d_with_linear_classifier.py DELETED
@@ -1,56 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from typing import Any, Type
4
- Tensor = Type[torch.Tensor]
5
-
6
- from .resnet3d import r3d_18
7
-
8
- class CNNResNet3DWithLinearClassifier(nn.Module):
9
-
10
- def __init__(self,
11
- src_modalities: dict[str, dict[str, Any]],
12
- tgt_modalities: dict[str, dict[str, Any]]
13
- ) -> None:
14
- """ ... """
15
- super().__init__()
16
- self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities))
17
- self.src_modalities = src_modalities
18
- self.tgt_modalities = tgt_modalities
19
-
20
- def forward(self,
21
- x: dict[str, Tensor],
22
- ) -> dict[str, Tensor]:
23
- """ x is expected to be a singleton dictionary """
24
- src_k = list(x.keys())[0]
25
- x = x[src_k]
26
- out = self.core(x)
27
- out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)}
28
- return out
29
-
30
-
31
- class _CNNResNet3DWithLinearClassifier(nn.Module):
32
-
33
- def __init__(self,
34
- len_tgt_modalities: int,
35
- ) -> None:
36
- """ ... """
37
- super().__init__()
38
- self.cnn = r3d_18()
39
- self.cls = nn.Sequential(
40
- nn.Dropout(0.5),
41
- nn.Linear(256, len_tgt_modalities),
42
- )
43
-
44
- def forward(self, x: Tensor) -> Tensor:
45
- """ ... """
46
- out_emb = self.forward_emb(x)
47
- out_cls = self.forward_cls(out_emb)
48
- return out_cls
49
-
50
- def forward_emb(self, x: Tensor) -> Tensor:
51
- """ ... """
52
- return self.cnn(x)
53
-
54
- def forward_cls(self, out_emb: Tensor) -> Tensor:
55
- """ ... """
56
- return self.cls(out_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/dense_net.py DELETED
@@ -1,211 +0,0 @@
1
- # This implementation is based on the DenseNet-BC implementation in torchvision
2
- # https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
3
- # https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py
4
-
5
-
6
- import math
7
- import torch
8
- import numpy as np
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import torch.utils.checkpoint as cp
12
- from collections import OrderedDict
13
-
14
-
15
- def _bn_function_factory(norm, relu, conv):
16
- def bn_function(*inputs):
17
- concated_features = torch.cat(inputs, 1)
18
- bottleneck_output = conv(relu(norm(concated_features)))
19
- return bottleneck_output
20
-
21
- return bn_function
22
-
23
-
24
- class _DenseLayer(nn.Module):
25
- def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
26
- super(_DenseLayer, self).__init__()
27
- self.add_module('norm1', nn.BatchNorm3d(num_input_features)),
28
- self.add_module('relu1', nn.ReLU(inplace=True)),
29
- self.add_module('conv1', nn.Conv3d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
30
- self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate)),
31
- self.add_module('relu2', nn.ReLU(inplace=True)),
32
- self.add_module('conv2', nn.Conv3d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
33
- self.drop_rate = drop_rate
34
- self.efficient = efficient
35
-
36
- def forward(self, *prev_features):
37
- bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
38
- if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
39
- bottleneck_output = cp.checkpoint(bn_function, *prev_features)
40
- else:
41
- bottleneck_output = bn_function(*prev_features)
42
- new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
43
- if self.drop_rate > 0:
44
- new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
45
- return new_features
46
-
47
-
48
- class _Transition(nn.Sequential):
49
- def __init__(self, num_input_features, num_output_features):
50
- super(_Transition, self).__init__()
51
- self.add_module('norm', nn.BatchNorm3d(num_input_features))
52
- self.add_module('relu', nn.ReLU(inplace=True))
53
- self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
54
- self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))
55
-
56
-
57
- class _DenseBlock(nn.Module):
58
- def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
59
- super(_DenseBlock, self).__init__()
60
- for i in range(num_layers):
61
- layer = _DenseLayer(
62
- num_input_features + i * growth_rate,
63
- growth_rate=growth_rate,
64
- bn_size=bn_size,
65
- drop_rate=drop_rate,
66
- efficient=efficient,
67
- )
68
- self.add_module('denselayer%d' % (i + 1), layer)
69
-
70
- def forward(self, init_features):
71
- features = [init_features]
72
- for name, layer in self.named_children():
73
- new_features = layer(*features)
74
- features.append(new_features)
75
- return torch.cat(features, 1)
76
-
77
-
78
- class DenseNet(nn.Module):
79
- r"""Densenet-BC model class, based on
80
- `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
81
- Args:
82
- growth_rate (int) - how many filters to add each layer (`k` in paper)
83
- block_config (list of 3 or 4 ints) - how many layers in each pooling block
84
- num_init_features (int) - the number of filters to learn in the first convolution layer
85
- bn_size (int) - multiplicative factor for number of bottle neck layers
86
- (i.e. bn_size * k features in the bottleneck layer)
87
- drop_rate (float) - dropout rate after each dense layer
88
- tgt_modalities (list) - list of target modalities
89
- efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
90
- """
91
- # def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
92
- # num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 1
93
-
94
- def __init__(self, tgt_modalities, growth_rate=12, block_config=(3, 3, 3), compression=0.5,
95
- num_init_features=16, bn_size=4, drop_rate=0, efficient=False, load_from_ckpt=False): # config 2
96
-
97
- super(DenseNet, self).__init__()
98
-
99
- # First convolution
100
- self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv3d(1, num_init_features, kernel_size=7, stride=2, padding=0, bias=False)),]))
101
- self.features.add_module('norm0', nn.BatchNorm3d(num_init_features))
102
- self.features.add_module('relu0', nn.ReLU(inplace=True))
103
- self.features.add_module('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=False))
104
- self.tgt_modalities = tgt_modalities
105
-
106
- # Each denseblock
107
- num_features = num_init_features
108
- for i, num_layers in enumerate(block_config):
109
- block = _DenseBlock(
110
- num_layers=num_layers,
111
- num_input_features=num_features,
112
- bn_size=bn_size,
113
- growth_rate=growth_rate,
114
- drop_rate=drop_rate,
115
- efficient=efficient,
116
- )
117
- self.features.add_module('denseblock%d' % (i + 1), block)
118
- num_features = num_features + num_layers * growth_rate
119
- if i != len(block_config):
120
- trans = _Transition(num_input_features=num_features,
121
- num_output_features=int(num_features * compression))
122
- self.features.add_module('transition%d' % (i + 1), trans)
123
- num_features = int(num_features * compression)
124
-
125
- # Final batch norm
126
- self.features.add_module('norm_final', nn.BatchNorm3d(num_features))
127
-
128
- # Classification heads
129
- self.tgt = torch.nn.ModuleDict()
130
- for k in tgt_modalities:
131
- # self.tgt[k] = torch.nn.Linear(621, 1) # config 2
132
- self.tgt[k] = torch.nn.Sequential(
133
- torch.nn.Linear(self.test_size(), 256),
134
- torch.nn.ReLU(),
135
- torch.nn.Linear(256, 1)
136
- )
137
-
138
- print(f'load_from_ckpt: {load_from_ckpt}')
139
- # Initialization
140
- if not load_from_ckpt:
141
- for name, param in self.named_parameters():
142
- if 'conv' in name and 'weight' in name:
143
- n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
144
- param.data.normal_().mul_(math.sqrt(2. / n))
145
- elif 'norm' in name and 'weight' in name:
146
- param.data.fill_(1)
147
- elif 'norm' in name and 'bias' in name:
148
- param.data.fill_(0)
149
- elif ('classifier' in name or 'tgt' in name) and 'bias' in name:
150
- param.data.fill_(0)
151
-
152
- # self.size = self.test_size()
153
-
154
- def forward(self, x, shap=True):
155
- # print(x.shape)
156
- features = self.features(x)
157
- # print(features.shape)
158
- out = F.relu(features, inplace=True)
159
- # out = F.adaptive_avg_pool3d(out, (1, 1, 1))
160
- out = torch.flatten(out, 1)
161
-
162
- # print(out.shape)
163
-
164
- # out_tgt = self.tgt(out).squeeze(1)
165
- # print(out_tgt)
166
- # return F.softmax(out_tgt)
167
-
168
- tgt_iter = self.tgt.keys()
169
- out_tgt = {k: self.tgt[k](out).squeeze(1) for k in tgt_iter}
170
- if shap:
171
- out_tgt = torch.stack(list(out_tgt.values()))
172
- return out_tgt.T
173
- else:
174
- return out_tgt
175
-
176
- def test_size(self):
177
- case = torch.ones((1, 1, 182, 218, 182))
178
- output = self.features(case).view(-1).size(0)
179
- return output
180
-
181
-
182
- if __name__ == "__main__":
183
- model = DenseNet(
184
- tgt_modalities=['NC', 'MCI', 'DE'],
185
- growth_rate=12,
186
- block_config=(2, 3, 2),
187
- compression=0.5,
188
- num_init_features=16,
189
- drop_rate=0.2)
190
- print(model)
191
- torch.manual_seed(42)
192
- x = torch.rand((1, 1, 182, 218, 182))
193
- # layers = list(model.features.named_children())
194
- features = nn.Sequential(*list(model.features.children()))(x)
195
- print(features.shape)
196
- print(sum(p.numel() for p in model.parameters()))
197
- # out = mdl.net_(x, shap=False)
198
- # print(out)
199
-
200
- out = model(x, shap=False)
201
- print(out)
202
- # layer_found = False
203
- # features = None
204
- # desired_layer_name = 'transition3'
205
-
206
- # for name, layer in layers:
207
- # if name == desired_layer_name:
208
- # x = layer(x)
209
- # print(x)
210
- # model(x)
211
- # print(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/focal_loss.py DELETED
@@ -1,120 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import sys
5
-
6
- class SigmoidFocalLoss(nn.Module):
7
- ''' ... '''
8
- def __init__(
9
- self,
10
- alpha: float = -1,
11
- gamma: float = 2.0,
12
- reduction: str = 'mean',
13
- ):
14
- ''' ... '''
15
- super().__init__()
16
- self.alpha = alpha
17
- self.gamma = gamma
18
- self.reduction = reduction
19
-
20
- def forward(self, input, target):
21
- ''' ... '''
22
- p = torch.sigmoid(input)
23
- ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
24
- p_t = p * target + (1 - p) * (1 - target)
25
- loss = ce_loss * ((1 - p_t) ** self.gamma)
26
-
27
- if self.alpha >= 0:
28
- alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
29
- loss = alpha_t * loss
30
-
31
- if self.reduction == 'mean':
32
- loss = loss.mean()
33
- elif self.reduction == 'sum':
34
- loss = loss.sum()
35
-
36
- return loss
37
-
38
-
39
- class SigmoidFocalLossBeta(nn.Module):
40
- ''' ... '''
41
- def __init__(
42
- self,
43
- beta: float = 0.9999,
44
- gamma: float = 2.0,
45
- num_per_cls = (1, 1),
46
- reduction: str = 'mean',
47
- ):
48
- ''' ... '''
49
- super().__init__()
50
- eps = sys.float_info.epsilon
51
- self.gamma = gamma
52
- self.reduction = reduction
53
-
54
- # weights to balance loss
55
- self.weight_neg = ((1 - beta) / (1 - beta ** num_per_cls[0] + eps))
56
- self.weight_pos = ((1 - beta) / (1 - beta ** num_per_cls[1] + eps))
57
- # weight_neg = (1 - beta) / (1 - beta ** num_per_cls[0])
58
- # weight_pos = (1 - beta) / (1 - beta ** num_per_cls[1])
59
- # self.weight_neg = weight_neg / (weight_neg + weight_pos)
60
- # self.weight_pos = weight_pos / (weight_neg + weight_pos)
61
-
62
- def forward(self, input, target):
63
- ''' ... '''
64
- p = torch.sigmoid(input)
65
- p_t = p * target + (1 - p) * (1 - target)
66
- ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
67
- loss = ce_loss * ((1 - p_t) ** self.gamma)
68
-
69
- alpha_t = self.weight_pos * target + self.weight_neg * (1 - target)
70
- loss = alpha_t * loss
71
-
72
- if self.reduction == 'mean':
73
- loss = loss.mean()
74
- elif self.reduction == 'sum':
75
- loss = loss.sum()
76
-
77
- return loss
78
-
79
- class AsymmetricLoss(nn.Module):
80
- def __init__(self, gamma_neg=4, gamma_pos=1, alpha=0.5, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
81
- super(AsymmetricLoss, self).__init__()
82
- self.alpha = alpha
83
- self.gamma_neg = gamma_neg
84
- self.gamma_pos = gamma_pos
85
- self.clip = clip
86
- self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
87
- self.eps = eps
88
-
89
-
90
- def forward(self, x, y):
91
- """"
92
- Parameters
93
- ----------
94
- x: input logits
95
- y: targets (multi-label binarized vector)
96
- """
97
- # Calculating Probabilities
98
- x_sigmoid = torch.sigmoid(x)
99
- xs_pos = x_sigmoid
100
- xs_neg = 1 - x_sigmoid
101
- # Asymmetric Clipping
102
- if self.clip is not None and self.clip > 0:
103
- xs_neg = (xs_neg + self.clip).clamp(max=1)
104
- # Basic CE calculation
105
- los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
106
- los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
107
- loss = self.alpha*los_pos + (1-self.alpha)*los_neg
108
- # Asymmetric Focusing
109
- if self.gamma_neg > 0 or self.gamma_pos > 0:
110
- if self.disable_torch_grad_focal_loss:
111
- torch.set_grad_enabled(False)
112
- pt0 = xs_pos * y
113
- pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
114
- pt = pt0 + pt1
115
- one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
116
- one_sided_w = torch.pow(1 - pt, one_sided_gamma)
117
- if self.disable_torch_grad_focal_loss:
118
- torch.set_grad_enabled(True)
119
- loss *= one_sided_w
120
- return -loss#.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/img_model_wrapper.py DELETED
@@ -1,174 +0,0 @@
1
- import torch
2
- from .. import nn
3
- from .. import model
4
- import numpy as np
5
- from icecream import ic
6
- from monai.networks.nets.swin_unetr import SwinUNETR
7
- from typing import Any
8
-
9
- class ImagingModelWrapper(torch.nn.Module):
10
- def __init__(
11
- self,
12
- arch: str = 'ViTAutoEnc',
13
- tgt_modalities: dict | None = {},
14
- img_size: int | None = 128,
15
- patch_size: int | None = 16,
16
- ckpt_path: str | None = None,
17
- train_backbone: bool = False,
18
- out_dim: int = 128,
19
- layers: int | None = 1,
20
- device: str = 'cpu',
21
- fusion_stage: str = 'middle',
22
- ):
23
- super(ImagingModelWrapper, self).__init__()
24
-
25
- self.arch = arch
26
- self.tgt_modalities = tgt_modalities
27
- self.img_size = img_size
28
- self.patch_size = patch_size
29
- self.train_backbone = train_backbone
30
- self.ckpt_path = ckpt_path
31
- self.device = device
32
- self.out_dim = out_dim
33
- self.layers = layers
34
- self.fusion_stage = fusion_stage
35
-
36
-
37
- if "swinunetr" in self.arch.lower():
38
- if "emb" not in self.arch.lower():
39
- ckpt_path = '/projectnb/ivc-ml/dlteif/pretrained_models/model_swinvit.pt'
40
- ckpt = torch.load(ckpt_path, map_location='cpu')
41
- self.img_model = SwinUNETR(
42
- in_channels=1,
43
- out_channels=1,
44
- img_size=128,
45
- feature_size=48,
46
- use_checkpoint=True,
47
- )
48
- ckpt["state_dict"] = {k.replace("swinViT.", "module."): v for k, v in ckpt["state_dict"].items()}
49
- ic(ckpt["state_dict"].keys())
50
- self.img_model.load_from(ckpt)
51
- self.dim = 768
52
-
53
- elif "vit" in self.arch.lower():
54
- if "emb" not in self.arch.lower():
55
- # Initialize image model
56
- self.img_model = nn.__dict__[self.arch](
57
- in_channels = 1,
58
- img_size = self.img_size,
59
- patch_size = self.patch_size,
60
- )
61
-
62
- if self.ckpt_path:
63
- self.img_model.load(self.ckpt_path, map_location=self.device)
64
- self.dim = self.img_model.hidden_size
65
- else:
66
- self.dim = 768
67
-
68
- if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
69
- dim = self.dim
70
- if self.fusion_stage == 'middle':
71
- downsample = torch.nn.ModuleList()
72
- # print('Number of layers: ', self.layers)
73
- for i in range(self.layers):
74
- if i == self.layers - 1:
75
- dim_out = self.out_dim
76
- # print(layers)
77
- ks = 2
78
- stride = 2
79
- else:
80
- dim_out = dim // 2
81
- ks = 2
82
- stride = 2
83
-
84
- downsample.append(
85
- torch.nn.Conv1d(in_channels=dim, out_channels=dim_out, kernel_size=ks, stride=stride)
86
- )
87
-
88
- dim = dim_out
89
-
90
- downsample.append(
91
- torch.nn.BatchNorm1d(dim)
92
- )
93
- downsample.append(
94
- torch.nn.ReLU()
95
- )
96
-
97
-
98
- self.downsample = torch.nn.Sequential(*downsample)
99
- elif self.fusion_stage == 'late':
100
- self.downsample = torch.nn.Identity()
101
- else:
102
- pass
103
-
104
- # print('Downsample layers: ', self.downsample)
105
-
106
- elif "densenet" in self.arch.lower():
107
- if "emb" not in self.arch.lower():
108
- self.img_model = model.ImagingModel.from_ckpt(self.ckpt_path, device=self.device, img_backend=self.arch, load_from_ckpt=True).net_
109
-
110
- self.downsample = torch.nn.Linear(3900, self.out_dim)
111
-
112
- # randomly initialize weights for downsample block
113
- for p in self.downsample.parameters():
114
- if p.dim() > 1:
115
- torch.nn.init.xavier_uniform_(p)
116
- p.requires_grad = True
117
-
118
- if "emb" not in self.arch.lower():
119
- # freeze imaging model parameters
120
- if "densenet" in self.arch.lower():
121
- for n, p in self.img_model.features.named_parameters():
122
- if not self.train_backbone:
123
- p.requires_grad = False
124
- else:
125
- p.requires_grad = True
126
- for n, p in self.img_model.tgt.named_parameters():
127
- p.requires_grad = False
128
- else:
129
- for n, p in self.img_model.named_parameters():
130
- # print(n, p.requires_grad)
131
- if not self.train_backbone:
132
- p.requires_grad = False
133
- else:
134
- p.requires_grad = True
135
-
136
- def forward(self, x):
137
- # print("--------ImagingModelWrapper forward--------")
138
- if "emb" not in self.arch.lower():
139
- if "swinunetr" in self.arch.lower():
140
- # print(x.size())
141
- out = self.img_model(x)
142
- # print(out.size())
143
- out = self.downsample(out)
144
- # print(out.size())
145
- out = torch.mean(out, dim=-1)
146
- # print(out.size())
147
- elif "vit" in self.arch.lower():
148
- out = self.img_model(x, return_emb=True)
149
- ic(out.size())
150
- out = self.downsample(out)
151
- out = torch.mean(out, dim=-1)
152
- elif "densenet" in self.arch.lower():
153
- out = torch.nn.Sequential(*list(self.img_model.features.children()))(x)
154
- # print(out.size())
155
- out = torch.flatten(out, 1)
156
- out = self.downsample(out)
157
- else:
158
- # print(x.size())
159
- if "swinunetr" in self.arch.lower():
160
- x = torch.squeeze(x, dim=1)
161
- x = x.view(x.size(0),self.dim, -1)
162
- # print('x: ', x.size())
163
- out = self.downsample(x)
164
- # print('out: ', out.size())
165
- if self.fusion_stage == 'middle':
166
- if "vit" in self.arch.lower() or "swinunetr" in self.arch.lower():
167
- out = torch.mean(out, dim=-1)
168
- else:
169
- out = torch.squeeze(out, dim=1)
170
- elif self.fusion_stage == 'late':
171
- pass
172
-
173
- return out
174
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/net_resnet3d.py DELETED
@@ -1,338 +0,0 @@
1
- """
2
- Created on Sat Nov 21 10:49:39 2021
3
-
4
- @author: cxue2
5
- """
6
-
7
- import torch.nn as nn
8
-
9
-
10
- __all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']
11
-
12
-
13
- class Conv3DSimple(nn.Conv3d):
14
- def __init__(self,
15
- in_planes,
16
- out_planes,
17
- midplanes=None,
18
- stride=1,
19
- padding=1):
20
-
21
- super(Conv3DSimple, self).__init__(
22
- in_channels=in_planes,
23
- out_channels=out_planes,
24
- kernel_size=(3, 3, 3),
25
- stride=stride,
26
- padding=padding,
27
- bias=False)
28
-
29
- @staticmethod
30
- def get_downsample_stride(stride):
31
- return stride, stride, stride
32
-
33
-
34
- class Conv2Plus1D(nn.Sequential):
35
-
36
- def __init__(self,
37
- in_planes,
38
- out_planes,
39
- midplanes,
40
- stride=1,
41
- padding=1):
42
- super(Conv2Plus1D, self).__init__(
43
- nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
44
- stride=(1, stride, stride), padding=(0, padding, padding),
45
- bias=False),
46
- nn.BatchNorm3d(midplanes),
47
- nn.ReLU(inplace=True),
48
- nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
49
- stride=(stride, 1, 1), padding=(padding, 0, 0),
50
- bias=False))
51
-
52
- @staticmethod
53
- def get_downsample_stride(stride):
54
- return stride, stride, stride
55
-
56
-
57
- class Conv3DNoTemporal(nn.Conv3d):
58
-
59
- def __init__(self,
60
- in_planes,
61
- out_planes,
62
- midplanes=None,
63
- stride=1,
64
- padding=1):
65
-
66
- super(Conv3DNoTemporal, self).__init__(
67
- in_channels=in_planes,
68
- out_channels=out_planes,
69
- kernel_size=(1, 3, 3),
70
- stride=(1, stride, stride),
71
- padding=(0, padding, padding),
72
- bias=False)
73
-
74
- @staticmethod
75
- def get_downsample_stride(stride):
76
- return 1, stride, stride
77
-
78
-
79
- class BasicBlock(nn.Module):
80
-
81
- expansion = 1
82
-
83
- def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
84
- midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
85
-
86
- super(BasicBlock, self).__init__()
87
- self.conv1 = nn.Sequential(
88
- conv_builder(inplanes, planes, midplanes, stride),
89
- nn.BatchNorm3d(planes),
90
- nn.ReLU(inplace=True)
91
- )
92
- self.conv2 = nn.Sequential(
93
- conv_builder(planes, planes, midplanes),
94
- nn.BatchNorm3d(planes)
95
- )
96
- self.relu = nn.ReLU(inplace=True)
97
- self.downsample = downsample
98
- self.stride = stride
99
-
100
- def forward(self, x):
101
- residual = x
102
-
103
- out = self.conv1(x)
104
- out = self.conv2(out)
105
- if self.downsample is not None:
106
- residual = self.downsample(x)
107
-
108
- out += residual
109
- out = self.relu(out)
110
-
111
- return out
112
-
113
-
114
- class Bottleneck(nn.Module):
115
- expansion = 4
116
-
117
- def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
118
-
119
- super(Bottleneck, self).__init__()
120
- midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
121
-
122
- # 1x1x1
123
- self.conv1 = nn.Sequential(
124
- nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
125
- nn.BatchNorm3d(planes),
126
- nn.ReLU(inplace=True)
127
- )
128
- # Second kernel
129
- self.conv2 = nn.Sequential(
130
- conv_builder(planes, planes, midplanes, stride),
131
- nn.BatchNorm3d(planes),
132
- nn.ReLU(inplace=True)
133
- )
134
-
135
- # 1x1x1
136
- self.conv3 = nn.Sequential(
137
- nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
138
- nn.BatchNorm3d(planes * self.expansion)
139
- )
140
- self.relu = nn.ReLU(inplace=True)
141
- self.downsample = downsample
142
- self.stride = stride
143
-
144
- def forward(self, x):
145
- residual = x
146
-
147
- out = self.conv1(x)
148
- out = self.conv2(out)
149
- out = self.conv3(out)
150
-
151
- if self.downsample is not None:
152
- residual = self.downsample(x)
153
-
154
- out += residual
155
- out = self.relu(out)
156
-
157
- return out
158
-
159
-
160
- class BasicStem(nn.Sequential):
161
- """The default conv-batchnorm-relu stem
162
- """
163
- def __init__(self):
164
- super(BasicStem, self).__init__(
165
- nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
166
- padding=(3, 3, 3), bias=False),
167
- nn.BatchNorm3d(64),
168
- nn.ReLU(inplace=True))
169
-
170
-
171
- class R2Plus1dStem(nn.Sequential):
172
- """R(2+1)D stem is different than the default one as it uses separated 3D convolution
173
- """
174
- def __init__(self):
175
- super(R2Plus1dStem, self).__init__(
176
- nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
177
- stride=(1, 2, 2), padding=(0, 3, 3),
178
- bias=False),
179
- nn.BatchNorm3d(45),
180
- nn.ReLU(inplace=True),
181
- nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
182
- stride=(1, 1, 1), padding=(1, 0, 0),
183
- bias=False),
184
- nn.BatchNorm3d(64),
185
- nn.ReLU(inplace=True))
186
-
187
-
188
- class VideoResNet(nn.Module):
189
-
190
- def __init__(self, block, conv_makers, layers,
191
- stem, num_classes=16,
192
- zero_init_residual=False):
193
- """Generic resnet video generator.
194
- Args:
195
- block (nn.Module): resnet building block
196
- conv_makers (list(functions)): generator function for each layer
197
- layers (List[int]): number of blocks per layer
198
- stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
199
- num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
200
- zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
201
- """
202
- super(VideoResNet, self).__init__()
203
- self.inplanes = 64
204
-
205
- self.stem = stem()
206
-
207
- self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
208
- self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
209
- self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
210
- self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
211
-
212
- self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
213
- self.fc = nn.Linear(256 * block.expansion, num_classes)
214
-
215
- # init weights
216
- self._initialize_weights()
217
-
218
- if zero_init_residual:
219
- for m in self.modules():
220
- if isinstance(m, Bottleneck):
221
- nn.init.constant_(m.bn3.weight, 0)
222
-
223
- def forward(self, x):
224
- x = self.stem(x)
225
-
226
- x = self.layer1(x)
227
- x = self.layer2(x)
228
- x = self.layer3(x)
229
- x = self.layer4(x)
230
-
231
- x = self.avgpool(x)
232
- # Flatten the layer to fc
233
- x = x.flatten(1)
234
- x = self.fc(x)
235
-
236
- return x
237
-
238
- def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
239
- downsample = None
240
-
241
- if stride != 1 or self.inplanes != planes * block.expansion:
242
- ds_stride = conv_builder.get_downsample_stride(stride)
243
- downsample = nn.Sequential(
244
- nn.Conv3d(self.inplanes, planes * block.expansion,
245
- kernel_size=1, stride=ds_stride, bias=False),
246
- nn.BatchNorm3d(planes * block.expansion)
247
- )
248
- layers = []
249
- layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
250
-
251
- self.inplanes = planes * block.expansion
252
- for i in range(1, blocks):
253
- layers.append(block(self.inplanes, planes, conv_builder))
254
-
255
- return nn.Sequential(*layers)
256
-
257
- def _initialize_weights(self):
258
- for m in self.modules():
259
- if isinstance(m, nn.Conv3d):
260
- nn.init.kaiming_normal_(m.weight, mode='fan_out',
261
- nonlinearity='relu')
262
- if m.bias is not None:
263
- nn.init.constant_(m.bias, 0)
264
- elif isinstance(m, nn.BatchNorm3d):
265
- nn.init.constant_(m.weight, 1)
266
- nn.init.constant_(m.bias, 0)
267
- elif isinstance(m, nn.Linear):
268
- nn.init.normal_(m.weight, 0, 0.01)
269
- nn.init.constant_(m.bias, 0)
270
-
271
-
272
- def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
273
- model = VideoResNet(**kwargs)
274
-
275
- return model
276
-
277
-
278
- def r3d_18(pretrained=False, progress=True, **kwargs):
279
- """Construct 18 layer Resnet3D model as in
280
- https://arxiv.org/abs/1711.11248
281
- Args:
282
- pretrained (bool): If True, returns a model pre-trained on Kinetics-400
283
- progress (bool): If True, displays a progress bar of the download to stderr
284
- Returns:
285
- nn.Module: R3D-18 network
286
- """
287
-
288
- return _video_resnet('r3d_18',
289
- pretrained, progress,
290
- block=BasicBlock,
291
- conv_makers=[Conv3DSimple] * 4,
292
- layers=[2, 2, 2, 2],
293
- stem=BasicStem, **kwargs)
294
-
295
-
296
- def mc3_18(pretrained=False, progress=True, **kwargs):
297
- """Constructor for 18 layer Mixed Convolution network as in
298
- https://arxiv.org/abs/1711.11248
299
- Args:
300
- pretrained (bool): If True, returns a model pre-trained on Kinetics-400
301
- progress (bool): If True, displays a progress bar of the download to stderr
302
- Returns:
303
- nn.Module: MC3 Network definition
304
- """
305
- return _video_resnet('mc3_18',
306
- pretrained, progress,
307
- block=BasicBlock,
308
- conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
309
- layers=[2, 2, 2, 2],
310
- stem=BasicStem, **kwargs)
311
-
312
-
313
- def r2plus1d_18(pretrained=False, progress=True, **kwargs):
314
- """Constructor for the 18 layer deep R(2+1)D network as in
315
- https://arxiv.org/abs/1711.11248
316
- Args:
317
- pretrained (bool): If True, returns a model pre-trained on Kinetics-400
318
- progress (bool): If True, displays a progress bar of the download to stderr
319
- Returns:
320
- nn.Module: R(2+1)D-18 network
321
- """
322
- return _video_resnet('r2plus1d_18',
323
- pretrained, progress,
324
- block=BasicBlock,
325
- conv_makers=[Conv2Plus1D] * 4,
326
- layers=[2, 2, 2, 2],
327
- stem=R2Plus1dStem, **kwargs)
328
-
329
-
330
- if __name__ == '__main__':
331
-
332
- import torch
333
-
334
- net = r3d_18().to(0)
335
- x = torch.zeros(3, 1, 182, 218, 182).to(0)
336
-
337
- print(net(x).shape)
338
- print(net)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/resnet3d.py DELETED
@@ -1,256 +0,0 @@
1
- """
2
- Simplified from torchvision.models.video.r3d_18. The citation information is
3
- shown below.
4
-
5
- @article{DBLP:journals/corr/abs-1711-11248,
6
- author = {Du Tran and
7
- Heng Wang and
8
- Lorenzo Torresani and
9
- Jamie Ray and
10
- Yann LeCun and
11
- Manohar Paluri},
12
- title = {A Closer Look at Spatiotemporal Convolutions for Action Recognition},
13
- journal = {CoRR},
14
- volume = {abs/1711.11248},
15
- year = {2017},
16
- url = {http://arxiv.org/abs/1711.11248},
17
- archivePrefix = {arXiv},
18
- eprint = {1711.11248},
19
- timestamp = {Mon, 13 Aug 2018 16:46:39 +0200},
20
- biburl = {https://dblp.org/rec/journals/corr/abs-1711-11248.bib},
21
- bibsource = {dblp computer science bibliography, https://dblp.org}
22
- }
23
- """
24
-
25
- import torch.nn as nn
26
-
27
-
28
- class Conv3DSimple(nn.Conv3d):
29
- def __init__(self,
30
- in_planes,
31
- out_planes,
32
- midplanes=None,
33
- stride=1,
34
- padding=1):
35
-
36
- super().__init__(
37
- in_channels=in_planes,
38
- out_channels=out_planes,
39
- kernel_size=(3, 3, 3),
40
- stride=stride,
41
- padding=padding,
42
- bias=False)
43
-
44
- @staticmethod
45
- def get_downsample_stride(stride):
46
- return stride, stride, stride
47
-
48
-
49
- class BasicBlock(nn.Module):
50
-
51
- expansion = 1
52
-
53
- def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
54
- midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
55
-
56
- super(BasicBlock, self).__init__()
57
- self.conv1 = nn.Sequential(
58
- conv_builder(inplanes, planes, midplanes, stride),
59
- nn.BatchNorm3d(planes),
60
- nn.ReLU(inplace=True)
61
- )
62
- self.conv2 = nn.Sequential(
63
- conv_builder(planes, planes, midplanes),
64
- nn.BatchNorm3d(planes)
65
- )
66
- self.relu = nn.ReLU(inplace=True)
67
- self.downsample = downsample
68
- self.stride = stride
69
-
70
- def forward(self, x):
71
- residual = x
72
-
73
- out = self.conv1(x)
74
- out = self.conv2(out)
75
- if self.downsample is not None:
76
- residual = self.downsample(x)
77
-
78
- out += residual
79
- out = self.relu(out)
80
-
81
- return out
82
-
83
-
84
- class Bottleneck(nn.Module):
85
- expansion = 4
86
-
87
- def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
88
-
89
- super(Bottleneck, self).__init__()
90
- midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
91
-
92
- # 1x1x1
93
- self.conv1 = nn.Sequential(
94
- nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
95
- nn.BatchNorm3d(planes),
96
- nn.ReLU(inplace=True)
97
- )
98
- # Second kernel
99
- self.conv2 = nn.Sequential(
100
- conv_builder(planes, planes, midplanes, stride),
101
- nn.BatchNorm3d(planes),
102
- nn.ReLU(inplace=True)
103
- )
104
-
105
- # 1x1x1
106
- self.conv3 = nn.Sequential(
107
- nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
108
- nn.BatchNorm3d(planes * self.expansion)
109
- )
110
- self.relu = nn.ReLU(inplace=True)
111
- self.downsample = downsample
112
- self.stride = stride
113
-
114
- def forward(self, x):
115
- residual = x
116
-
117
- out = self.conv1(x)
118
- out = self.conv2(out)
119
- out = self.conv3(out)
120
-
121
- if self.downsample is not None:
122
- residual = self.downsample(x)
123
-
124
- out += residual
125
- out = self.relu(out)
126
-
127
- return out
128
-
129
-
130
- class BasicStem(nn.Sequential):
131
- """The default conv-batchnorm-relu stem
132
- """
133
- def __init__(self):
134
- super(BasicStem, self).__init__(
135
- nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2),
136
- padding=(3, 3, 3), bias=False),
137
- nn.BatchNorm3d(64),
138
- nn.ReLU(inplace=True))
139
-
140
-
141
- class VideoResNet(nn.Module):
142
-
143
- def __init__(self, block, conv_makers, layers,
144
- stem, num_classes=2,
145
- zero_init_residual=False):
146
- """Generic resnet video generator.
147
- Args:
148
- block (nn.Module): resnet building block
149
- conv_makers (list(functions)): generator function for each layer
150
- layers (List[int]): number of blocks per layer
151
- stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
152
- num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
153
- zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
154
- """
155
- super(VideoResNet, self).__init__()
156
- self.inplanes = 64
157
-
158
- self.stem = stem()
159
-
160
- self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
161
- self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
162
- self.layer3 = self._make_layer(block, conv_makers[2], 192, layers[2], stride=2)
163
- self.layer4 = self._make_layer(block, conv_makers[3], 256, layers[3], stride=2)
164
-
165
- self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
166
- # self.fc = nn.Linear(256 * block.expansion, num_classes)
167
-
168
- # init weights
169
- self._initialize_weights()
170
-
171
- if zero_init_residual:
172
- for m in self.modules():
173
- if isinstance(m, Bottleneck):
174
- nn.init.constant_(m.bn3.weight, 0)
175
-
176
- def forward(self, x):
177
- x = self.stem(x)
178
-
179
- x = self.layer1(x)
180
- x = self.layer2(x)
181
- x = self.layer3(x)
182
- x = self.layer4(x)
183
-
184
- x = self.avgpool(x)
185
- # Flatten the layer to fc
186
- x = x.flatten(1)
187
- # x = self.fc(x)
188
-
189
- return x
190
-
191
- def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
192
- downsample = None
193
-
194
- if stride != 1 or self.inplanes != planes * block.expansion:
195
- ds_stride = conv_builder.get_downsample_stride(stride)
196
- downsample = nn.Sequential(
197
- nn.Conv3d(self.inplanes, planes * block.expansion,
198
- kernel_size=1, stride=ds_stride, bias=False),
199
- nn.BatchNorm3d(planes * block.expansion)
200
- )
201
- layers = []
202
- layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
203
-
204
- self.inplanes = planes * block.expansion
205
- for i in range(1, blocks):
206
- layers.append(block(self.inplanes, planes, conv_builder))
207
-
208
- return nn.Sequential(*layers)
209
-
210
- def _initialize_weights(self):
211
- for m in self.modules():
212
- if isinstance(m, nn.Conv3d):
213
- nn.init.kaiming_normal_(m.weight, mode='fan_out',
214
- nonlinearity='relu')
215
- if m.bias is not None:
216
- nn.init.constant_(m.bias, 0)
217
- elif isinstance(m, nn.BatchNorm3d):
218
- nn.init.constant_(m.weight, 1)
219
- nn.init.constant_(m.bias, 0)
220
- elif isinstance(m, nn.Linear):
221
- nn.init.normal_(m.weight, 0, 0.01)
222
- nn.init.constant_(m.bias, 0)
223
-
224
-
225
- def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
226
- model = VideoResNet(**kwargs)
227
-
228
- return model
229
-
230
-
231
- def r3d_18(pretrained=False, progress=True, **kwargs):
232
- """Construct 18 layer Resnet3D model as in
233
- https://arxiv.org/abs/1711.11248
234
- Args:
235
- pretrained (bool): If True, returns a model pre-trained on Kinetics-400
236
- progress (bool): If True, displays a progress bar of the download to stderr
237
- Returns:
238
- nn.Module: R3D-18 network
239
- """
240
-
241
- return _video_resnet('r3d_18',
242
- pretrained, progress,
243
- block=BasicBlock,
244
- conv_makers=[Conv3DSimple] * 4,
245
- layers=[2, 2, 2, 2],
246
- stem=BasicStem, **kwargs)
247
-
248
-
249
- if __name__ == '__main__':
250
- """ ... """
251
- import torch
252
-
253
- net = r3d_18().to('cuda:1')
254
- x = torch.zeros(8, 1, 182, 218, 182).to('cuda:1')
255
-
256
- print(net(x).shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/resnet_img_model.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import sys
4
- from icecream import ic
5
- # sys.path.append('/home/skowshik/ADRD_repo/adrd_tool/adrd/')
6
- from .net_resnet3d import r3d_18
7
- # from dev.data.dataset_csv import CSVDataset
8
-
9
-
10
- class ResNetModel(nn.Module):
11
- ''' ... '''
12
- def __init__(
13
- self,
14
- tgt_modalities,
15
- mri_feature = 'img_MRI_T1',
16
- ):
17
- ''' ... '''
18
- super().__init__()
19
-
20
- self.mri_feature = mri_feature
21
-
22
- self.img_net_ = r3d_18()
23
-
24
- # self.modules_emb_src = nn.Sequential(
25
- # nn.BatchNorm1d(9),
26
- # nn.Linear(9, d_model)
27
- # )
28
-
29
- # classifiers (binary only)
30
- self.modules_cls = nn.ModuleDict()
31
- for k, info in tgt_modalities.items():
32
- if info['type'] == 'categorical' and info['num_categories'] == 2:
33
- # categorical
34
- self.modules_cls[k] = nn.Linear(64, 1)
35
-
36
- else:
37
- # unrecognized
38
- raise ValueError
39
-
40
- def forward(self, x):
41
- ''' ... '''
42
- tgt_iter = self.modules_cls.keys()
43
-
44
- img_x_batch = x[self.mri_feature]
45
- img_out = self.img_net_(img_x_batch)
46
-
47
- # ic(img_out.shape)
48
-
49
- # run linear classifiers
50
- out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)]
51
- out = torch.stack(out, dim=1)
52
-
53
- # ic(out.shape)
54
-
55
- # out to dict
56
- out = {k: out[:, i] for i, k in enumerate(tgt_iter)}
57
-
58
- return out
59
-
60
-
61
- if __name__ == '__main__':
62
- ''' for testing purpose only '''
63
- # import torch
64
- # import numpy as np
65
-
66
- # seed = 0
67
- # print('Loading training dataset ... ')
68
- # dat_trn = CSVDataset(mode=0, split=[1, 700], seed=seed)
69
- # print(len(dat_trn))
70
- # tgt_modalities = dat_trn.label_modalities
71
- # net = ResNetModel(tgt_modalities).to('cuda')
72
- # x = dat_trn.features
73
- # x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])).to('cuda') for k in x[0]}
74
- # ic(x)
75
-
76
-
77
- # # print(net(x).shape)
78
- # print(net(x))
79
-
80
-
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/selfattention.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) MONAI Consortium
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- # Unless required by applicable law or agreed to in writing, software
7
- # distributed under the License is distributed on an "AS IS" BASIS,
8
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
- # See the License for the specific language governing permissions and
10
- # limitations under the License.
11
- from monai.utils import optional_import
12
- import torch
13
- import torch.nn as nn
14
-
15
-
16
- Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
17
-
18
-
19
- class SABlock(nn.Module):
20
- """
21
- A self-attention block, based on: "Dosovitskiy et al.,
22
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
23
- """
24
-
25
- def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
26
- """
27
- Args:
28
- hidden_size: dimension of hidden layer.
29
- num_heads: number of attention heads.
30
- dropout_rate: faction of the input units to drop.
31
- qkv_bias: bias term for the qkv linear layer.
32
-
33
- """
34
-
35
- super().__init__()
36
-
37
- if not (0 <= dropout_rate <= 1):
38
- raise ValueError("dropout_rate should be between 0 and 1.")
39
-
40
- if hidden_size % num_heads != 0:
41
- raise ValueError("hidden size should be divisible by num_heads.")
42
-
43
- self.num_heads = num_heads
44
- self.out_proj = nn.Linear(hidden_size, hidden_size)
45
- self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
46
- self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
47
- self.out_rearrange = Rearrange("b h l d -> b l (h d)")
48
- self.drop_output = nn.Dropout(dropout_rate)
49
- self.drop_weights = nn.Dropout(dropout_rate)
50
- self.head_dim = hidden_size // num_heads
51
- self.scale = self.head_dim**-0.5
52
-
53
- def forward(self, x):
54
- output = self.input_rearrange(self.qkv(x))
55
- q, k, v = output[0], output[1], output[2]
56
- att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
57
- att_mat = self.drop_weights(att_mat)
58
- x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
59
- x = self.out_rearrange(x)
60
- x = self.out_proj(x)
61
- x = self.drop_output(x)
62
- return x, att_mat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/transformer.py DELETED
@@ -1,268 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from .. import nn
4
- # from ..nn import ImagingModelWrapper
5
- from .net_resnet3d import r3d_18
6
- from typing import Any, Type
7
- import math
8
- Tensor = Type[torch.Tensor]
9
- from icecream import ic
10
- ic.disable()
11
-
12
- class Transformer(torch.nn.Module):
13
- ''' ... '''
14
- def __init__(self,
15
- src_modalities: dict[str, dict[str, Any]],
16
- tgt_modalities: dict[str, dict[str, Any]],
17
- d_model: int,
18
- nhead: int,
19
- num_encoder_layers: int = 1,
20
- num_decoder_layers: int = 1,
21
- device: str = 'cpu',
22
- cuda_devices: list = [3],
23
- img_net: str | None = None,
24
- layers: int = 3,
25
- img_size: int | None = 128,
26
- patch_size: int | None = 16,
27
- imgnet_ckpt: str | None = None,
28
- train_imgnet: bool = False,
29
- fusion_stage: str = 'middle',
30
- ) -> None:
31
- ''' ... '''
32
- super().__init__()
33
-
34
- self.d_model = d_model
35
- self.nhead = nhead
36
- self.num_encoder_layers = num_encoder_layers
37
- self.num_decoder_layers = num_decoder_layers
38
- self.img_net = img_net
39
- self.img_size = img_size
40
- self.patch_size = patch_size
41
- self.imgnet_ckpt = imgnet_ckpt
42
- self.train_imgnet = train_imgnet
43
- self.layers = layers
44
- self.src_modalities = src_modalities
45
- self.tgt_modalities = tgt_modalities
46
- self.device = device
47
- self.fusion_stage = fusion_stage
48
-
49
- # embedding modules for source
50
-
51
- self.modules_emb_src = torch.nn.ModuleDict()
52
- print('Downsample layers: ', self.layers)
53
- self.img_model = nn.ImagingModelWrapper(arch=self.img_net, img_size=self.img_size, patch_size=self.patch_size, ckpt_path=self.imgnet_ckpt, train_backbone=self.train_imgnet, layers=self.layers, out_dim=self.d_model, device=self.device, fusion_stage=self.fusion_stage)
54
-
55
- for k, info in src_modalities.items():
56
- # ic(k)
57
- # for key, val in info.items():
58
- # ic(key, val)
59
- if info['type'] == 'categorical':
60
- self.modules_emb_src[k] = torch.nn.Embedding(info['num_categories'], d_model)
61
- elif info['type'] == 'numerical':
62
- self.modules_emb_src[k] = torch.nn.Sequential(
63
- torch.nn.BatchNorm1d(info['shape'][0]),
64
- torch.nn.Linear(info['shape'][0], d_model)
65
- )
66
- elif info['type'] == 'imaging':
67
- # print(info['shape'], info['img_shape'])
68
- if self.img_net:
69
- self.modules_emb_src[k] = self.img_model
70
-
71
- else:
72
- # unrecognized
73
- raise ValueError('{} is an unrecognized data modality'.format(k))
74
-
75
- # positional encoding
76
- self.pe = PositionalEncoding(d_model)
77
-
78
- # auxiliary embedding vectors for targets
79
- self.emb_aux = torch.nn.Parameter(
80
- torch.zeros(len(tgt_modalities), 1, d_model),
81
- requires_grad = True,
82
- )
83
-
84
- # transformer
85
- enc = torch.nn.TransformerEncoderLayer(
86
- self.d_model, self.nhead,
87
- dim_feedforward = self.d_model,
88
- activation = 'gelu',
89
- dropout = 0.3,
90
- )
91
- self.transformer = torch.nn.TransformerEncoder(enc, self.num_encoder_layers)
92
-
93
-
94
- # classifiers (binary only)
95
- self.modules_cls = torch.nn.ModuleDict()
96
- for k, info in tgt_modalities.items():
97
- if info['type'] == 'categorical' and info['num_categories'] == 2:
98
- self.modules_cls[k] = torch.nn.Linear(d_model, 1)
99
- else:
100
- # unrecognized
101
- raise ValueError
102
-
103
- # for n,p in self.named_parameters():
104
- # print(n, p.requires_grad)
105
-
106
- def forward(self,
107
- x: dict[str, Tensor],
108
- mask: dict[str, Tensor],
109
- # x_img: dict[str, Tensor] | Any = None,
110
- skip_embedding: dict[str, bool] | None = None,
111
- return_out_emb: bool = False,
112
- ) -> dict[str, Tensor]:
113
- """ ... """
114
-
115
- out_emb = self.forward_emb(x, mask, skip_embedding)
116
- if self.fusion_stage == "late":
117
- out_emb = {k: v for k,v in out_emb.items() if "img_MRI" not in k}
118
- img_out_emb = {k: v for k,v in out_emb.items() if "img_MRI" in k}
119
- # for k,v in out_emb.items():
120
- # print(k, v.size())
121
- mask_nonimg = {k: v for k,v in mask.items() if "img_MRI" not in k}
122
- out_trf = self.forward_trf(out_emb, mask_nonimg) # (8,128) + (8,50,128)
123
- # print("out_trf: ", out_trf.size())
124
- out_trf = torch.concatenate()
125
- else:
126
- out_trf = self.forward_trf(out_emb, mask)
127
-
128
- out_cls = self.forward_cls(out_trf)
129
-
130
- if return_out_emb:
131
- return out_emb, out_cls
132
- return out_cls
133
-
134
- def forward_emb(self,
135
- x: dict[str, Tensor],
136
- mask: dict[str, Tensor],
137
- skip_embedding: dict[str, bool] | None = None,
138
- # x_img: dict[str, Tensor] | Any = None,
139
- ) -> dict[str, Tensor]:
140
- """ ... """
141
- # print("-------forward_emb--------")
142
- out_emb = dict()
143
- for k in self.modules_emb_src.keys():
144
- if skip_embedding is None or k not in skip_embedding or not skip_embedding[k]:
145
- if "img_MRI" in k:
146
- # print("img_MRI in ", k)
147
- if torch.all(mask[k]):
148
- if "swinunetr" in self.img_net.lower() and self.fusion_stage == 'late':
149
- out_emb[k] = torch.zeros((1,768,4,4,4))
150
- else:
151
- if 'cuda' in self.device:
152
- device = x[k].device
153
- # print(device)
154
- else:
155
- device = self.device
156
- out_emb[k] = torch.zeros((mask[k].shape[0], self.d_model)).to(device, non_blocking=True)
157
- # print("mask is True, out_emb[k]: ", out_emb[k].size())
158
- else:
159
- # print("calling modules_emb_src...")
160
- out_emb[k] = self.modules_emb_src[k](x[k])
161
- # print("mask is False, out_emb[k]: ", out_emb[k].size())
162
-
163
- else:
164
- out_emb[k] = self.modules_emb_src[k](x[k])
165
-
166
- # out_emb[k] = self.modules_emb_src[k](x[k])
167
- else:
168
- out_emb[k] = x[k]
169
- return out_emb
170
-
171
- def forward_trf(self,
172
- out_emb: dict[str, Tensor],
173
- mask: dict[str, Tensor],
174
- ) -> dict[str, Tensor]:
175
- """ ... """
176
- # print('-----------forward_trf----------')
177
- N = len(next(iter(out_emb.values()))) # batch size
178
- S = len(self.modules_emb_src) # number of sources
179
- T = len(self.modules_cls) # number of targets
180
- if self.fusion_stage == 'late':
181
- src_iter = [k for k in self.modules_emb_src.keys() if "img_MRI" not in k]
182
- S = len(src_iter) # number of sources
183
-
184
- else:
185
- src_iter = self.modules_emb_src.keys()
186
- tgt_iter = self.modules_cls.keys()
187
-
188
- emb_src = torch.stack([o for o in out_emb.values()], dim=0)
189
- # print('emb_src: ', emb_src.size())
190
-
191
- self.pe.index = -1
192
- emb_src = self.pe(emb_src)
193
- # print('emb_src + pe: ', emb_src.size())
194
-
195
- # target embedding
196
- # print('emb_aux: ', self.emb_aux.size())
197
- emb_tgt = self.emb_aux.repeat(1, N, 1)
198
- # print('emb_tgt: ', emb_tgt.size())
199
-
200
- # concatenate source embeddings and target embeddings
201
- emb_all = torch.concatenate((emb_tgt, emb_src), dim=0)
202
-
203
- # combine masks
204
- mask_src = [mask[k] for k in src_iter]
205
- mask_src = torch.stack(mask_src, dim=1)
206
-
207
- # target masks
208
- mask_tgt = torch.zeros((N, T), dtype=torch.bool, device=self.emb_aux.device)
209
-
210
- # concatenate source masks and target masks
211
- mask_all = torch.concatenate((mask_tgt, mask_src), dim=1)
212
-
213
- # repeat mask_all to fit transformer
214
- mask_all = mask_all.unsqueeze(1).expand(-1, S + T, -1).repeat(self.nhead, 1, 1)
215
-
216
- # run transformer
217
- out_trf = self.transformer(
218
- src = emb_all,
219
- mask = mask_all,
220
- )[0]
221
- # print('out_trf: ', out_trf.size())
222
- # out_trf = {k: out_trf[i] for i, k in enumerate(tgt_iter)}
223
- return out_trf
224
-
225
- def forward_cls(self,
226
- out_trf: dict[str, Tensor],
227
- ) -> dict[str, Tensor]:
228
- """ ... """
229
- tgt_iter = self.modules_cls.keys()
230
- out_cls = {k: self.modules_cls[k](out_trf).squeeze(1) for k in tgt_iter}
231
- return out_cls
232
-
233
- class PositionalEncoding(torch.nn.Module):
234
-
235
- def __init__(self,
236
- d_model: int,
237
- max_len: int = 512
238
- ):
239
- """ ... """
240
- super().__init__()
241
- position = torch.arange(max_len).unsqueeze(1)
242
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
243
- pe = torch.zeros(max_len, 1, d_model)
244
- pe[:, 0, 0::2] = torch.sin(position * div_term)
245
- pe[:, 0, 1::2] = torch.cos(position * div_term)
246
- self.register_buffer('pe', pe)
247
- self.index = -1
248
-
249
- def forward(self, x: Tensor, pe_type: str = 'non_img') -> Tensor:
250
- """
251
- Arguments:
252
- x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
253
- """
254
- # print('pe: ', self.pe.size())
255
- # print('x: ', x.size())
256
- if pe_type == 'img':
257
- self.index += 1
258
- return x + self.pe[self.index]
259
- else:
260
- self.index += 1
261
- return x + self.pe[self.index:x.size(0)+self.index]
262
-
263
-
264
- if __name__ == '__main__':
265
- ''' for testing purpose only '''
266
- pass
267
-
268
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/unet.py DELETED
@@ -1,232 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torchvision
5
- from torchvision import models
6
- from torch.nn import init
7
- import torch.nn.functional as F
8
- from icecream import ic
9
-
10
-
11
- class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
12
- def _check_input_dim(self, input):
13
-
14
- if input.dim() != 5:
15
- raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
16
- #super(ContBatchNorm3d, self)._check_input_dim(input)
17
-
18
- def forward(self, input):
19
- self._check_input_dim(input)
20
- return F.batch_norm(
21
- input, self.running_mean, self.running_var, self.weight, self.bias,
22
- True, self.momentum, self.eps)
23
-
24
-
25
- class LUConv(nn.Module):
26
- def __init__(self, in_chan, out_chan, act):
27
- super(LUConv, self).__init__()
28
- self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1)
29
- self.bn1 = ContBatchNorm3d(out_chan)
30
-
31
- if act == 'relu':
32
- self.activation = nn.ReLU(out_chan)
33
- elif act == 'prelu':
34
- self.activation = nn.PReLU(out_chan)
35
- elif act == 'elu':
36
- self.activation = nn.ELU(inplace=True)
37
- else:
38
- raise
39
-
40
- def forward(self, x):
41
- out = self.activation(self.bn1(self.conv1(x)))
42
- return out
43
-
44
-
45
- def _make_nConv(in_channel, depth, act, double_chnnel=False):
46
- if double_chnnel:
47
- layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act)
48
- layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act)
49
- else:
50
- layer1 = LUConv(in_channel, 32*(2**depth),act)
51
- layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act)
52
-
53
- return nn.Sequential(layer1,layer2)
54
-
55
-
56
- class DownTransition(nn.Module):
57
- def __init__(self, in_channel,depth, act):
58
- super(DownTransition, self).__init__()
59
- self.ops = _make_nConv(in_channel, depth,act)
60
- self.maxpool = nn.MaxPool3d(2)
61
- self.current_depth = depth
62
-
63
- def forward(self, x):
64
- if self.current_depth == 3:
65
- out = self.ops(x)
66
- out_before_pool = out
67
- else:
68
- out_before_pool = self.ops(x)
69
- out = self.maxpool(out_before_pool)
70
- return out, out_before_pool
71
-
72
- class UpTransition(nn.Module):
73
- def __init__(self, inChans, outChans, depth,act):
74
- super(UpTransition, self).__init__()
75
- self.depth = depth
76
- self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2)
77
- self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True)
78
-
79
- def forward(self, x, skip_x):
80
- out_up_conv = self.up_conv(x)
81
- concat = torch.cat((out_up_conv,skip_x),1)
82
- out = self.ops(concat)
83
- return out
84
-
85
- class OutputTransition(nn.Module):
86
- def __init__(self, inChans, n_labels):
87
-
88
- super(OutputTransition, self).__init__()
89
- self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1)
90
- self.sigmoid = nn.Sigmoid()
91
-
92
- def forward(self, x):
93
- out = self.sigmoid(self.final_conv(x))
94
- return out
95
-
96
- class ConvLayer(nn.Module):
97
- def __init__(self, in_channels, out_channels, drop_rate, kernel, pooling, BN=True, relu_type='leaky'):
98
- super().__init__()
99
- kernel_size, kernel_stride, kernel_padding = kernel
100
- pool_kernel, pool_stride, pool_padding = pooling
101
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, kernel_stride, kernel_padding, bias=False)
102
- self.pooling = nn.MaxPool3d(pool_kernel, pool_stride, pool_padding)
103
- self.BN = nn.BatchNorm3d(out_channels)
104
- self.relu = nn.LeakyReLU(inplace=False) if relu_type=='leaky' else nn.ReLU(inplace=False)
105
- self.dropout = nn.Dropout(drop_rate, inplace=False)
106
-
107
- def forward(self, x):
108
- x = self.conv(x)
109
- x = self.pooling(x)
110
- x = self.BN(x)
111
- x = self.relu(x)
112
- x = self.dropout(x)
113
- return x
114
-
115
- class AttentionModule(nn.Module):
116
- def __init__(self, in_channels, out_channels, drop_rate=0.1):
117
- super(AttentionModule, self).__init__()
118
- self.conv = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=False)
119
- self.attention = ConvLayer(in_channels, out_channels, drop_rate, (1, 1, 0), (1, 1, 0))
120
-
121
- def forward(self, x, return_attention=True):
122
- feats = self.conv(x)
123
- att = F.softmax(self.attention(x))
124
-
125
- out = feats * att
126
-
127
- if return_attention:
128
- return att, out
129
-
130
- return out
131
-
132
- class UNet3D(nn.Module):
133
- # the number of convolutions in each layer corresponds
134
- # to what is in the actual prototxt, not the intent
135
- def __init__(self, n_class=1, act='relu', pretrained=False, input_size=(1,1,182,218,182), attention=False, drop_rate=0.1, blocks=4):
136
- super(UNet3D, self).__init__()
137
-
138
- self.blocks = blocks
139
- self.down_tr64 = DownTransition(1,0,act)
140
- self.down_tr128 = DownTransition(64,1,act)
141
- self.down_tr256 = DownTransition(128,2,act)
142
- self.down_tr512 = DownTransition(256,3,act)
143
-
144
- self.up_tr256 = UpTransition(512, 512,2,act)
145
- self.up_tr128 = UpTransition(256,256, 1,act)
146
- self.up_tr64 = UpTransition(128,128,0,act)
147
- self.out_tr = OutputTransition(64, 1)
148
-
149
- self.pretrained = pretrained
150
- self.attention = attention
151
- if pretrained:
152
- print("Using image pretrained model checkpoint")
153
- weight_dir = '/home/skowshik/ADRD_repo/img_pretrained_ckpt/Genesis_Chest_CT.pt'
154
- checkpoint = torch.load(weight_dir)
155
- state_dict = checkpoint['state_dict']
156
- unParalled_state_dict = {}
157
- for key in state_dict.keys():
158
- unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
159
- self.load_state_dict(unParalled_state_dict)
160
- del self.up_tr256
161
- del self.up_tr128
162
- del self.up_tr64
163
- del self.out_tr
164
-
165
- if self.blocks == 5:
166
- self.down_tr1024 = DownTransition(512,4,act)
167
-
168
-
169
- # self.conv1 = nn.Conv3d(512, 256, 1, 1, 0, bias=False)
170
- # self.conv2 = nn.Conv3d(256, 128, 1, 1, 0, bias=False)
171
- # self.conv3 = nn.Conv3d(128, 64, 1, 1, 0, bias=False)
172
-
173
- if attention:
174
- self.attention_module = AttentionModule(1024 if self.blocks==5 else 512, n_class, drop_rate=drop_rate)
175
- # Output.
176
- self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
177
-
178
- dummy_inp = torch.rand(input_size)
179
- dummy_feats = self.forward(dummy_inp, stage='get_features')
180
- dummy_feats = dummy_feats[0]
181
- self.in_features = list(dummy_feats.shape)
182
- ic(self.in_features)
183
-
184
- self._init_weights()
185
-
186
- def _init_weights(self):
187
- if not self.pretrained:
188
- for m in self.modules():
189
- if isinstance(m, nn.Conv3d):
190
- init.kaiming_normal_(m.weight)
191
- elif isinstance(m, ContBatchNorm3d):
192
- init.constant_(m.weight, 1)
193
- init.constant_(m.bias, 0)
194
- elif isinstance(m, nn.Linear):
195
- init.kaiming_normal_(m.weight)
196
- init.constant_(m.bias, 0)
197
- elif self.attention:
198
- for m in self.attention_module.modules():
199
- if isinstance(m, nn.Conv3d):
200
- init.kaiming_normal_(m.weight)
201
- elif isinstance(m, nn.BatchNorm3d):
202
- init.constant_(m.weight, 1)
203
- init.constant_(m.bias, 0)
204
- else:
205
- pass
206
- # Zero initialize the last batchnorm in each residual branch.
207
- # for m in self.modules():
208
- # if isinstance(m, BottleneckBlock):
209
- # init.constant_(m.out_conv.bn.weight, 0)
210
-
211
- def forward(self, x, stage='normal', attention=False):
212
- ic('backbone forward')
213
- self.out64, self.skip_out64 = self.down_tr64(x)
214
- self.out128,self.skip_out128 = self.down_tr128(self.out64)
215
- self.out256,self.skip_out256 = self.down_tr256(self.out128)
216
- self.out512,self.skip_out512 = self.down_tr512(self.out256)
217
- if self.blocks == 5:
218
- self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
219
- ic(self.out1024.shape)
220
- # self.out = self.conv1(self.out512)
221
- # self.out = self.conv2(self.out)
222
- # self.out = self.conv3(self.out)
223
- # self.out = self.conv(self.out)
224
- ic(hasattr(self, 'attention_module'))
225
- if hasattr(self, 'attention_module'):
226
- att, feats = self.attention_module(self.out1024 if self.blocks==5 else self.out512)
227
- else:
228
- feats = self.out1024 if self.blocks==5 else self.out512
229
- ic(feats.shape)
230
- if attention:
231
- return att, feats
232
- return feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/unet_3d.py DELETED
@@ -1,63 +0,0 @@
1
- import sys
2
- sys.path.append('..')
3
- # from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D
4
- import torch
5
- import torch.nn as nn
6
- from torchvision import models
7
- import torch.nn.functional as F
8
- # from . import UNet3D
9
- from .unet import UNet3D
10
- from icecream import ic
11
-
12
-
13
- class UNet3DBase(nn.Module):
14
- def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4):
15
- super(UNet3DBase, self).__init__()
16
- model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks)
17
-
18
- self.blocks = blocks
19
-
20
- self.down_tr64 = model.down_tr64
21
- self.down_tr128 = model.down_tr128
22
- self.down_tr256 = model.down_tr256
23
- self.down_tr512 = model.down_tr512
24
- if self.blocks == 5:
25
- self.down_tr1024 = model.down_tr1024
26
- # self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512])
27
-
28
- self.in_features = model.in_features
29
- # ic(attention)
30
- if attention:
31
- self.attention_module = model.attention_module
32
- # self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate)
33
- # self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
34
-
35
- def forward(self, x, stage='normal', attention=False):
36
- # ic('UNet3DBase forward')
37
- self.out64, self.skip_out64 = self.down_tr64(x)
38
- # ic(self.out64.shape, self.skip_out64.shape)
39
- self.out128,self.skip_out128 = self.down_tr128(self.out64)
40
- # ic(self.out128.shape, self.skip_out128.shape)
41
- self.out256,self.skip_out256 = self.down_tr256(self.out128)
42
- # ic(self.out256.shape, self.skip_out256.shape)
43
- self.out512,self.skip_out512 = self.down_tr512(self.out256)
44
- # ic(self.out512.shape, self.skip_out512.shape)
45
- if self.blocks == 5:
46
- self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
47
- # ic(self.out1024.shape, self.skip_out1024.shape)
48
- # ic(hasattr(self, 'attention_module'))
49
- if hasattr(self, 'attention_module'):
50
- att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512)
51
- else:
52
- feats = self.out1024 if self.blocks == 5 else self.out512
53
- # ic(feats.shape)
54
- if attention:
55
- return att, feats
56
- return feats
57
-
58
- # self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
59
- # self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
60
- # self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
61
- # self.out = self.out_tr(self.out_up_64)
62
-
63
- # return self.out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/unet_img_model.py DELETED
@@ -1,211 +0,0 @@
1
- from pyexpat import features
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.cuda.amp import autocast
6
- import numpy as np
7
- import re
8
- from icecream import ic
9
- import math
10
- import torch.nn.utils.weight_norm as weightNorm
11
-
12
- # from . import UNet3DBase
13
- from .unet_3d import UNet3DBase
14
-
15
-
16
- def init_weights(m):
17
- classname = m.__class__.__name__
18
- if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
19
- nn.init.kaiming_uniform_(m.weight)
20
- nn.init.zeros_(m.bias)
21
- elif classname.find('BatchNorm') != -1:
22
- nn.init.normal_(m.weight, 1.0, 0.02)
23
- nn.init.zeros_(m.bias)
24
- elif classname.find('Linear') != -1:
25
- nn.init.xavier_normal_(m.weight)
26
- nn.init.zeros_(m.bias)
27
-
28
- class feat_classifier(nn.Module):
29
- def __init__(self, class_num, bottleneck_dim=256, type="linear"):
30
- super(feat_classifier, self).__init__()
31
- self.type = type
32
- # if type in ['conv', 'gap'] and len(bottleneck_dim) > 3:
33
- # bottleneck_dim = bottleneck_dim[-3:]
34
- ic(bottleneck_dim)
35
- if type == 'wn':
36
- self.layer = weightNorm(
37
- nn.Linear(bottleneck_dim[1:], class_num), name="weight")
38
- # self.fc.apply(init_weights)
39
- elif type == 'gap':
40
- if len(bottleneck_dim) > 3:
41
- bottleneck_dim = bottleneck_dim[-3:]
42
- self.layer = nn.AvgPool3d(bottleneck_dim, stride=(1,1,1))
43
- elif type == 'conv':
44
- if len(bottleneck_dim) > 3:
45
- bottleneck_dim = bottleneck_dim[-4:]
46
- ic(bottleneck_dim)
47
- self.layer = nn.Conv3d(bottleneck_dim[0], class_num, kernel_size=bottleneck_dim[1:])
48
- ic(self.layer)
49
- else:
50
- print('bottleneck dim: ', bottleneck_dim)
51
- self.layer = nn.Sequential(
52
- torch.nn.Flatten(start_dim=1, end_dim=-1),
53
- nn.Linear(math.prod(bottleneck_dim), class_num)
54
- )
55
- self.layer.apply(init_weights)
56
-
57
- def forward(self, x):
58
- # print('=> feat_classifier forward')
59
- # ic(x.size())
60
- x = self.layer(x)
61
- # ic(x.size())
62
- if self.type in ['gap','conv']:
63
- x = torch.squeeze(x)
64
- if len(x.shape) < 2:
65
- x = torch.unsqueeze(x,0)
66
- # print('returning x: ', x.size())
67
- return x
68
-
69
- class ImageModel(nn.Module):
70
- """
71
- Empirical Risk Minimization (ERM)
72
- """
73
-
74
- def __init__(
75
- self,
76
- counts=None,
77
- classifier='gap',
78
- accum_iter=8,
79
- save_emb=False,
80
- # ssl,
81
- num_classes=1,
82
- load_img_ckpt=False,
83
- ):
84
- super(ImageModel, self).__init__()
85
- if counts is not None:
86
- if isinstance(counts[0], list):
87
- counts = np.stack(counts, axis=0).sum(axis=0)
88
- print('counts: ', counts)
89
- total = np.sum(counts)
90
- print(total/counts)
91
- self.weight = total/torch.FloatTensor(counts)
92
- else:
93
- total = sum(counts)
94
- self.weight = torch.FloatTensor([total/c for c in counts])
95
- else:
96
- self.weight = None
97
- print('weight: ', self.weight)
98
- # device = torch.device(f'cuda:{args.gpu_id}' if args.gpu_id is not None else 'cpu')
99
- self.criterion = nn.CrossEntropyLoss(weight=self.weight)
100
- # if ssl:
101
- # # add contrastive loss
102
- # # self.ssl_criterion =
103
- # pass
104
-
105
- self.featurizer = UNet3DBase(n_class=num_classes, attention=True, pretrained=load_img_ckpt)
106
- self.classifier = feat_classifier(
107
- num_classes, self.featurizer.in_features, classifier)
108
-
109
- self.network = nn.Sequential(
110
- self.featurizer, self.classifier)
111
- self.accum_iter = accum_iter
112
- self.acc_steps = 0
113
- self.save_embedding = save_emb
114
-
115
- def update(self, minibatches, opt, sch, scaler):
116
- print('--------------def update----------------')
117
- device = list(self.parameters())[0].device
118
- all_x = torch.cat([data[1].to(device).float() for data in minibatches])
119
- all_y = torch.cat([data[2].to(device).long() for data in minibatches])
120
- print('all_x: ', all_x.size())
121
- # all_p = self.predict(all_x)
122
- # all_probs =
123
- label_list = all_y.tolist()
124
- count = float(len(label_list))
125
- ic(count)
126
-
127
- uniques = sorted(list(set(label_list)))
128
- ic(uniques)
129
- counts = [float(label_list.count(i)) for i in uniques]
130
- ic(counts)
131
-
132
- weights = [count / c for c in counts]
133
- ic(weights)
134
-
135
- with autocast():
136
- loss = self.criterion(self.predict(all_x), all_y)
137
- self.acc_steps += 1
138
- print('class: ', loss.item())
139
-
140
- scaler.scale(loss / self.accum_iter).backward()
141
-
142
- if self.acc_steps == self.accum_iter:
143
- scaler.step(opt)
144
- if sch:
145
- sch.step()
146
- scaler.update()
147
- self.zero_grad()
148
- self.acc_steps = 0
149
- torch.cuda.empty_cache()
150
-
151
- del all_x
152
- del all_y
153
- return {'class': loss.item()}, sch
154
-
155
- def forward(self, *args, **kwargs):
156
- return self.network(*args, **kwargs)
157
-
158
- def predict(self, x, stage='normal', attention=False):
159
- # print('network device: ', list(self.network.parameters())[0].device)
160
- # print('x device: ', x.device)
161
- if stage == 'get_features' or self.save_embedding:
162
- feats = self.network[0](x, attention=attention)
163
- output = self.network[1](feats[-1] if attention else feats)
164
- return feats, output
165
- else:
166
- return self.network(x)
167
-
168
- def extract_features(self, x, attention=False):
169
- feats = self.network[0](x, attention=attention)
170
- return feats
171
-
172
- def load_checkpoint(self, state_dict):
173
- try:
174
- self.load_checkpoint_helper(state_dict)
175
- except:
176
- featurizer_dict = {}
177
- net_dict = {}
178
- for key,val in state_dict.items():
179
- if 'featurizer' in key:
180
- featurizer_dict[key] = val
181
- elif 'network' in key:
182
- net_dict[key] = val
183
- self.featurizer.load_state_dict(featurizer_dict)
184
- self.classifier.load_state_dict(net_dict)
185
-
186
- def load_checkpoint_helper(self, state_dict):
187
- try:
188
- self.load_state_dict(state_dict)
189
- print('try: loaded')
190
- except RuntimeError as e:
191
- print('--> except')
192
- if 'Missing key(s) in state_dict:' in str(e):
193
- state_dict = {
194
- key.replace('module.', '', 1): value
195
- for key, value in state_dict.items()
196
- }
197
- state_dict = {
198
- key.replace('featurizer.', '', 1).replace('classifier.','',1): value
199
- for key, value in state_dict.items()
200
- }
201
- state_dict = {
202
- re.sub('network.[0-9].', '', key): value
203
- for key, value in state_dict.items()
204
- }
205
- try:
206
- del state_dict['criterion.weight']
207
- except:
208
- pass
209
- self.load_state_dict(state_dict)
210
-
211
- print('except: loaded')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adrd/nn/vitautoenc.py DELETED
@@ -1,163 +0,0 @@
1
- # Copyright (c) MONAI Consortium
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- # Unless required by applicable law or agreed to in writing, software
7
- # distributed under the License is distributed on an "AS IS" BASIS,
8
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
- # See the License for the specific language governing permissions and
10
- # limitations under the License.
11
-
12
-
13
- from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
14
- from monai.networks.layers import Conv
15
- from monai.utils import ensure_tuple_rep
16
-
17
- from typing import Sequence, Union
18
- import torch
19
- import torch.nn as nn
20
-
21
- from ..nn.blocks import TransformerBlock
22
- from icecream import ic
23
- ic.disable()
24
-
25
- __all__ = ["ViTAutoEnc"]
26
-
27
-
28
- class ViTAutoEnc(nn.Module):
29
- """
30
- Vision Transformer (ViT), based on: "Dosovitskiy et al.,
31
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
32
-
33
- Modified to also give same dimension outputs as the input size of the image
34
- """
35
-
36
- def __init__(
37
- self,
38
- in_channels: int,
39
- img_size: Union[Sequence[int], int],
40
- patch_size: Union[Sequence[int], int],
41
- out_channels: int = 1,
42
- deconv_chns: int = 16,
43
- hidden_size: int = 768,
44
- mlp_dim: int = 3072,
45
- num_layers: int = 12,
46
- num_heads: int = 12,
47
- pos_embed: str = "conv",
48
- dropout_rate: float = 0.0,
49
- spatial_dims: int = 3,
50
- ) -> None:
51
- """
52
- Args:
53
- in_channels: dimension of input channels or the number of channels for input
54
- img_size: dimension of input image.
55
- patch_size: dimension of patch size.
56
- hidden_size: dimension of hidden layer.
57
- out_channels: number of output channels.
58
- deconv_chns: number of channels for the deconvolution layers.
59
- mlp_dim: dimension of feedforward layer.
60
- num_layers: number of transformer blocks.
61
- num_heads: number of attention heads.
62
- pos_embed: position embedding layer type.
63
- dropout_rate: faction of the input units to drop.
64
- spatial_dims: number of spatial dimensions.
65
-
66
- Examples::
67
-
68
- # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
69
- # It will provide an output of same size as that of the input
70
- >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv')
71
-
72
- # for 3-channel with image size of (128,128,128), output will be same size as of input
73
- >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv')
74
-
75
- """
76
-
77
- super().__init__()
78
-
79
- self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
80
- self.spatial_dims = spatial_dims
81
- self.hidden_size = hidden_size
82
-
83
- self.patch_embedding = PatchEmbeddingBlock(
84
- in_channels=in_channels,
85
- img_size=img_size,
86
- patch_size=patch_size,
87
- hidden_size=hidden_size,
88
- num_heads=num_heads,
89
- pos_embed=pos_embed,
90
- dropout_rate=dropout_rate,
91
- spatial_dims=self.spatial_dims,
92
- )
93
- self.blocks = nn.ModuleList(
94
- [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
95
- )
96
- self.norm = nn.LayerNorm(hidden_size)
97
-
98
- new_patch_size = [4] * self.spatial_dims
99
- conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]
100
- # self.conv3d_transpose* is to be compatible with existing 3d model weights.
101
- self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size)
102
- self.conv3d_transpose_1 = conv_trans(
103
- in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size
104
- )
105
-
106
- def forward(self, x, return_emb=False, return_hiddens=False):
107
- """
108
- Args:
109
- x: input tensor must have isotropic spatial dimensions,
110
- such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
111
- """
112
- spatial_size = x.shape[2:]
113
- x = self.patch_embedding(x)
114
- hidden_states_out = []
115
- for blk in self.blocks:
116
- x = blk(x)
117
- hidden_states_out.append(x)
118
- x = self.norm(x)
119
- x = x.transpose(1, 2)
120
- if return_emb:
121
- return x
122
- d = [s // p for s, p in zip(spatial_size, self.patch_size)]
123
- x = torch.reshape(x, [x.shape[0], x.shape[1], *d])
124
- x = self.conv3d_transpose(x)
125
- x = self.conv3d_transpose_1(x)
126
- if return_hiddens:
127
- return x, hidden_states_out
128
- return x
129
-
130
- def get_last_selfattention(self, x):
131
- """
132
- Args:
133
- x: input tensor must have isotropic spatial dimensions,
134
- such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
135
- """
136
- x = self.patch_embedding(x)
137
- ic(x.size())
138
- for i, blk in enumerate(self.blocks):
139
- if i < len(self.blocks) - 1:
140
- x = blk(x)
141
- x.size()
142
- else:
143
- return blk(x, return_attention=True)
144
-
145
- def load(self, ckpt_path, map_location='cpu', checkpoint_key='state_dict'):
146
- """
147
- Args:
148
- ckpt_path: path to the pretrained weights
149
- map_location: device to load the checkpoint on
150
- """
151
- state_dict = torch.load(ckpt_path, map_location=map_location)
152
- ic(state_dict['epoch'], state_dict['train_loss'])
153
- if checkpoint_key in state_dict:
154
- print(f"Take key {checkpoint_key} in provided checkpoint dict")
155
- state_dict = state_dict[checkpoint_key]
156
- # remove `module.` prefix
157
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
158
- # remove `backbone.` prefix induced by multicrop wrapper
159
- state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
160
- msg = self.load_state_dict(state_dict, strict=False)
161
- print('Pretrained weights found at {} and loaded with msg: {}'.format(ckpt_path, msg))
162
-
163
-