DuyTa commited on
Commit
3b80ddd
1 Parent(s): 44168e6

Upload folder using huggingface_hub

Browse files
__pycache__/trainer.cpython-39.pyc ADDED
Binary file (7.55 kB). View file
 
checkpoints/SegTransVAE/Epoch 59-MeanDiceScore0.7753.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:633f9f9bc994612e455727ff7d5b10ff0b7a3fe369749e55673d4c0c0cc1cc50
3
+ size 711206989
checkpoints/SegTransVAE/Epoch 69-MeanDiceScore0.7905.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9ac5e8e0ccb79ac91b2e5d3dbcb4fe43d501349315e1e38e3791b9ec73ea248
3
+ size 711206989
checkpoints/SegTransVAE/Epoch 79-MeanDiceScore0.7905.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e215f6e96d7bd150f78bf896ddb41d5bd979ef4a4faad3a3185b4bec8944b65
3
+ size 711206989
checkpoints/SegTransVAE/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e215f6e96d7bd150f78bf896ddb41d5bd979ef4a4faad3a3185b4bec8944b65
3
+ size 711206989
dataset/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.66 kB). View file
 
dataset/brats.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from monai.transforms import MapTransform
3
+
4
+
5
+ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
6
+ """
7
+ Convert labels to multi channels based on brats classes:
8
+ label 1 is the necrotic and non-enhancing tumor core
9
+ label 2 is the peritumoral edema
10
+ label 4 is the GD-enhancing tumor
11
+ The possible classes are TC (Tumor core), WT (Whole tumor)
12
+ and ET (Enhancing tumor).
13
+
14
+ """
15
+
16
+ def __call__(self, data):
17
+ d = dict(data)
18
+ for key in self.keys:
19
+ result = []
20
+ # merge label 1 and label 4 to construct TC
21
+ result.append(np.logical_or(d[key] == 1, d[key] == 4))
22
+ # merge labels 1, 2 and 4 to construct WT
23
+ result.append(
24
+ np.logical_or(
25
+ np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
26
+ )
27
+ )
28
+ # label 4 is ET
29
+ result.append(d[key] == 4)
30
+ d[key] = np.stack(result, axis=0).astype(np.float32)
31
+ return d
dataset/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ from monai.data import DataLoader, Dataset
6
+ from monai import transforms
7
+
8
+ def datafold_read(datalist, basedir, fold=0, key="training"):
9
+ with open(datalist) as f:
10
+ json_data = json.load(f)
11
+
12
+ json_data = json_data[key]
13
+
14
+ for d in json_data:
15
+ for k in d:
16
+ if isinstance(d[k], list):
17
+ d[k] = [os.path.join(basedir, iv) for iv in d[k]]
18
+ elif isinstance(d[k], str):
19
+ d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
20
+
21
+ tr = []
22
+ val = []
23
+ for d in json_data:
24
+ if "fold" in d and d["fold"] == fold:
25
+ val.append(d)
26
+ else:
27
+ tr.append(d)
28
+
29
+ return tr, val
30
+
31
+
32
+ def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
33
+ train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
34
+ if volume != None :
35
+ train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
36
+
37
+ train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
38
+
39
+ validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
40
+ return train_files, validation_files, test_files
41
+
42
+
43
+ def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
44
+ train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
45
+
46
+ train_transform = transforms.Compose(
47
+ [
48
+ transforms.LoadImaged(keys=["image", "label"]),
49
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
50
+ transforms.CropForegroundd(
51
+ keys=["image", "label"],
52
+ source_key="image",
53
+ k_divisible=[roi[0], roi[1], roi[2]],
54
+ ),
55
+ transforms.RandSpatialCropd(
56
+ keys=["image", "label"],
57
+ roi_size=[roi[0], roi[1], roi[2]],
58
+ random_size=False,
59
+ ),
60
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
61
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
62
+ transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
63
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
64
+ transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
65
+ transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
66
+ ]
67
+ )
68
+ val_transform = transforms.Compose(
69
+ [
70
+ transforms.LoadImaged(keys=["image", "label"]),
71
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
72
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
73
+ ]
74
+ )
75
+
76
+ train_ds = Dataset(data=train_files, transform=train_transform)
77
+ train_loader = DataLoader(
78
+ train_ds,
79
+ batch_size=batch_size,
80
+ shuffle=True,
81
+ num_workers=2,
82
+ pin_memory=True,
83
+ )
84
+ val_ds = Dataset(data=validation_files, transform=val_transform)
85
+ val_loader = DataLoader(
86
+ val_ds,
87
+ batch_size=1,
88
+ shuffle=False,
89
+ num_workers=2,
90
+ pin_memory=True,
91
+ )
92
+ test_ds = Dataset(data=test_files, transform=val_transform)
93
+ test_loader = DataLoader(
94
+ test_ds,
95
+ batch_size=1,
96
+ shuffle=False,
97
+ num_workers=2,
98
+ pin_memory=True,
99
+ )
100
+ return train_loader, val_loader,test_loader
logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff29460308634cd2b1327302a88efa226a2d9575614414ccefe089147587478a
3
+ size 3448488
logs/SegTransVAE/version_0/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
logs/SegTransVAE/version_0/metric_log.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Epoch,Mean Dice Score,Dice TC,Dice WT,Dice ET
2
+ 0,0.004601036664098501,0.0006361556006595492,0.012770041823387146,0.0003969123645219952
3
+ 9,0.6198188066482544,0.574246346950531,0.7122330069541931,0.5729770660400391
4
+ 19,0.6669135689735413,0.6340484023094177,0.7448429465293884,0.6218492984771729
5
+ 29,0.7387320399284363,0.7157909274101257,0.8078205585479736,0.6925845146179199
6
+ 39,0.7600616812705994,0.745855987071991,0.8239226937294006,0.7104061245918274
7
+ 49,0.7556017637252808,0.746627688407898,0.8140978813171387,0.7060797810554504
8
+ 59,0.7752631902694702,0.7715978622436523,0.8288023471832275,0.7253893613815308
9
+ 69,0.790486752986908,0.7867077589035034,0.8424317836761475,0.7423205375671387
10
+ 79,0.7904816269874573,0.788635790348053,0.8472502827644348,0.7355587482452393
loss/__init__.py ADDED
File without changes
loss/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (124 Bytes). View file
 
loss/__pycache__/loss.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
loss/loss.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class Loss_VAE(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.mse = nn.MSELoss(reduction='sum')
8
+
9
+ def forward(self, recon_x, x, mu, log_var):
10
+ mse = self.mse(recon_x, x)
11
+ kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
12
+ loss = mse + kld
13
+ return loss
14
+
15
+
16
+ def DiceScore(
17
+ y_pred: torch.Tensor,
18
+ y: torch.Tensor,
19
+ include_background: bool = True,
20
+ ) -> torch.Tensor:
21
+ """Computes Dice score metric from full size Tensor and collects average.
22
+ Args:
23
+ y_pred: input data to compute, typical segmentation model output.
24
+ It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
25
+ should be binarized.
26
+ y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
27
+ The values should be binarized.
28
+ include_background: whether to skip Dice computation on the first channel of
29
+ the predicted output. Defaults to True.
30
+ Returns:
31
+ Dice scores per batch and per class, (shape [batch_size, num_classes]).
32
+ Raises:
33
+ ValueError: when `y_pred` and `y` have different shapes.
34
+ """
35
+
36
+ y = y.float()
37
+ y_pred = y_pred.float()
38
+
39
+ if y.shape != y_pred.shape:
40
+ raise ValueError("y_pred and y should have same shapes.")
41
+
42
+ # reducing only spatial dimensions (not batch nor channels)
43
+ n_len = len(y_pred.shape)
44
+ reduce_axis = list(range(2, n_len))
45
+ intersection = torch.sum(y * y_pred, dim=reduce_axis)
46
+
47
+ y_o = torch.sum(y, reduce_axis)
48
+ y_pred_o = torch.sum(y_pred, dim=reduce_axis)
49
+ denominator = y_o + y_pred_o
50
+
51
+ return torch.where(
52
+ denominator > 0,
53
+ (2.0 * intersection) / denominator,
54
+ torch.tensor(float("1"), device=y_o.device),
55
+ )
models/SegTranVAE/SegTranVAE.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import torch
5
+ from einops import rearrange
6
+ import torch
7
+ import torch.nn as nn
8
+ ###########Resnet Block############
9
+ def normalization(planes, norm = 'instance'):
10
+ if norm == 'bn':
11
+ m = nn.BatchNorm3d(planes)
12
+ elif norm == 'gn':
13
+ m = nn.GroupNorm(8, planes)
14
+ elif norm == 'instance':
15
+ m = nn.InstanceNorm3d(planes)
16
+ else:
17
+ raise ValueError("Does not support this kind of norm.")
18
+ return m
19
+ class ResNetBlock(nn.Module):
20
+ def __init__(self, in_channels, norm = 'instance'):
21
+ super().__init__()
22
+ self.resnetblock = nn.Sequential(
23
+ normalization(in_channels, norm = norm),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+ nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
26
+ normalization(in_channels, norm = norm),
27
+ nn.LeakyReLU(0.2, inplace=True),
28
+ nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
29
+ )
30
+
31
+ def forward(self, x):
32
+ y = self.resnetblock(x)
33
+ return y + x
34
+
35
+
36
+ ##############VAE###############
37
+ def calculate_total_dimension(a):
38
+ res = 1
39
+ for x in a:
40
+ res *= x
41
+ return res
42
+
43
+ class VAE(nn.Module):
44
+ def __init__(self, input_shape, latent_dim, num_channels):
45
+ super().__init__()
46
+ self.input_shape = input_shape
47
+ self.in_channels = input_shape[1] #input_shape[0] is batch size
48
+ self.latent_dim = latent_dim
49
+ self.encoder_channels = self.in_channels // 16
50
+
51
+ #Encoder
52
+ self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
53
+ kernel_size = 3, stride = 2, padding=1)
54
+ # self.VAE_reshape = nn.Sequential(
55
+ # nn.GroupNorm(8, self.in_channels),
56
+ # nn.ReLU(),
57
+ # nn.Conv3d(self.in_channels, self.encoder_channels,
58
+ # kernel_size = 3, stride = 2, padding=1),
59
+ # )
60
+
61
+ flatten_input_shape = calculate_total_dimension(input_shape)
62
+ flatten_input_shape_after_vae_reshape = \
63
+ flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
64
+
65
+ #Convert from total dimension to latent space
66
+ self.to_latent_space = nn.Linear(
67
+ flatten_input_shape_after_vae_reshape // self.in_channels, 1)
68
+
69
+ self.mean = nn.Linear(self.in_channels, self.latent_dim)
70
+ self.logvar = nn.Linear(self.in_channels, self.latent_dim)
71
+ # self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
72
+
73
+ #Decoder
74
+ self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
75
+ self.Reconstruct = nn.Sequential(
76
+ nn.LeakyReLU(0.2, inplace=True),
77
+ nn.Conv3d(
78
+ self.encoder_channels, self.in_channels,
79
+ stride = 1, kernel_size = 1),
80
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
81
+
82
+ nn.Conv3d(
83
+ self.in_channels, self.in_channels // 2,
84
+ stride = 1, kernel_size = 1),
85
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
86
+ ResNetBlock(self.in_channels // 2),
87
+
88
+ nn.Conv3d(
89
+ self.in_channels // 2, self.in_channels // 4,
90
+ stride = 1, kernel_size = 1),
91
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
92
+ ResNetBlock(self.in_channels // 4),
93
+
94
+ nn.Conv3d(
95
+ self.in_channels // 4, self.in_channels // 8,
96
+ stride = 1, kernel_size = 1),
97
+ nn.Upsample(scale_factor=2, mode = 'nearest'),
98
+ ResNetBlock(self.in_channels // 8),
99
+
100
+ nn.InstanceNorm3d(self.in_channels // 8),
101
+ nn.LeakyReLU(0.2, inplace=True),
102
+ nn.Conv3d(
103
+ self.in_channels // 8, num_channels,
104
+ kernel_size = 3, padding = 1),
105
+ # nn.Sigmoid()
106
+ )
107
+
108
+
109
+ def forward(self, x): #x has shape = input_shape
110
+ #Encoder
111
+ # print(x.shape)
112
+ x = self.VAE_reshape(x)
113
+ shape = x.shape
114
+
115
+ x = x.view(self.in_channels, -1)
116
+ x = self.to_latent_space(x)
117
+ x = x.view(1, self.in_channels)
118
+
119
+ mean = self.mean(x)
120
+ logvar = self.logvar(x)
121
+ # sigma = torch.exp(0.5 * logvar)
122
+ # Reparameter
123
+ epsilon = torch.randn_like(logvar)
124
+ sample = mean + epsilon * torch.exp(0.5*logvar)
125
+
126
+ #Decoder
127
+ y = self.to_original_dimension(sample)
128
+ y = y.view(*shape)
129
+ return self.Reconstruct(y), mean, logvar
130
+ def total_params(self):
131
+ total = sum(p.numel() for p in self.parameters())
132
+ return format(total, ',')
133
+
134
+ def total_trainable_params(self):
135
+ total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
136
+ return format(total_trainable, ',')
137
+
138
+
139
+ # x = torch.rand((1, 256, 16, 16, 16))
140
+ # vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
141
+ # y = vae(x)
142
+ # print(y[0].shape, y[1].shape, y[2].shape)
143
+ # print(vae.total_trainable_params())
144
+
145
+
146
+ ### Decoder ####
147
+
148
+
149
+
150
+ class Upsample(nn.Module):
151
+ def __init__(self, in_channel, out_channel):
152
+ super().__init__()
153
+ self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
154
+ self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
155
+ self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
156
+
157
+ def forward(self, prev, x):
158
+ x = self.deconv(self.conv1(x))
159
+ y = torch.cat((prev, x), dim = 1)
160
+ return self.conv2(y)
161
+
162
+ class FinalConv(nn.Module): # Input channels are equal to output channels
163
+ def __init__(self, in_channels, out_channels=32, norm="instance"):
164
+ super(FinalConv, self).__init__()
165
+ if norm == "batch":
166
+ norm_layer = nn.BatchNorm3d(num_features=in_channels)
167
+ elif norm == "group":
168
+ norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
169
+ elif norm == 'instance':
170
+ norm_layer = nn.InstanceNorm3d(in_channels)
171
+
172
+ self.layer = nn.Sequential(
173
+ norm_layer,
174
+ nn.LeakyReLU(0.2, inplace=True),
175
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
176
+ )
177
+ def forward(self, x):
178
+ return self.layer(x)
179
+
180
+ class Decoder(nn.Module):
181
+ def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
182
+ super().__init__()
183
+ self.img_dim = img_dim
184
+ self.patch_dim = patch_dim
185
+ self.embedding_dim = embedding_dim
186
+
187
+ self.decoder_upsample_1 = Upsample(128, 64)
188
+ self.decoder_block_1 = ResNetBlock(64)
189
+
190
+ self.decoder_upsample_2 = Upsample(64, 32)
191
+ self.decoder_block_2 = ResNetBlock(32)
192
+
193
+ self.decoder_upsample_3 = Upsample(32, 16)
194
+ self.decoder_block_3 = ResNetBlock(16)
195
+
196
+ self.endconv = FinalConv(16, num_classes)
197
+ # self.normalize = nn.Sigmoid()
198
+
199
+ def forward(self, x1, x2, x3, x):
200
+ x = self.decoder_upsample_1(x3, x)
201
+ x = self.decoder_block_1(x)
202
+
203
+ x = self.decoder_upsample_2(x2, x)
204
+ x = self.decoder_block_2(x)
205
+
206
+ x = self.decoder_upsample_3(x1, x)
207
+ x = self.decoder_block_3(x)
208
+
209
+ y = self.endconv(x)
210
+ return y
211
+
212
+
213
+
214
+ ###############Encoder##############
215
+ class InitConv(nn.Module):
216
+ def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
217
+ super().__init__()
218
+ self.layer = nn.Sequential(
219
+ nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
220
+ nn.Dropout3d(dropout)
221
+ )
222
+ def forward(self, x):
223
+ y = self.layer(x)
224
+ return y
225
+
226
+
227
+ class DownSample(nn.Module):
228
+ def __init__(self, in_channels, out_channels):
229
+ super().__init__()
230
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
231
+ def forward(self, x):
232
+ return self.conv(x)
233
+
234
+ class Encoder(nn.Module):
235
+ def __init__(self, in_channels, base_channels, dropout = 0.2):
236
+ super().__init__()
237
+
238
+ self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
239
+ self.encoder_block1 = ResNetBlock(in_channels = base_channels)
240
+ self.encoder_down1 = DownSample(base_channels, base_channels * 2)
241
+
242
+ self.encoder_block2_1 = ResNetBlock(base_channels * 2)
243
+ self.encoder_block2_2 = ResNetBlock(base_channels * 2)
244
+ self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
245
+
246
+ self.encoder_block3_1 = ResNetBlock(base_channels * 4)
247
+ self.encoder_block3_2 = ResNetBlock(base_channels * 4)
248
+ self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
249
+
250
+ self.encoder_block4_1 = ResNetBlock(base_channels * 8)
251
+ self.encoder_block4_2 = ResNetBlock(base_channels * 8)
252
+ self.encoder_block4_3 = ResNetBlock(base_channels * 8)
253
+ self.encoder_block4_4 = ResNetBlock(base_channels * 8)
254
+ # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
255
+ def forward(self, x):
256
+ x = self.init_conv(x) #(1, 16, 128, 128, 128)
257
+
258
+ x1 = self.encoder_block1(x)
259
+ x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
260
+
261
+ x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
262
+ x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
263
+
264
+ x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
265
+ x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
266
+
267
+ output = self.encoder_block4_4(
268
+ self.encoder_block4_3(
269
+ self.encoder_block4_2(
270
+ self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
271
+ return x1, x2, x3, output
272
+
273
+ # x = torch.rand((1, 4, 128, 128, 128))
274
+ # Enc = Encoder(4, 32)
275
+ # _, _, _, y = Enc(x)
276
+ # print(y.shape) (1,256,16,16)
277
+
278
+
279
+ ###############FeatureMapping###############
280
+
281
+ class FeatureMapping(nn.Module):
282
+ def __init__(self, in_channel, out_channel, norm = 'instance'):
283
+ super().__init__()
284
+ if norm == 'bn':
285
+ norm_layer_1 = nn.BatchNorm3d(out_channel)
286
+ norm_layer_2 = nn.BatchNorm3d(out_channel)
287
+ elif norm == 'gn':
288
+ norm_layer_1 = nn.GroupNorm(8, out_channel)
289
+ norm_layer_2 = nn.GroupNorm(8, out_channel)
290
+ elif norm == 'instance':
291
+ norm_layer_1 = nn.InstanceNorm3d(out_channel)
292
+ norm_layer_2 = nn.InstanceNorm3d(out_channel)
293
+ self.feature_mapping = nn.Sequential(
294
+ nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
295
+ norm_layer_1,
296
+ nn.LeakyReLU(0.2, inplace=True),
297
+ nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
298
+ norm_layer_2,
299
+ nn.LeakyReLU(0.2, inplace=True)
300
+ )
301
+
302
+ def forward(self, x):
303
+ return self.feature_mapping(x)
304
+
305
+
306
+ class FeatureMapping1(nn.Module):
307
+ def __init__(self, in_channel, norm = 'instance'):
308
+ super().__init__()
309
+ if norm == 'bn':
310
+ norm_layer_1 = nn.BatchNorm3d(in_channel)
311
+ norm_layer_2 = nn.BatchNorm3d(in_channel)
312
+ elif norm == 'gn':
313
+ norm_layer_1 = nn.GroupNorm(8, in_channel)
314
+ norm_layer_2 = nn.GroupNorm(8, in_channel)
315
+ elif norm == 'instance':
316
+ norm_layer_1 = nn.InstanceNorm3d(in_channel)
317
+ norm_layer_2 = nn.InstanceNorm3d(in_channel)
318
+ self.feature_mapping1 = nn.Sequential(
319
+ nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
320
+ norm_layer_1,
321
+ nn.LeakyReLU(0.2, inplace=True),
322
+ nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
323
+ norm_layer_2,
324
+ nn.LeakyReLU(0.2, inplace=True)
325
+ )
326
+ def forward(self, x):
327
+ y = self.feature_mapping1(x)
328
+ return x + y #Resnet Like
329
+
330
+ ################Transformer#######################
331
+
332
+
333
+ def pair(t):
334
+ return t if isinstance(t, tuple) else (t, t)
335
+
336
+
337
+ class PreNorm(nn.Module):
338
+ def __init__(self, dim, function):
339
+ super().__init__()
340
+ self.norm = nn.LayerNorm(dim)
341
+ self.function = function
342
+
343
+ def forward(self, x):
344
+ return self.function(self.norm(x))
345
+
346
+
347
+ class FeedForward(nn.Module):
348
+ def __init__(self, dim, hidden_dim, dropout = 0.0):
349
+ super().__init__()
350
+ self.net = nn.Sequential(
351
+ nn.Linear(dim, hidden_dim),
352
+ nn.GELU(),
353
+ nn.Dropout(dropout),
354
+ nn.Linear(hidden_dim, dim),
355
+ nn.Dropout(dropout)
356
+ )
357
+
358
+ def forward(self, x):
359
+ return self.net(x)
360
+
361
+ class Attention(nn.Module):
362
+ def __init__(self, dim, heads, dim_head, dropout = 0.0):
363
+ super().__init__()
364
+ all_head_size = heads * dim_head
365
+ project_out = not (heads == 1 and dim_head == dim)
366
+
367
+ self.heads = heads
368
+ self.scale = dim_head ** -0.5
369
+
370
+ self.softmax = nn.Softmax(dim = -1)
371
+ self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
372
+
373
+ self.to_out = nn.Sequential(
374
+ nn.Linear(all_head_size, dim),
375
+ nn.Dropout(dropout)
376
+ ) if project_out else nn.Identity()
377
+
378
+ def forward(self, x):
379
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
380
+ #(batch, heads * dim_head) -> (batch, all_head_size)
381
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
382
+
383
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
384
+
385
+ atten = self.softmax(dots)
386
+
387
+ out = torch.matmul(atten, v)
388
+ out = rearrange(out, 'b h n d -> b n (h d)')
389
+ return self.to_out(out)
390
+
391
+ class Transformer(nn.Module):
392
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
393
+ super().__init__()
394
+ self.layers = nn.ModuleList([])
395
+ for _ in range(depth):
396
+ self.layers.append(nn.ModuleList([
397
+ PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
398
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
399
+ ]))
400
+ def forward(self, x):
401
+ for attention, feedforward in self.layers:
402
+ x = attention(x) + x
403
+ x = feedforward(x) + x
404
+ return x
405
+
406
+ class FixedPositionalEncoding(nn.Module):
407
+ def __init__(self, embedding_dim, max_length=768):
408
+ super(FixedPositionalEncoding, self).__init__()
409
+
410
+ pe = torch.zeros(max_length, embedding_dim)
411
+ position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
412
+ div_term = torch.exp(
413
+ torch.arange(0, embedding_dim, 2).float()
414
+ * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
415
+ )
416
+ pe[:, 0::2] = torch.sin(position * div_term)
417
+ pe[:, 1::2] = torch.cos(position * div_term)
418
+ pe = pe.unsqueeze(0).transpose(0, 1)
419
+ self.register_buffer('pe', pe)
420
+
421
+ def forward(self, x):
422
+ x = x + self.pe[: x.size(0), :]
423
+ return x
424
+
425
+
426
+ class LearnedPositionalEncoding(nn.Module):
427
+ def __init__(self, embedding_dim, seq_length):
428
+ super(LearnedPositionalEncoding, self).__init__()
429
+ self.seq_length = seq_length
430
+ self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
431
+
432
+ def forward(self, x, position_ids=None):
433
+ position_embeddings = self.position_embeddings
434
+ # print(x.shape, self.position_embeddings.shape)
435
+ return x + position_embeddings
436
+
437
+
438
+
439
+
440
+
441
+ ###############Main model#################
442
+
443
+ class SegTransVAE(nn.Module):
444
+ def __init__(self, img_dim, patch_dim, num_channels, num_classes,
445
+ embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
446
+ dropout = 0.0, attention_dropout = 0.0,
447
+ conv_patch_representation = True, positional_encoding = 'learned',
448
+ use_VAE = False):
449
+
450
+ super().__init__()
451
+ assert embedding_dim % num_heads == 0
452
+ assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
453
+
454
+ self.img_dim = img_dim
455
+ self.embedding_dim = embedding_dim
456
+ self.num_heads = num_heads
457
+ self.num_classes = num_classes
458
+ self.patch_dim = patch_dim
459
+ self.num_channels = num_channels
460
+ self.in_channels_vae = in_channels_vae
461
+ self.dropout = dropout
462
+ self.attention_dropout = attention_dropout
463
+ self.conv_patch_representation = conv_patch_representation
464
+ self.use_VAE = use_VAE
465
+
466
+ self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
467
+ self.seq_length = self.num_patches
468
+ self.flatten_dim = 128 * num_channels
469
+
470
+ self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
471
+ if positional_encoding == "learned":
472
+ self.position_encoding = LearnedPositionalEncoding(
473
+ self.embedding_dim, self.seq_length
474
+ )
475
+ elif positional_encoding == "fixed":
476
+ self.position_encoding = FixedPositionalEncoding(
477
+ self.embedding_dim,
478
+ )
479
+ self.pe_dropout = nn.Dropout(self.dropout)
480
+
481
+ self.transformer = Transformer(
482
+ embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
483
+ )
484
+ self.pre_head_ln = nn.LayerNorm(embedding_dim)
485
+
486
+ if self.conv_patch_representation:
487
+ self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
488
+ self.encoder = Encoder(self.num_channels, 16)
489
+ self.bn = nn.InstanceNorm3d(128)
490
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
491
+ self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
492
+ self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
493
+ self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
494
+
495
+ self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
496
+ if use_VAE:
497
+ self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
498
+ def encode(self, x):
499
+ if self.conv_patch_representation:
500
+ x1, x2, x3, x = self.encoder(x)
501
+ x = self.bn(x)
502
+ x = self.relu(x)
503
+ x = self.conv_x(x)
504
+ x = x.permute(0, 2, 3, 4, 1).contiguous()
505
+ x = x.view(x.size(0), -1, self.embedding_dim)
506
+ x = self.position_encoding(x)
507
+ x = self.pe_dropout(x)
508
+ x = self.transformer(x)
509
+ x = self.pre_head_ln(x)
510
+
511
+ return x1, x2, x3, x
512
+
513
+ def decode(self, x1, x2, x3, x):
514
+ #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
515
+ # print("In decode...")
516
+ # print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
517
+ # break
518
+ return self.decoder(x1, x2, x3, x)
519
+
520
+ def forward(self, x, is_validation = True):
521
+ x1, x2, x3, x = self.encode(x)
522
+ x = x.view( x.size(0),
523
+ self.img_dim[0] // self.patch_dim,
524
+ self.img_dim[1] // self.patch_dim,
525
+ self.img_dim[2] // self.patch_dim,
526
+ self.embedding_dim)
527
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
528
+ x = self.FeatureMapping(x)
529
+ x = self.FeatureMapping1(x)
530
+ if self.use_VAE and not is_validation:
531
+ vae_out, mu, sigma = self.vae(x)
532
+ y = self.decode(x1, x2, x3, x)
533
+ if self.use_VAE and not is_validation:
534
+ return y, vae_out, mu, sigma
535
+ else:
536
+ return y
537
+
538
+
models/SegTranVAE/__init__.py ADDED
File without changes
models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
models/SegTranVAE/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (137 Bytes). View file
 
train.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from monai.utils import set_determinism
4
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
5
+ import os
6
+ from pytorch_lightning.loggers import TensorBoardLogger
7
+ from trainer import BRATS
8
+ from dataset.utils import get_loader
9
+ import pytorch_lightning as pl
10
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ set_determinism(seed=0)
14
+
15
+ os.system('cls||clear')
16
+ print("Training ...")
17
+
18
+ data_dir = "/app/brats_2021_task1"
19
+ json_list = "/app/info.json"
20
+ roi = (128, 128, 128)
21
+ batch_size = 1
22
+ fold = 1
23
+ max_epochs = 500
24
+ val_every = 10
25
+ train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
26
+ print("Done initialize dataloader !! ")
27
+
28
+ model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
29
+ checkpoint_callback = ModelCheckpoint(
30
+ monitor='val/MeanDiceScore',
31
+ dirpath='./checkpoints/{}'.format("SegTransVAE"),
32
+ filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
33
+ save_top_k=3,
34
+ mode='max',
35
+ save_last= True,
36
+ auto_insert_metric_name=False
37
+ )
38
+ early_stop_callback = EarlyStopping(
39
+ monitor='val/MeanDiceScore',
40
+ min_delta=0.0001,
41
+ patience=15,
42
+ verbose=False,
43
+ mode='max'
44
+ )
45
+ tensorboardlogger = TensorBoardLogger(
46
+ 'logs',
47
+ name = "SegTransVAE",
48
+ default_hp_metric = None
49
+ )
50
+ trainer = pl.Trainer(#fast_dev_run = 10,
51
+ # accelerator='ddp',
52
+ #overfit_batches=5,
53
+ devices = [0],
54
+ precision=16,
55
+ max_epochs = max_epochs,
56
+ enable_progress_bar=True,
57
+ callbacks=[checkpoint_callback, early_stop_callback],
58
+ # auto_lr_find=True,
59
+ num_sanity_val_steps=1,
60
+ logger = tensorboardlogger,
61
+ check_val_every_n_epoch = 10,
62
+ # limit_train_batches=0.01,
63
+ # limit_val_batches=0.01
64
+ )
65
+ # trainer.tune(model)
66
+ trainer.fit(model)
67
+
68
+
69
+
trainer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytorch_lightning as pl
3
+ import matplotlib.pyplot as plt
4
+ import csv
5
+ import torch
6
+ from monai.transforms import AsDiscrete, Activations, Compose, EnsureType
7
+ from models.SegTranVAE.SegTranVAE import SegTransVAE
8
+ from loss.loss import Loss_VAE, DiceScore
9
+ from monai.losses import DiceLoss
10
+ import pytorch_lightning as pl
11
+ from monai.inferers import sliding_window_inference
12
+
13
+
14
+
15
+
16
+
17
+ class BRATS(pl.LightningModule):
18
+ def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ):
19
+ super().__init__()
20
+ self.train_loader = train_loader
21
+ self.val_loader = val_loader
22
+ self.test_loader = test_loader
23
+ self.use_vae = use_VAE
24
+ self.lr = lr
25
+ self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)
26
+
27
+ self.loss_vae = Loss_VAE()
28
+ self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
29
+ self.post_trans_images = Compose(
30
+ [EnsureType(),
31
+ Activations(sigmoid=True),
32
+ AsDiscrete(threshold_values=True),
33
+ ]
34
+ )
35
+
36
+ self.best_val_dice = 0
37
+
38
+ self.training_step_outputs = []
39
+ self.val_step_loss = []
40
+ self.val_step_dice = []
41
+ self.val_step_dice_tc = []
42
+ self.val_step_dice_wt = []
43
+ self.val_step_dice_et = []
44
+ self.test_step_loss = []
45
+ self.test_step_dice = []
46
+ self.test_step_dice_tc = []
47
+ self.test_step_dice_wt = []
48
+ self.test_step_dice_et = []
49
+
50
+ def forward(self, x, is_validation = True):
51
+ return self.model(x, is_validation)
52
+ def training_step(self, batch, batch_index):
53
+ inputs, labels = (batch['image'], batch['label'])
54
+
55
+ if not self.use_vae:
56
+ outputs = self.forward(inputs, is_validation=False)
57
+ loss = self.dice_loss(outputs, labels)
58
+ else:
59
+ outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
60
+
61
+ vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
62
+ dice_loss = self.dice_loss(outputs, labels)
63
+ loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
64
+ self.training_step_outputs.append(loss)
65
+ self.log('train/vae_loss', vae_loss)
66
+ self.log('train/dice_loss', dice_loss)
67
+ if batch_index == 10:
68
+
69
+ tensorboard = self.logger.experiment
70
+ fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
71
+
72
+
73
+ ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
74
+ ax[0].set_title("Input")
75
+
76
+ ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
77
+ ax[1].set_title("Reconstruction")
78
+
79
+ ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
80
+ ax[2].set_title("Labels TC")
81
+
82
+ ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
83
+ ax[3].set_title("TC")
84
+
85
+ ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
86
+ ax[4].set_title("Labels ET")
87
+
88
+ ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
89
+ ax[5].set_title("ET")
90
+
91
+
92
+ tensorboard.add_figure('train_visualize', fig, self.current_epoch)
93
+
94
+ self.log('train/loss', loss)
95
+
96
+ return loss
97
+
98
+ def on_train_epoch_end(self):
99
+ ## F1 Macro all epoch saving outputs and target per batch
100
+
101
+ # free up the memory
102
+ # --> HERE STEP 3 <--
103
+ epoch_average = torch.stack(self.training_step_outputs).mean()
104
+ self.log("training_epoch_average", epoch_average)
105
+ self.training_step_outputs.clear() # free memory
106
+
107
+ def validation_step(self, batch, batch_index):
108
+ inputs, labels = (batch['image'], batch['label'])
109
+ roi_size = (128, 128, 128)
110
+ sw_batch_size = 1
111
+ outputs = sliding_window_inference(
112
+ inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
113
+ loss = self.dice_loss(outputs, labels)
114
+
115
+
116
+ val_outputs = self.post_trans_images(outputs)
117
+
118
+
119
+ metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
120
+ metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
121
+ metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
122
+ mean_val_dice = (metric_tc + metric_wt + metric_et)/3
123
+ self.val_step_loss.append(loss)
124
+ self.val_step_dice.append(mean_val_dice)
125
+ self.val_step_dice_tc.append(metric_tc)
126
+ self.val_step_dice_wt.append(metric_wt)
127
+ self.val_step_dice_et.append(metric_et)
128
+ return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
129
+ 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
130
+
131
+ def on_validation_epoch_end(self):
132
+
133
+ loss = torch.stack(self.val_step_loss).mean()
134
+ mean_val_dice = torch.stack(self.val_step_dice).mean()
135
+ metric_tc = torch.stack(self.val_step_dice_tc).mean()
136
+ metric_wt = torch.stack(self.val_step_dice_wt).mean()
137
+ metric_et = torch.stack(self.val_step_dice_et).mean()
138
+ self.log('val/Loss', loss)
139
+ self.log('val/MeanDiceScore', mean_val_dice)
140
+ self.log('val/DiceTC', metric_tc)
141
+ self.log('val/DiceWT', metric_wt)
142
+ self.log('val/DiceET', metric_et)
143
+ os.makedirs(self.logger.log_dir, exist_ok=True)
144
+ if self.current_epoch == 0:
145
+ with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
146
+ writer = csv.writer(f)
147
+ writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
148
+ with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
149
+ writer = csv.writer(f)
150
+ writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])
151
+
152
+ if mean_val_dice > self.best_val_dice:
153
+ self.best_val_dice = mean_val_dice
154
+ self.best_val_epoch = self.current_epoch
155
+ print(
156
+ f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
157
+ f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
158
+ f"\n Best mean dice: {self.best_val_dice}"
159
+ f" at epoch: {self.best_val_epoch}"
160
+ )
161
+
162
+ self.val_step_loss.clear()
163
+ self.val_step_dice.clear()
164
+ self.val_step_dice_tc.clear()
165
+ self.val_step_dice_wt.clear()
166
+ self.val_step_dice_et.clear()
167
+ return {'val_MeanDiceScore': mean_val_dice}
168
+ def test_step(self, batch, batch_index):
169
+ inputs, labels = (batch['image'], batch['label'])
170
+
171
+ roi_size = (128, 128, 128)
172
+ sw_batch_size = 1
173
+ test_outputs = sliding_window_inference(
174
+ inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
175
+ loss = self.dice_loss(test_outputs, labels)
176
+ test_outputs = self.post_trans_images(test_outputs)
177
+ metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
178
+ metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
179
+ metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
180
+ mean_test_dice = (metric_tc + metric_wt + metric_et)/3
181
+
182
+ self.test_step_loss.append(loss)
183
+ self.test_step_dice.append(mean_test_dice)
184
+ self.test_step_dice_tc.append(metric_tc)
185
+ self.test_step_dice_wt.append(metric_wt)
186
+ self.test_step_dice_et.append(metric_et)
187
+
188
+ return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
189
+ 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
190
+
191
+ def test_epoch_end(self):
192
+ loss = torch.stack(self.test_step_loss).mean()
193
+ mean_test_dice = torch.stack(self.test_step_dice).mean()
194
+ metric_tc = torch.stack(self.test_step_dice_tc).mean()
195
+ metric_wt = torch.stack(self.test_step_dice_wt).mean()
196
+ metric_et = torch.stack(self.test_step_dice_et).mean()
197
+ self.log('test/Loss', loss)
198
+ self.log('test/MeanDiceScore', mean_test_dice)
199
+ self.log('test/DiceTC', metric_tc)
200
+ self.log('test/DiceWT', metric_wt)
201
+ self.log('test/DiceET', metric_et)
202
+
203
+ with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
204
+ writer = csv.writer(f)
205
+ writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
206
+ writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])
207
+
208
+ self.test_step_loss.clear()
209
+ self.test_step_dice.clear()
210
+ self.test_step_dice_tc.clear()
211
+ self.test_step_dice_wt.clear()
212
+ self.test_step_dice_et.clear()
213
+ return {'test_MeanDiceScore': mean_test_dice}
214
+
215
+
216
+ def configure_optimizers(self):
217
+ optimizer = torch.optim.Adam(
218
+ self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
219
+ )
220
+ # optimizer = AdaBelief(self.model.parameters(),
221
+ # lr=self.lr, eps=1e-16,
222
+ # betas=(0.9,0.999), weight_decouple = True,
223
+ # rectify = False)
224
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
225
+ return [optimizer], [scheduler]
226
+
227
+ def train_dataloader(self):
228
+ return self.train_loader
229
+ def val_dataloader(self):
230
+ return self.val_loader
231
+
232
+ def test_dataloader(self):
233
+ return self.test_loader