Upload folder using huggingface_hub
Browse files- __pycache__/trainer.cpython-39.pyc +0 -0
- checkpoints/SegTransVAE/Epoch 59-MeanDiceScore0.7753.ckpt +3 -0
- checkpoints/SegTransVAE/Epoch 69-MeanDiceScore0.7905.ckpt +3 -0
- checkpoints/SegTransVAE/Epoch 79-MeanDiceScore0.7905.ckpt +3 -0
- checkpoints/SegTransVAE/last.ckpt +3 -0
- dataset/__pycache__/utils.cpython-39.pyc +0 -0
- dataset/brats.py +31 -0
- dataset/utils.py +100 -0
- logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 +3 -0
- logs/SegTransVAE/version_0/hparams.yaml +1 -0
- logs/SegTransVAE/version_0/metric_log.csv +10 -0
- loss/__init__.py +0 -0
- loss/__pycache__/__init__.cpython-39.pyc +0 -0
- loss/__pycache__/loss.cpython-39.pyc +0 -0
- loss/loss.py +55 -0
- models/SegTranVAE/SegTranVAE.py +538 -0
- models/SegTranVAE/__init__.py +0 -0
- models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc +0 -0
- models/SegTranVAE/__pycache__/__init__.cpython-39.pyc +0 -0
- train.py +69 -0
- trainer.py +233 -0
__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (7.55 kB). View file
|
|
checkpoints/SegTransVAE/Epoch 59-MeanDiceScore0.7753.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:633f9f9bc994612e455727ff7d5b10ff0b7a3fe369749e55673d4c0c0cc1cc50
|
3 |
+
size 711206989
|
checkpoints/SegTransVAE/Epoch 69-MeanDiceScore0.7905.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9ac5e8e0ccb79ac91b2e5d3dbcb4fe43d501349315e1e38e3791b9ec73ea248
|
3 |
+
size 711206989
|
checkpoints/SegTransVAE/Epoch 79-MeanDiceScore0.7905.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e215f6e96d7bd150f78bf896ddb41d5bd979ef4a4faad3a3185b4bec8944b65
|
3 |
+
size 711206989
|
checkpoints/SegTransVAE/last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e215f6e96d7bd150f78bf896ddb41d5bd979ef4a4faad3a3185b4bec8944b65
|
3 |
+
size 711206989
|
dataset/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.66 kB). View file
|
|
dataset/brats.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from monai.transforms import MapTransform
|
3 |
+
|
4 |
+
|
5 |
+
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
|
6 |
+
"""
|
7 |
+
Convert labels to multi channels based on brats classes:
|
8 |
+
label 1 is the necrotic and non-enhancing tumor core
|
9 |
+
label 2 is the peritumoral edema
|
10 |
+
label 4 is the GD-enhancing tumor
|
11 |
+
The possible classes are TC (Tumor core), WT (Whole tumor)
|
12 |
+
and ET (Enhancing tumor).
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __call__(self, data):
|
17 |
+
d = dict(data)
|
18 |
+
for key in self.keys:
|
19 |
+
result = []
|
20 |
+
# merge label 1 and label 4 to construct TC
|
21 |
+
result.append(np.logical_or(d[key] == 1, d[key] == 4))
|
22 |
+
# merge labels 1, 2 and 4 to construct WT
|
23 |
+
result.append(
|
24 |
+
np.logical_or(
|
25 |
+
np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
|
26 |
+
)
|
27 |
+
)
|
28 |
+
# label 4 is ET
|
29 |
+
result.append(d[key] == 4)
|
30 |
+
d[key] = np.stack(result, axis=0).astype(np.float32)
|
31 |
+
return d
|
dataset/utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
|
5 |
+
from monai.data import DataLoader, Dataset
|
6 |
+
from monai import transforms
|
7 |
+
|
8 |
+
def datafold_read(datalist, basedir, fold=0, key="training"):
|
9 |
+
with open(datalist) as f:
|
10 |
+
json_data = json.load(f)
|
11 |
+
|
12 |
+
json_data = json_data[key]
|
13 |
+
|
14 |
+
for d in json_data:
|
15 |
+
for k in d:
|
16 |
+
if isinstance(d[k], list):
|
17 |
+
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
|
18 |
+
elif isinstance(d[k], str):
|
19 |
+
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
|
20 |
+
|
21 |
+
tr = []
|
22 |
+
val = []
|
23 |
+
for d in json_data:
|
24 |
+
if "fold" in d and d["fold"] == fold:
|
25 |
+
val.append(d)
|
26 |
+
else:
|
27 |
+
tr.append(d)
|
28 |
+
|
29 |
+
return tr, val
|
30 |
+
|
31 |
+
|
32 |
+
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
|
33 |
+
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
|
34 |
+
if volume != None :
|
35 |
+
train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
|
36 |
+
|
37 |
+
train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
|
38 |
+
|
39 |
+
validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
|
40 |
+
return train_files, validation_files, test_files
|
41 |
+
|
42 |
+
|
43 |
+
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
|
44 |
+
train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
|
45 |
+
|
46 |
+
train_transform = transforms.Compose(
|
47 |
+
[
|
48 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
49 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
50 |
+
transforms.CropForegroundd(
|
51 |
+
keys=["image", "label"],
|
52 |
+
source_key="image",
|
53 |
+
k_divisible=[roi[0], roi[1], roi[2]],
|
54 |
+
),
|
55 |
+
transforms.RandSpatialCropd(
|
56 |
+
keys=["image", "label"],
|
57 |
+
roi_size=[roi[0], roi[1], roi[2]],
|
58 |
+
random_size=False,
|
59 |
+
),
|
60 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
61 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
|
62 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
63 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
64 |
+
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
|
65 |
+
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
val_transform = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
71 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
72 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
train_ds = Dataset(data=train_files, transform=train_transform)
|
77 |
+
train_loader = DataLoader(
|
78 |
+
train_ds,
|
79 |
+
batch_size=batch_size,
|
80 |
+
shuffle=True,
|
81 |
+
num_workers=2,
|
82 |
+
pin_memory=True,
|
83 |
+
)
|
84 |
+
val_ds = Dataset(data=validation_files, transform=val_transform)
|
85 |
+
val_loader = DataLoader(
|
86 |
+
val_ds,
|
87 |
+
batch_size=1,
|
88 |
+
shuffle=False,
|
89 |
+
num_workers=2,
|
90 |
+
pin_memory=True,
|
91 |
+
)
|
92 |
+
test_ds = Dataset(data=test_files, transform=val_transform)
|
93 |
+
test_loader = DataLoader(
|
94 |
+
test_ds,
|
95 |
+
batch_size=1,
|
96 |
+
shuffle=False,
|
97 |
+
num_workers=2,
|
98 |
+
pin_memory=True,
|
99 |
+
)
|
100 |
+
return train_loader, val_loader,test_loader
|
logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff29460308634cd2b1327302a88efa226a2d9575614414ccefe089147587478a
|
3 |
+
size 3448488
|
logs/SegTransVAE/version_0/hparams.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
logs/SegTransVAE/version_0/metric_log.csv
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Epoch,Mean Dice Score,Dice TC,Dice WT,Dice ET
|
2 |
+
0,0.004601036664098501,0.0006361556006595492,0.012770041823387146,0.0003969123645219952
|
3 |
+
9,0.6198188066482544,0.574246346950531,0.7122330069541931,0.5729770660400391
|
4 |
+
19,0.6669135689735413,0.6340484023094177,0.7448429465293884,0.6218492984771729
|
5 |
+
29,0.7387320399284363,0.7157909274101257,0.8078205585479736,0.6925845146179199
|
6 |
+
39,0.7600616812705994,0.745855987071991,0.8239226937294006,0.7104061245918274
|
7 |
+
49,0.7556017637252808,0.746627688407898,0.8140978813171387,0.7060797810554504
|
8 |
+
59,0.7752631902694702,0.7715978622436523,0.8288023471832275,0.7253893613815308
|
9 |
+
69,0.790486752986908,0.7867077589035034,0.8424317836761475,0.7423205375671387
|
10 |
+
79,0.7904816269874573,0.788635790348053,0.8472502827644348,0.7355587482452393
|
loss/__init__.py
ADDED
File without changes
|
loss/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (124 Bytes). View file
|
|
loss/__pycache__/loss.cpython-39.pyc
ADDED
Binary file (2.16 kB). View file
|
|
loss/loss.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class Loss_VAE(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.mse = nn.MSELoss(reduction='sum')
|
8 |
+
|
9 |
+
def forward(self, recon_x, x, mu, log_var):
|
10 |
+
mse = self.mse(recon_x, x)
|
11 |
+
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
12 |
+
loss = mse + kld
|
13 |
+
return loss
|
14 |
+
|
15 |
+
|
16 |
+
def DiceScore(
|
17 |
+
y_pred: torch.Tensor,
|
18 |
+
y: torch.Tensor,
|
19 |
+
include_background: bool = True,
|
20 |
+
) -> torch.Tensor:
|
21 |
+
"""Computes Dice score metric from full size Tensor and collects average.
|
22 |
+
Args:
|
23 |
+
y_pred: input data to compute, typical segmentation model output.
|
24 |
+
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
|
25 |
+
should be binarized.
|
26 |
+
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
|
27 |
+
The values should be binarized.
|
28 |
+
include_background: whether to skip Dice computation on the first channel of
|
29 |
+
the predicted output. Defaults to True.
|
30 |
+
Returns:
|
31 |
+
Dice scores per batch and per class, (shape [batch_size, num_classes]).
|
32 |
+
Raises:
|
33 |
+
ValueError: when `y_pred` and `y` have different shapes.
|
34 |
+
"""
|
35 |
+
|
36 |
+
y = y.float()
|
37 |
+
y_pred = y_pred.float()
|
38 |
+
|
39 |
+
if y.shape != y_pred.shape:
|
40 |
+
raise ValueError("y_pred and y should have same shapes.")
|
41 |
+
|
42 |
+
# reducing only spatial dimensions (not batch nor channels)
|
43 |
+
n_len = len(y_pred.shape)
|
44 |
+
reduce_axis = list(range(2, n_len))
|
45 |
+
intersection = torch.sum(y * y_pred, dim=reduce_axis)
|
46 |
+
|
47 |
+
y_o = torch.sum(y, reduce_axis)
|
48 |
+
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
|
49 |
+
denominator = y_o + y_pred_o
|
50 |
+
|
51 |
+
return torch.where(
|
52 |
+
denominator > 0,
|
53 |
+
(2.0 * intersection) / denominator,
|
54 |
+
torch.tensor(float("1"), device=y_o.device),
|
55 |
+
)
|
models/SegTranVAE/SegTranVAE.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
###########Resnet Block############
|
9 |
+
def normalization(planes, norm = 'instance'):
|
10 |
+
if norm == 'bn':
|
11 |
+
m = nn.BatchNorm3d(planes)
|
12 |
+
elif norm == 'gn':
|
13 |
+
m = nn.GroupNorm(8, planes)
|
14 |
+
elif norm == 'instance':
|
15 |
+
m = nn.InstanceNorm3d(planes)
|
16 |
+
else:
|
17 |
+
raise ValueError("Does not support this kind of norm.")
|
18 |
+
return m
|
19 |
+
class ResNetBlock(nn.Module):
|
20 |
+
def __init__(self, in_channels, norm = 'instance'):
|
21 |
+
super().__init__()
|
22 |
+
self.resnetblock = nn.Sequential(
|
23 |
+
normalization(in_channels, norm = norm),
|
24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
25 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
|
26 |
+
normalization(in_channels, norm = norm),
|
27 |
+
nn.LeakyReLU(0.2, inplace=True),
|
28 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
y = self.resnetblock(x)
|
33 |
+
return y + x
|
34 |
+
|
35 |
+
|
36 |
+
##############VAE###############
|
37 |
+
def calculate_total_dimension(a):
|
38 |
+
res = 1
|
39 |
+
for x in a:
|
40 |
+
res *= x
|
41 |
+
return res
|
42 |
+
|
43 |
+
class VAE(nn.Module):
|
44 |
+
def __init__(self, input_shape, latent_dim, num_channels):
|
45 |
+
super().__init__()
|
46 |
+
self.input_shape = input_shape
|
47 |
+
self.in_channels = input_shape[1] #input_shape[0] is batch size
|
48 |
+
self.latent_dim = latent_dim
|
49 |
+
self.encoder_channels = self.in_channels // 16
|
50 |
+
|
51 |
+
#Encoder
|
52 |
+
self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
|
53 |
+
kernel_size = 3, stride = 2, padding=1)
|
54 |
+
# self.VAE_reshape = nn.Sequential(
|
55 |
+
# nn.GroupNorm(8, self.in_channels),
|
56 |
+
# nn.ReLU(),
|
57 |
+
# nn.Conv3d(self.in_channels, self.encoder_channels,
|
58 |
+
# kernel_size = 3, stride = 2, padding=1),
|
59 |
+
# )
|
60 |
+
|
61 |
+
flatten_input_shape = calculate_total_dimension(input_shape)
|
62 |
+
flatten_input_shape_after_vae_reshape = \
|
63 |
+
flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
|
64 |
+
|
65 |
+
#Convert from total dimension to latent space
|
66 |
+
self.to_latent_space = nn.Linear(
|
67 |
+
flatten_input_shape_after_vae_reshape // self.in_channels, 1)
|
68 |
+
|
69 |
+
self.mean = nn.Linear(self.in_channels, self.latent_dim)
|
70 |
+
self.logvar = nn.Linear(self.in_channels, self.latent_dim)
|
71 |
+
# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
|
72 |
+
|
73 |
+
#Decoder
|
74 |
+
self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
|
75 |
+
self.Reconstruct = nn.Sequential(
|
76 |
+
nn.LeakyReLU(0.2, inplace=True),
|
77 |
+
nn.Conv3d(
|
78 |
+
self.encoder_channels, self.in_channels,
|
79 |
+
stride = 1, kernel_size = 1),
|
80 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
81 |
+
|
82 |
+
nn.Conv3d(
|
83 |
+
self.in_channels, self.in_channels // 2,
|
84 |
+
stride = 1, kernel_size = 1),
|
85 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
86 |
+
ResNetBlock(self.in_channels // 2),
|
87 |
+
|
88 |
+
nn.Conv3d(
|
89 |
+
self.in_channels // 2, self.in_channels // 4,
|
90 |
+
stride = 1, kernel_size = 1),
|
91 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
92 |
+
ResNetBlock(self.in_channels // 4),
|
93 |
+
|
94 |
+
nn.Conv3d(
|
95 |
+
self.in_channels // 4, self.in_channels // 8,
|
96 |
+
stride = 1, kernel_size = 1),
|
97 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
98 |
+
ResNetBlock(self.in_channels // 8),
|
99 |
+
|
100 |
+
nn.InstanceNorm3d(self.in_channels // 8),
|
101 |
+
nn.LeakyReLU(0.2, inplace=True),
|
102 |
+
nn.Conv3d(
|
103 |
+
self.in_channels // 8, num_channels,
|
104 |
+
kernel_size = 3, padding = 1),
|
105 |
+
# nn.Sigmoid()
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, x): #x has shape = input_shape
|
110 |
+
#Encoder
|
111 |
+
# print(x.shape)
|
112 |
+
x = self.VAE_reshape(x)
|
113 |
+
shape = x.shape
|
114 |
+
|
115 |
+
x = x.view(self.in_channels, -1)
|
116 |
+
x = self.to_latent_space(x)
|
117 |
+
x = x.view(1, self.in_channels)
|
118 |
+
|
119 |
+
mean = self.mean(x)
|
120 |
+
logvar = self.logvar(x)
|
121 |
+
# sigma = torch.exp(0.5 * logvar)
|
122 |
+
# Reparameter
|
123 |
+
epsilon = torch.randn_like(logvar)
|
124 |
+
sample = mean + epsilon * torch.exp(0.5*logvar)
|
125 |
+
|
126 |
+
#Decoder
|
127 |
+
y = self.to_original_dimension(sample)
|
128 |
+
y = y.view(*shape)
|
129 |
+
return self.Reconstruct(y), mean, logvar
|
130 |
+
def total_params(self):
|
131 |
+
total = sum(p.numel() for p in self.parameters())
|
132 |
+
return format(total, ',')
|
133 |
+
|
134 |
+
def total_trainable_params(self):
|
135 |
+
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
136 |
+
return format(total_trainable, ',')
|
137 |
+
|
138 |
+
|
139 |
+
# x = torch.rand((1, 256, 16, 16, 16))
|
140 |
+
# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
|
141 |
+
# y = vae(x)
|
142 |
+
# print(y[0].shape, y[1].shape, y[2].shape)
|
143 |
+
# print(vae.total_trainable_params())
|
144 |
+
|
145 |
+
|
146 |
+
### Decoder ####
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class Upsample(nn.Module):
|
151 |
+
def __init__(self, in_channel, out_channel):
|
152 |
+
super().__init__()
|
153 |
+
self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
|
154 |
+
self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
|
155 |
+
self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
|
156 |
+
|
157 |
+
def forward(self, prev, x):
|
158 |
+
x = self.deconv(self.conv1(x))
|
159 |
+
y = torch.cat((prev, x), dim = 1)
|
160 |
+
return self.conv2(y)
|
161 |
+
|
162 |
+
class FinalConv(nn.Module): # Input channels are equal to output channels
|
163 |
+
def __init__(self, in_channels, out_channels=32, norm="instance"):
|
164 |
+
super(FinalConv, self).__init__()
|
165 |
+
if norm == "batch":
|
166 |
+
norm_layer = nn.BatchNorm3d(num_features=in_channels)
|
167 |
+
elif norm == "group":
|
168 |
+
norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
|
169 |
+
elif norm == 'instance':
|
170 |
+
norm_layer = nn.InstanceNorm3d(in_channels)
|
171 |
+
|
172 |
+
self.layer = nn.Sequential(
|
173 |
+
norm_layer,
|
174 |
+
nn.LeakyReLU(0.2, inplace=True),
|
175 |
+
nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
176 |
+
)
|
177 |
+
def forward(self, x):
|
178 |
+
return self.layer(x)
|
179 |
+
|
180 |
+
class Decoder(nn.Module):
|
181 |
+
def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
|
182 |
+
super().__init__()
|
183 |
+
self.img_dim = img_dim
|
184 |
+
self.patch_dim = patch_dim
|
185 |
+
self.embedding_dim = embedding_dim
|
186 |
+
|
187 |
+
self.decoder_upsample_1 = Upsample(128, 64)
|
188 |
+
self.decoder_block_1 = ResNetBlock(64)
|
189 |
+
|
190 |
+
self.decoder_upsample_2 = Upsample(64, 32)
|
191 |
+
self.decoder_block_2 = ResNetBlock(32)
|
192 |
+
|
193 |
+
self.decoder_upsample_3 = Upsample(32, 16)
|
194 |
+
self.decoder_block_3 = ResNetBlock(16)
|
195 |
+
|
196 |
+
self.endconv = FinalConv(16, num_classes)
|
197 |
+
# self.normalize = nn.Sigmoid()
|
198 |
+
|
199 |
+
def forward(self, x1, x2, x3, x):
|
200 |
+
x = self.decoder_upsample_1(x3, x)
|
201 |
+
x = self.decoder_block_1(x)
|
202 |
+
|
203 |
+
x = self.decoder_upsample_2(x2, x)
|
204 |
+
x = self.decoder_block_2(x)
|
205 |
+
|
206 |
+
x = self.decoder_upsample_3(x1, x)
|
207 |
+
x = self.decoder_block_3(x)
|
208 |
+
|
209 |
+
y = self.endconv(x)
|
210 |
+
return y
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
###############Encoder##############
|
215 |
+
class InitConv(nn.Module):
|
216 |
+
def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
|
217 |
+
super().__init__()
|
218 |
+
self.layer = nn.Sequential(
|
219 |
+
nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
|
220 |
+
nn.Dropout3d(dropout)
|
221 |
+
)
|
222 |
+
def forward(self, x):
|
223 |
+
y = self.layer(x)
|
224 |
+
return y
|
225 |
+
|
226 |
+
|
227 |
+
class DownSample(nn.Module):
|
228 |
+
def __init__(self, in_channels, out_channels):
|
229 |
+
super().__init__()
|
230 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
|
231 |
+
def forward(self, x):
|
232 |
+
return self.conv(x)
|
233 |
+
|
234 |
+
class Encoder(nn.Module):
|
235 |
+
def __init__(self, in_channels, base_channels, dropout = 0.2):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
|
239 |
+
self.encoder_block1 = ResNetBlock(in_channels = base_channels)
|
240 |
+
self.encoder_down1 = DownSample(base_channels, base_channels * 2)
|
241 |
+
|
242 |
+
self.encoder_block2_1 = ResNetBlock(base_channels * 2)
|
243 |
+
self.encoder_block2_2 = ResNetBlock(base_channels * 2)
|
244 |
+
self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
|
245 |
+
|
246 |
+
self.encoder_block3_1 = ResNetBlock(base_channels * 4)
|
247 |
+
self.encoder_block3_2 = ResNetBlock(base_channels * 4)
|
248 |
+
self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
|
249 |
+
|
250 |
+
self.encoder_block4_1 = ResNetBlock(base_channels * 8)
|
251 |
+
self.encoder_block4_2 = ResNetBlock(base_channels * 8)
|
252 |
+
self.encoder_block4_3 = ResNetBlock(base_channels * 8)
|
253 |
+
self.encoder_block4_4 = ResNetBlock(base_channels * 8)
|
254 |
+
# self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
|
255 |
+
def forward(self, x):
|
256 |
+
x = self.init_conv(x) #(1, 16, 128, 128, 128)
|
257 |
+
|
258 |
+
x1 = self.encoder_block1(x)
|
259 |
+
x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
|
260 |
+
|
261 |
+
x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
|
262 |
+
x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
|
263 |
+
|
264 |
+
x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
|
265 |
+
x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
|
266 |
+
|
267 |
+
output = self.encoder_block4_4(
|
268 |
+
self.encoder_block4_3(
|
269 |
+
self.encoder_block4_2(
|
270 |
+
self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
|
271 |
+
return x1, x2, x3, output
|
272 |
+
|
273 |
+
# x = torch.rand((1, 4, 128, 128, 128))
|
274 |
+
# Enc = Encoder(4, 32)
|
275 |
+
# _, _, _, y = Enc(x)
|
276 |
+
# print(y.shape) (1,256,16,16)
|
277 |
+
|
278 |
+
|
279 |
+
###############FeatureMapping###############
|
280 |
+
|
281 |
+
class FeatureMapping(nn.Module):
|
282 |
+
def __init__(self, in_channel, out_channel, norm = 'instance'):
|
283 |
+
super().__init__()
|
284 |
+
if norm == 'bn':
|
285 |
+
norm_layer_1 = nn.BatchNorm3d(out_channel)
|
286 |
+
norm_layer_2 = nn.BatchNorm3d(out_channel)
|
287 |
+
elif norm == 'gn':
|
288 |
+
norm_layer_1 = nn.GroupNorm(8, out_channel)
|
289 |
+
norm_layer_2 = nn.GroupNorm(8, out_channel)
|
290 |
+
elif norm == 'instance':
|
291 |
+
norm_layer_1 = nn.InstanceNorm3d(out_channel)
|
292 |
+
norm_layer_2 = nn.InstanceNorm3d(out_channel)
|
293 |
+
self.feature_mapping = nn.Sequential(
|
294 |
+
nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
|
295 |
+
norm_layer_1,
|
296 |
+
nn.LeakyReLU(0.2, inplace=True),
|
297 |
+
nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
|
298 |
+
norm_layer_2,
|
299 |
+
nn.LeakyReLU(0.2, inplace=True)
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
return self.feature_mapping(x)
|
304 |
+
|
305 |
+
|
306 |
+
class FeatureMapping1(nn.Module):
|
307 |
+
def __init__(self, in_channel, norm = 'instance'):
|
308 |
+
super().__init__()
|
309 |
+
if norm == 'bn':
|
310 |
+
norm_layer_1 = nn.BatchNorm3d(in_channel)
|
311 |
+
norm_layer_2 = nn.BatchNorm3d(in_channel)
|
312 |
+
elif norm == 'gn':
|
313 |
+
norm_layer_1 = nn.GroupNorm(8, in_channel)
|
314 |
+
norm_layer_2 = nn.GroupNorm(8, in_channel)
|
315 |
+
elif norm == 'instance':
|
316 |
+
norm_layer_1 = nn.InstanceNorm3d(in_channel)
|
317 |
+
norm_layer_2 = nn.InstanceNorm3d(in_channel)
|
318 |
+
self.feature_mapping1 = nn.Sequential(
|
319 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
320 |
+
norm_layer_1,
|
321 |
+
nn.LeakyReLU(0.2, inplace=True),
|
322 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
323 |
+
norm_layer_2,
|
324 |
+
nn.LeakyReLU(0.2, inplace=True)
|
325 |
+
)
|
326 |
+
def forward(self, x):
|
327 |
+
y = self.feature_mapping1(x)
|
328 |
+
return x + y #Resnet Like
|
329 |
+
|
330 |
+
################Transformer#######################
|
331 |
+
|
332 |
+
|
333 |
+
def pair(t):
|
334 |
+
return t if isinstance(t, tuple) else (t, t)
|
335 |
+
|
336 |
+
|
337 |
+
class PreNorm(nn.Module):
|
338 |
+
def __init__(self, dim, function):
|
339 |
+
super().__init__()
|
340 |
+
self.norm = nn.LayerNorm(dim)
|
341 |
+
self.function = function
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
return self.function(self.norm(x))
|
345 |
+
|
346 |
+
|
347 |
+
class FeedForward(nn.Module):
|
348 |
+
def __init__(self, dim, hidden_dim, dropout = 0.0):
|
349 |
+
super().__init__()
|
350 |
+
self.net = nn.Sequential(
|
351 |
+
nn.Linear(dim, hidden_dim),
|
352 |
+
nn.GELU(),
|
353 |
+
nn.Dropout(dropout),
|
354 |
+
nn.Linear(hidden_dim, dim),
|
355 |
+
nn.Dropout(dropout)
|
356 |
+
)
|
357 |
+
|
358 |
+
def forward(self, x):
|
359 |
+
return self.net(x)
|
360 |
+
|
361 |
+
class Attention(nn.Module):
|
362 |
+
def __init__(self, dim, heads, dim_head, dropout = 0.0):
|
363 |
+
super().__init__()
|
364 |
+
all_head_size = heads * dim_head
|
365 |
+
project_out = not (heads == 1 and dim_head == dim)
|
366 |
+
|
367 |
+
self.heads = heads
|
368 |
+
self.scale = dim_head ** -0.5
|
369 |
+
|
370 |
+
self.softmax = nn.Softmax(dim = -1)
|
371 |
+
self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
|
372 |
+
|
373 |
+
self.to_out = nn.Sequential(
|
374 |
+
nn.Linear(all_head_size, dim),
|
375 |
+
nn.Dropout(dropout)
|
376 |
+
) if project_out else nn.Identity()
|
377 |
+
|
378 |
+
def forward(self, x):
|
379 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
380 |
+
#(batch, heads * dim_head) -> (batch, all_head_size)
|
381 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
382 |
+
|
383 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
384 |
+
|
385 |
+
atten = self.softmax(dots)
|
386 |
+
|
387 |
+
out = torch.matmul(atten, v)
|
388 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
389 |
+
return self.to_out(out)
|
390 |
+
|
391 |
+
class Transformer(nn.Module):
|
392 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
|
393 |
+
super().__init__()
|
394 |
+
self.layers = nn.ModuleList([])
|
395 |
+
for _ in range(depth):
|
396 |
+
self.layers.append(nn.ModuleList([
|
397 |
+
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
398 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
399 |
+
]))
|
400 |
+
def forward(self, x):
|
401 |
+
for attention, feedforward in self.layers:
|
402 |
+
x = attention(x) + x
|
403 |
+
x = feedforward(x) + x
|
404 |
+
return x
|
405 |
+
|
406 |
+
class FixedPositionalEncoding(nn.Module):
|
407 |
+
def __init__(self, embedding_dim, max_length=768):
|
408 |
+
super(FixedPositionalEncoding, self).__init__()
|
409 |
+
|
410 |
+
pe = torch.zeros(max_length, embedding_dim)
|
411 |
+
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
|
412 |
+
div_term = torch.exp(
|
413 |
+
torch.arange(0, embedding_dim, 2).float()
|
414 |
+
* (-torch.log(torch.tensor(10000.0)) / embedding_dim)
|
415 |
+
)
|
416 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
417 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
418 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
419 |
+
self.register_buffer('pe', pe)
|
420 |
+
|
421 |
+
def forward(self, x):
|
422 |
+
x = x + self.pe[: x.size(0), :]
|
423 |
+
return x
|
424 |
+
|
425 |
+
|
426 |
+
class LearnedPositionalEncoding(nn.Module):
|
427 |
+
def __init__(self, embedding_dim, seq_length):
|
428 |
+
super(LearnedPositionalEncoding, self).__init__()
|
429 |
+
self.seq_length = seq_length
|
430 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
|
431 |
+
|
432 |
+
def forward(self, x, position_ids=None):
|
433 |
+
position_embeddings = self.position_embeddings
|
434 |
+
# print(x.shape, self.position_embeddings.shape)
|
435 |
+
return x + position_embeddings
|
436 |
+
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
|
441 |
+
###############Main model#################
|
442 |
+
|
443 |
+
class SegTransVAE(nn.Module):
|
444 |
+
def __init__(self, img_dim, patch_dim, num_channels, num_classes,
|
445 |
+
embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
|
446 |
+
dropout = 0.0, attention_dropout = 0.0,
|
447 |
+
conv_patch_representation = True, positional_encoding = 'learned',
|
448 |
+
use_VAE = False):
|
449 |
+
|
450 |
+
super().__init__()
|
451 |
+
assert embedding_dim % num_heads == 0
|
452 |
+
assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
|
453 |
+
|
454 |
+
self.img_dim = img_dim
|
455 |
+
self.embedding_dim = embedding_dim
|
456 |
+
self.num_heads = num_heads
|
457 |
+
self.num_classes = num_classes
|
458 |
+
self.patch_dim = patch_dim
|
459 |
+
self.num_channels = num_channels
|
460 |
+
self.in_channels_vae = in_channels_vae
|
461 |
+
self.dropout = dropout
|
462 |
+
self.attention_dropout = attention_dropout
|
463 |
+
self.conv_patch_representation = conv_patch_representation
|
464 |
+
self.use_VAE = use_VAE
|
465 |
+
|
466 |
+
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
|
467 |
+
self.seq_length = self.num_patches
|
468 |
+
self.flatten_dim = 128 * num_channels
|
469 |
+
|
470 |
+
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
|
471 |
+
if positional_encoding == "learned":
|
472 |
+
self.position_encoding = LearnedPositionalEncoding(
|
473 |
+
self.embedding_dim, self.seq_length
|
474 |
+
)
|
475 |
+
elif positional_encoding == "fixed":
|
476 |
+
self.position_encoding = FixedPositionalEncoding(
|
477 |
+
self.embedding_dim,
|
478 |
+
)
|
479 |
+
self.pe_dropout = nn.Dropout(self.dropout)
|
480 |
+
|
481 |
+
self.transformer = Transformer(
|
482 |
+
embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
|
483 |
+
)
|
484 |
+
self.pre_head_ln = nn.LayerNorm(embedding_dim)
|
485 |
+
|
486 |
+
if self.conv_patch_representation:
|
487 |
+
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
|
488 |
+
self.encoder = Encoder(self.num_channels, 16)
|
489 |
+
self.bn = nn.InstanceNorm3d(128)
|
490 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
491 |
+
self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
|
492 |
+
self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
|
493 |
+
self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
|
494 |
+
|
495 |
+
self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
|
496 |
+
if use_VAE:
|
497 |
+
self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
|
498 |
+
def encode(self, x):
|
499 |
+
if self.conv_patch_representation:
|
500 |
+
x1, x2, x3, x = self.encoder(x)
|
501 |
+
x = self.bn(x)
|
502 |
+
x = self.relu(x)
|
503 |
+
x = self.conv_x(x)
|
504 |
+
x = x.permute(0, 2, 3, 4, 1).contiguous()
|
505 |
+
x = x.view(x.size(0), -1, self.embedding_dim)
|
506 |
+
x = self.position_encoding(x)
|
507 |
+
x = self.pe_dropout(x)
|
508 |
+
x = self.transformer(x)
|
509 |
+
x = self.pre_head_ln(x)
|
510 |
+
|
511 |
+
return x1, x2, x3, x
|
512 |
+
|
513 |
+
def decode(self, x1, x2, x3, x):
|
514 |
+
#x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
|
515 |
+
# print("In decode...")
|
516 |
+
# print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
|
517 |
+
# break
|
518 |
+
return self.decoder(x1, x2, x3, x)
|
519 |
+
|
520 |
+
def forward(self, x, is_validation = True):
|
521 |
+
x1, x2, x3, x = self.encode(x)
|
522 |
+
x = x.view( x.size(0),
|
523 |
+
self.img_dim[0] // self.patch_dim,
|
524 |
+
self.img_dim[1] // self.patch_dim,
|
525 |
+
self.img_dim[2] // self.patch_dim,
|
526 |
+
self.embedding_dim)
|
527 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
528 |
+
x = self.FeatureMapping(x)
|
529 |
+
x = self.FeatureMapping1(x)
|
530 |
+
if self.use_VAE and not is_validation:
|
531 |
+
vae_out, mu, sigma = self.vae(x)
|
532 |
+
y = self.decode(x1, x2, x3, x)
|
533 |
+
if self.use_VAE and not is_validation:
|
534 |
+
return y, vae_out, mu, sigma
|
535 |
+
else:
|
536 |
+
return y
|
537 |
+
|
538 |
+
|
models/SegTranVAE/__init__.py
ADDED
File without changes
|
models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc
ADDED
Binary file (16.3 kB). View file
|
|
models/SegTranVAE/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (137 Bytes). View file
|
|
train.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from monai.utils import set_determinism
|
4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
5 |
+
import os
|
6 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
7 |
+
from trainer import BRATS
|
8 |
+
from dataset.utils import get_loader
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
+
set_determinism(seed=0)
|
14 |
+
|
15 |
+
os.system('cls||clear')
|
16 |
+
print("Training ...")
|
17 |
+
|
18 |
+
data_dir = "/app/brats_2021_task1"
|
19 |
+
json_list = "/app/info.json"
|
20 |
+
roi = (128, 128, 128)
|
21 |
+
batch_size = 1
|
22 |
+
fold = 1
|
23 |
+
max_epochs = 500
|
24 |
+
val_every = 10
|
25 |
+
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
|
26 |
+
print("Done initialize dataloader !! ")
|
27 |
+
|
28 |
+
model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
|
29 |
+
checkpoint_callback = ModelCheckpoint(
|
30 |
+
monitor='val/MeanDiceScore',
|
31 |
+
dirpath='./checkpoints/{}'.format("SegTransVAE"),
|
32 |
+
filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
|
33 |
+
save_top_k=3,
|
34 |
+
mode='max',
|
35 |
+
save_last= True,
|
36 |
+
auto_insert_metric_name=False
|
37 |
+
)
|
38 |
+
early_stop_callback = EarlyStopping(
|
39 |
+
monitor='val/MeanDiceScore',
|
40 |
+
min_delta=0.0001,
|
41 |
+
patience=15,
|
42 |
+
verbose=False,
|
43 |
+
mode='max'
|
44 |
+
)
|
45 |
+
tensorboardlogger = TensorBoardLogger(
|
46 |
+
'logs',
|
47 |
+
name = "SegTransVAE",
|
48 |
+
default_hp_metric = None
|
49 |
+
)
|
50 |
+
trainer = pl.Trainer(#fast_dev_run = 10,
|
51 |
+
# accelerator='ddp',
|
52 |
+
#overfit_batches=5,
|
53 |
+
devices = [0],
|
54 |
+
precision=16,
|
55 |
+
max_epochs = max_epochs,
|
56 |
+
enable_progress_bar=True,
|
57 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
58 |
+
# auto_lr_find=True,
|
59 |
+
num_sanity_val_steps=1,
|
60 |
+
logger = tensorboardlogger,
|
61 |
+
check_val_every_n_epoch = 10,
|
62 |
+
# limit_train_batches=0.01,
|
63 |
+
# limit_val_batches=0.01
|
64 |
+
)
|
65 |
+
# trainer.tune(model)
|
66 |
+
trainer.fit(model)
|
67 |
+
|
68 |
+
|
69 |
+
|
trainer.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import csv
|
5 |
+
import torch
|
6 |
+
from monai.transforms import AsDiscrete, Activations, Compose, EnsureType
|
7 |
+
from models.SegTranVAE.SegTranVAE import SegTransVAE
|
8 |
+
from loss.loss import Loss_VAE, DiceScore
|
9 |
+
from monai.losses import DiceLoss
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from monai.inferers import sliding_window_inference
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BRATS(pl.LightningModule):
|
18 |
+
def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ):
|
19 |
+
super().__init__()
|
20 |
+
self.train_loader = train_loader
|
21 |
+
self.val_loader = val_loader
|
22 |
+
self.test_loader = test_loader
|
23 |
+
self.use_vae = use_VAE
|
24 |
+
self.lr = lr
|
25 |
+
self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)
|
26 |
+
|
27 |
+
self.loss_vae = Loss_VAE()
|
28 |
+
self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
|
29 |
+
self.post_trans_images = Compose(
|
30 |
+
[EnsureType(),
|
31 |
+
Activations(sigmoid=True),
|
32 |
+
AsDiscrete(threshold_values=True),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
self.best_val_dice = 0
|
37 |
+
|
38 |
+
self.training_step_outputs = []
|
39 |
+
self.val_step_loss = []
|
40 |
+
self.val_step_dice = []
|
41 |
+
self.val_step_dice_tc = []
|
42 |
+
self.val_step_dice_wt = []
|
43 |
+
self.val_step_dice_et = []
|
44 |
+
self.test_step_loss = []
|
45 |
+
self.test_step_dice = []
|
46 |
+
self.test_step_dice_tc = []
|
47 |
+
self.test_step_dice_wt = []
|
48 |
+
self.test_step_dice_et = []
|
49 |
+
|
50 |
+
def forward(self, x, is_validation = True):
|
51 |
+
return self.model(x, is_validation)
|
52 |
+
def training_step(self, batch, batch_index):
|
53 |
+
inputs, labels = (batch['image'], batch['label'])
|
54 |
+
|
55 |
+
if not self.use_vae:
|
56 |
+
outputs = self.forward(inputs, is_validation=False)
|
57 |
+
loss = self.dice_loss(outputs, labels)
|
58 |
+
else:
|
59 |
+
outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
|
60 |
+
|
61 |
+
vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
|
62 |
+
dice_loss = self.dice_loss(outputs, labels)
|
63 |
+
loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
|
64 |
+
self.training_step_outputs.append(loss)
|
65 |
+
self.log('train/vae_loss', vae_loss)
|
66 |
+
self.log('train/dice_loss', dice_loss)
|
67 |
+
if batch_index == 10:
|
68 |
+
|
69 |
+
tensorboard = self.logger.experiment
|
70 |
+
fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
|
71 |
+
|
72 |
+
|
73 |
+
ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
|
74 |
+
ax[0].set_title("Input")
|
75 |
+
|
76 |
+
ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
77 |
+
ax[1].set_title("Reconstruction")
|
78 |
+
|
79 |
+
ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
80 |
+
ax[2].set_title("Labels TC")
|
81 |
+
|
82 |
+
ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
83 |
+
ax[3].set_title("TC")
|
84 |
+
|
85 |
+
ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
86 |
+
ax[4].set_title("Labels ET")
|
87 |
+
|
88 |
+
ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
89 |
+
ax[5].set_title("ET")
|
90 |
+
|
91 |
+
|
92 |
+
tensorboard.add_figure('train_visualize', fig, self.current_epoch)
|
93 |
+
|
94 |
+
self.log('train/loss', loss)
|
95 |
+
|
96 |
+
return loss
|
97 |
+
|
98 |
+
def on_train_epoch_end(self):
|
99 |
+
## F1 Macro all epoch saving outputs and target per batch
|
100 |
+
|
101 |
+
# free up the memory
|
102 |
+
# --> HERE STEP 3 <--
|
103 |
+
epoch_average = torch.stack(self.training_step_outputs).mean()
|
104 |
+
self.log("training_epoch_average", epoch_average)
|
105 |
+
self.training_step_outputs.clear() # free memory
|
106 |
+
|
107 |
+
def validation_step(self, batch, batch_index):
|
108 |
+
inputs, labels = (batch['image'], batch['label'])
|
109 |
+
roi_size = (128, 128, 128)
|
110 |
+
sw_batch_size = 1
|
111 |
+
outputs = sliding_window_inference(
|
112 |
+
inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
|
113 |
+
loss = self.dice_loss(outputs, labels)
|
114 |
+
|
115 |
+
|
116 |
+
val_outputs = self.post_trans_images(outputs)
|
117 |
+
|
118 |
+
|
119 |
+
metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
120 |
+
metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
121 |
+
metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
122 |
+
mean_val_dice = (metric_tc + metric_wt + metric_et)/3
|
123 |
+
self.val_step_loss.append(loss)
|
124 |
+
self.val_step_dice.append(mean_val_dice)
|
125 |
+
self.val_step_dice_tc.append(metric_tc)
|
126 |
+
self.val_step_dice_wt.append(metric_wt)
|
127 |
+
self.val_step_dice_et.append(metric_et)
|
128 |
+
return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
|
129 |
+
'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
|
130 |
+
|
131 |
+
def on_validation_epoch_end(self):
|
132 |
+
|
133 |
+
loss = torch.stack(self.val_step_loss).mean()
|
134 |
+
mean_val_dice = torch.stack(self.val_step_dice).mean()
|
135 |
+
metric_tc = torch.stack(self.val_step_dice_tc).mean()
|
136 |
+
metric_wt = torch.stack(self.val_step_dice_wt).mean()
|
137 |
+
metric_et = torch.stack(self.val_step_dice_et).mean()
|
138 |
+
self.log('val/Loss', loss)
|
139 |
+
self.log('val/MeanDiceScore', mean_val_dice)
|
140 |
+
self.log('val/DiceTC', metric_tc)
|
141 |
+
self.log('val/DiceWT', metric_wt)
|
142 |
+
self.log('val/DiceET', metric_et)
|
143 |
+
os.makedirs(self.logger.log_dir, exist_ok=True)
|
144 |
+
if self.current_epoch == 0:
|
145 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
|
146 |
+
writer = csv.writer(f)
|
147 |
+
writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
|
148 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
|
149 |
+
writer = csv.writer(f)
|
150 |
+
writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])
|
151 |
+
|
152 |
+
if mean_val_dice > self.best_val_dice:
|
153 |
+
self.best_val_dice = mean_val_dice
|
154 |
+
self.best_val_epoch = self.current_epoch
|
155 |
+
print(
|
156 |
+
f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
|
157 |
+
f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
|
158 |
+
f"\n Best mean dice: {self.best_val_dice}"
|
159 |
+
f" at epoch: {self.best_val_epoch}"
|
160 |
+
)
|
161 |
+
|
162 |
+
self.val_step_loss.clear()
|
163 |
+
self.val_step_dice.clear()
|
164 |
+
self.val_step_dice_tc.clear()
|
165 |
+
self.val_step_dice_wt.clear()
|
166 |
+
self.val_step_dice_et.clear()
|
167 |
+
return {'val_MeanDiceScore': mean_val_dice}
|
168 |
+
def test_step(self, batch, batch_index):
|
169 |
+
inputs, labels = (batch['image'], batch['label'])
|
170 |
+
|
171 |
+
roi_size = (128, 128, 128)
|
172 |
+
sw_batch_size = 1
|
173 |
+
test_outputs = sliding_window_inference(
|
174 |
+
inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
|
175 |
+
loss = self.dice_loss(test_outputs, labels)
|
176 |
+
test_outputs = self.post_trans_images(test_outputs)
|
177 |
+
metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
178 |
+
metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
179 |
+
metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
180 |
+
mean_test_dice = (metric_tc + metric_wt + metric_et)/3
|
181 |
+
|
182 |
+
self.test_step_loss.append(loss)
|
183 |
+
self.test_step_dice.append(mean_test_dice)
|
184 |
+
self.test_step_dice_tc.append(metric_tc)
|
185 |
+
self.test_step_dice_wt.append(metric_wt)
|
186 |
+
self.test_step_dice_et.append(metric_et)
|
187 |
+
|
188 |
+
return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
|
189 |
+
'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
|
190 |
+
|
191 |
+
def test_epoch_end(self):
|
192 |
+
loss = torch.stack(self.test_step_loss).mean()
|
193 |
+
mean_test_dice = torch.stack(self.test_step_dice).mean()
|
194 |
+
metric_tc = torch.stack(self.test_step_dice_tc).mean()
|
195 |
+
metric_wt = torch.stack(self.test_step_dice_wt).mean()
|
196 |
+
metric_et = torch.stack(self.test_step_dice_et).mean()
|
197 |
+
self.log('test/Loss', loss)
|
198 |
+
self.log('test/MeanDiceScore', mean_test_dice)
|
199 |
+
self.log('test/DiceTC', metric_tc)
|
200 |
+
self.log('test/DiceWT', metric_wt)
|
201 |
+
self.log('test/DiceET', metric_et)
|
202 |
+
|
203 |
+
with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
|
204 |
+
writer = csv.writer(f)
|
205 |
+
writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
|
206 |
+
writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])
|
207 |
+
|
208 |
+
self.test_step_loss.clear()
|
209 |
+
self.test_step_dice.clear()
|
210 |
+
self.test_step_dice_tc.clear()
|
211 |
+
self.test_step_dice_wt.clear()
|
212 |
+
self.test_step_dice_et.clear()
|
213 |
+
return {'test_MeanDiceScore': mean_test_dice}
|
214 |
+
|
215 |
+
|
216 |
+
def configure_optimizers(self):
|
217 |
+
optimizer = torch.optim.Adam(
|
218 |
+
self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
|
219 |
+
)
|
220 |
+
# optimizer = AdaBelief(self.model.parameters(),
|
221 |
+
# lr=self.lr, eps=1e-16,
|
222 |
+
# betas=(0.9,0.999), weight_decouple = True,
|
223 |
+
# rectify = False)
|
224 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
|
225 |
+
return [optimizer], [scheduler]
|
226 |
+
|
227 |
+
def train_dataloader(self):
|
228 |
+
return self.train_loader
|
229 |
+
def val_dataloader(self):
|
230 |
+
return self.val_loader
|
231 |
+
|
232 |
+
def test_dataloader(self):
|
233 |
+
return self.test_loader
|