Spaces:
Sleeping
Sleeping
Upload 12 files
Browse files- model development/run_best_model_notebook.ipynb +0 -0
- utils/data_preparation.py +241 -0
- utils/data_transforms.py +267 -0
- utils/inference.py +155 -0
- utils/loss.py +153 -0
- utils/models.py +670 -0
- utils/pipeline.py +501 -0
- utils/sliding_window.py +328 -0
- utils/tumor_features.py +55 -0
- utils/visualization.py +109 -0
model development/run_best_model_notebook.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/data_preparation.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from sklearn.model_selection import train_test_split
|
3 |
+
import monai
|
4 |
+
from monai.data import Dataset, DataLoader
|
5 |
+
from data_transforms import define_transforms, define_transforms_loadonly
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from visualization import visualize_patient
|
9 |
+
from monai.data import list_data_collate
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
|
13 |
+
def prepare_clinical_data(data_file, predictors):
|
14 |
+
|
15 |
+
# read data file
|
16 |
+
info = pd.read_excel(data_file, sheet_name=0)
|
17 |
+
|
18 |
+
# convert to numerical
|
19 |
+
info['CPS'] = info['CPS'].map({'A': 1, 'B': 2, 'C': 3})
|
20 |
+
info['T_involvment'] = info['T_involvment'].map({'< or = 50%': 1, '>50%': 2})
|
21 |
+
info['CLIP_Score'] = info['CLIP_Score'].map({'Stage_0': 0, 'Stage_1': 1, 'Stage_2': 2, 'Stage_3': 3, 'Stage_4': 4, 'Stage_5': 5, 'Stage_6': 6})
|
22 |
+
info['Okuda'] = info['Okuda'].map({'Stage I': 1, 'Stage II': 2, 'Stage III': 3})
|
23 |
+
info['TNM'] = info['TNM'].map({'Stage-I': 1, 'Stage-II': 2, 'Stage-IIIA': 3, 'Stage-IIIB': 4, 'Stage-IIIC': 5, 'Stage-IVA': 6, 'Stage-IVB': 7})
|
24 |
+
info['BCLC'] = info['BCLC'].map({'0': 0, 'Stage-A': 1, 'Stage-B': 2, 'Stage-C': 3, 'Stage-D': 4})
|
25 |
+
|
26 |
+
# remove duplicates
|
27 |
+
info.groupby("TCIA_ID").first()
|
28 |
+
|
29 |
+
# select columns
|
30 |
+
info = info[['TCIA_ID'] + predictors].rename(columns={'TCIA_ID': "patient_id"})
|
31 |
+
|
32 |
+
|
33 |
+
return info
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def preparare_train_test_txt(data_dir, test_patient_ratio=0.2, seed=1):
|
38 |
+
"""
|
39 |
+
From a list of patients, split them into train and test and export list to .txt files
|
40 |
+
"""
|
41 |
+
|
42 |
+
# split based on seed, write to txt files
|
43 |
+
patients = os.listdir(data_dir)
|
44 |
+
patients.remove("HCC-TACE-Seg_clinical_data-V2.xlsx")
|
45 |
+
patients = list(set(patients))
|
46 |
+
|
47 |
+
# remove one patient with wrong labels
|
48 |
+
try:
|
49 |
+
patients.remove("HCC_017")
|
50 |
+
print("The patient HCC_017 is removed due to label issues including necrosis.")
|
51 |
+
except Exception as e:
|
52 |
+
pass
|
53 |
+
|
54 |
+
print("Total patients:", len(patients))
|
55 |
+
patients_train, patients_test = train_test_split(patients, test_size=test_patient_ratio, random_state=seed)
|
56 |
+
print(" There are", len(patients_train), "patients in training")
|
57 |
+
print(" There are", len(patients_test), "patients in test")
|
58 |
+
|
59 |
+
# export a copy
|
60 |
+
if not os.path.exists('train-test-split-seed' + str(seed)):
|
61 |
+
os.makedirs('train-test-split-seed' + str(seed))
|
62 |
+
with open(r'train-test-split-seed' + str(seed) + '/train.txt', 'w') as f:
|
63 |
+
f.write(','.join(patient for patient in patients_train))
|
64 |
+
with open(r'train-test-split-seed' + str(seed) + '/test.txt', 'w') as f:
|
65 |
+
f.write(','.join(patient for patient in patients_test))
|
66 |
+
|
67 |
+
print("Files saved to", 'train-test-split-seed' + str(seed) + '/train.txt and train-test-split-seed' + str(seed) + '/test.txt')
|
68 |
+
return
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def extract_file_path(patient_id, data_folder):
|
74 |
+
"""
|
75 |
+
Given one patient's ID, obtain the file path of the image and mask data.
|
76 |
+
If patient has multiple images, they are labeled as pre1, pre2, etc.
|
77 |
+
"""
|
78 |
+
path = os.path.join(data_folder, patient_id)
|
79 |
+
files = os.listdir(path)
|
80 |
+
patient_files = {}
|
81 |
+
count = 1
|
82 |
+
for file in files:
|
83 |
+
if "seg" in file or "Segmentation" in file:
|
84 |
+
patient_files["mask"] = os.path.join(path, file)
|
85 |
+
else:
|
86 |
+
patient_files["pre_" + str(count)] = os.path.join(path, file)
|
87 |
+
count += 1
|
88 |
+
return patient_files
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def get_patient_dictionaries(txt_file, data_dir):
|
93 |
+
"""
|
94 |
+
From .txt file that stores list of patients, look through data folders and extract a dictionary of patient data
|
95 |
+
"""
|
96 |
+
assert os.path.isfile(txt_file), "The file " + txt_file + " was not found. Please check your file directory."
|
97 |
+
|
98 |
+
file = open(txt_file, "r")
|
99 |
+
patients = file.read().split(',')
|
100 |
+
|
101 |
+
data_dict = []
|
102 |
+
|
103 |
+
for patient_id in patients:
|
104 |
+
|
105 |
+
# get directories for mask and images
|
106 |
+
patient_files = extract_file_path(patient_id, data_dir)
|
107 |
+
|
108 |
+
# pair up each image with the mask
|
109 |
+
for key, value in patient_files.items():
|
110 |
+
if key != "mask":
|
111 |
+
data_dict.append(
|
112 |
+
{
|
113 |
+
"patient_id": patient_id,
|
114 |
+
"image": patient_files[key],
|
115 |
+
"mask": patient_files["mask"]
|
116 |
+
}
|
117 |
+
)
|
118 |
+
|
119 |
+
print(" There are", len(data_dict), "image-masks in this dataset.")
|
120 |
+
return data_dict
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
def build_dataset(config, get_clinical=False):
|
126 |
+
|
127 |
+
def custom_collate_fn(batch):
|
128 |
+
"""
|
129 |
+
Custom collate function to stack samples along the first dimension.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
batch (list): List of dictionaries with keys "image" and "mask",
|
133 |
+
where values are tensors of shape (N, 1, 512, 512).
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
tuple: Tuple containing two tensors:
|
137 |
+
- Stacked images of shape (B, 1, 512, 512)
|
138 |
+
- Stacked masks of shape (B, 1, 512, 512)
|
139 |
+
where B is the total number of samples in the batch.
|
140 |
+
"""
|
141 |
+
# torch.manual_seed(1)
|
142 |
+
num_samples_to_select = config['BATCH_SIZE']
|
143 |
+
|
144 |
+
# Extract images and masks from the batch
|
145 |
+
images, masks = [], []
|
146 |
+
for sample in batch:
|
147 |
+
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
|
148 |
+
random_indices = torch.randperm(num_samples)[:num_samples_to_select]
|
149 |
+
if "3D" in config['MODEL_NAME']: # 3D image
|
150 |
+
images.append(sample["image"][:,:512,:512,:]) # ensure image and mask same size
|
151 |
+
masks.append(sample["mask"][:,:512,:512,:])
|
152 |
+
else:
|
153 |
+
images.append(sample["image"][random_indices,:,:512,:512]) # ensure image and mask same size
|
154 |
+
masks.append(sample["mask"][random_indices,:,:512,:512])
|
155 |
+
#images.append(sample["image"][:,:,:512,:512]) # ensure image and mask same size
|
156 |
+
#masks.append(sample["mask"][:,:,:512,:512])
|
157 |
+
|
158 |
+
# Stack images and masks along the first dimension
|
159 |
+
try:
|
160 |
+
if "3D" not in config['MODEL_NAME']: # 3D image
|
161 |
+
concatenated_images = torch.cat(images, dim=0)
|
162 |
+
concatenated_masks = torch.cat(masks, dim=0)
|
163 |
+
else:
|
164 |
+
concatenated_images = torch.stack(images, dim=0)
|
165 |
+
concatenated_masks = torch.stack(masks, dim=0)
|
166 |
+
except Exception as e:
|
167 |
+
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
|
168 |
+
return None, None
|
169 |
+
|
170 |
+
# Return stacked images and masks as tensors
|
171 |
+
return {"image": concatenated_images, "mask": concatenated_masks}
|
172 |
+
|
173 |
+
# get list of training and test patient files
|
174 |
+
train_data_dict = get_patient_dictionaries(config['TRAIN_PATIENTS_FILE'], config['DATA_DIR'])
|
175 |
+
test_data_dict = get_patient_dictionaries(config['TEST_PATIENTS_FILE'], config['DATA_DIR'])
|
176 |
+
if config['ONESAMPLETESTRUN']: train_data_dict = train_data_dict[:2]
|
177 |
+
ttrain_data_dict, valid_data_dict = train_test_split(train_data_dict, test_size=config['VALID_PATIENT_RATIO'], shuffle=False, random_state=1) # must be false to match with linical data
|
178 |
+
print(" Training patients:", len(ttrain_data_dict), " Validation patients:", len(valid_data_dict))
|
179 |
+
print(" Test patients:", len(test_data_dict))
|
180 |
+
|
181 |
+
# define data transformations
|
182 |
+
preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms = define_transforms(config)
|
183 |
+
|
184 |
+
# create data loaders
|
185 |
+
train_ds = Dataset(ttrain_data_dict, transform=preprocessing_transforms_train)
|
186 |
+
valid_ds = Dataset(valid_data_dict, transform=preprocessing_transforms_test)
|
187 |
+
test_ds = Dataset(test_data_dict, transform=preprocessing_transforms_test)
|
188 |
+
|
189 |
+
if "3D" in config['MODEL_NAME']:
|
190 |
+
train_loader = DataLoader(train_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
|
191 |
+
valid_loader = DataLoader(valid_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
|
192 |
+
test_loader = DataLoader(test_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
|
193 |
+
else:
|
194 |
+
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
|
195 |
+
valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
|
196 |
+
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
|
197 |
+
|
198 |
+
# get clinical data
|
199 |
+
df_clinical_train = pd.DataFrame()
|
200 |
+
if get_clinical:
|
201 |
+
# define transforms
|
202 |
+
simple_transforms = define_transforms_loadonly()
|
203 |
+
simple_train_ds = Dataset(train_data_dict, transform=simple_transforms)
|
204 |
+
simple_train_loader = DataLoader(simple_train_ds, batch_size=config['BATCH_SIZE'], collate_fn=list_data_collate, shuffle=False, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
|
205 |
+
|
206 |
+
# compute tumor ratio within liver
|
207 |
+
df_clinical_train['patient_id'] = [p["patient_id"] for p in train_data_dict]
|
208 |
+
ratios_train, ratios_test = [], []
|
209 |
+
for batch_data in simple_train_loader:
|
210 |
+
labels = batch_data["mask"]
|
211 |
+
ratio = torch.sum(labels == 2, dim=(1, 2, 3, 4)) / torch.sum(labels > 0, dim=(1, 2, 3, 4))
|
212 |
+
ratios_train.append(ratio.cpu().numpy()[0]) # [metatensor()]
|
213 |
+
df_clinical_train['tumor_ratio'] = ratios_train
|
214 |
+
|
215 |
+
# get clinical features
|
216 |
+
info = prepare_clinical_data(config['CLINICAL_DATA_FILE'], config['CLINICAL_PREDICTORS'])
|
217 |
+
df_clinical_train = pd.merge(df_clinical_train, info, on='patient_id', how="left")
|
218 |
+
df_clinical_train.fillna(df_clinical_train.median(), inplace=True)
|
219 |
+
df_clinical_train.set_index("patient_id", inplace=True)
|
220 |
+
|
221 |
+
# visualize the data loader for one image to ensure correct formatting
|
222 |
+
print("Example data transformations:")
|
223 |
+
while True:
|
224 |
+
sample = preprocessing_transforms_train(train_data_dict[0])
|
225 |
+
if isinstance(sample, list): # depending on preprocessing, one sample may be [sample] or sample
|
226 |
+
sample = sample[0]
|
227 |
+
if torch.sum(sample['mask'][-1]) == 0: continue
|
228 |
+
print(f" image shape: {sample['image'].shape}")
|
229 |
+
print(f" mask shape: {sample['mask'].shape}")
|
230 |
+
print(f" mask values: {np.unique(sample['mask'])}")
|
231 |
+
#print(f" image affine:\n{sample['image'].meta['affine']}")
|
232 |
+
print(f" image min max: {np.min(sample['image']), np.max(sample['image'])}")
|
233 |
+
visualize_patient(sample['image'], sample['mask'], n_slices=3, z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
234 |
+
break
|
235 |
+
|
236 |
+
temp = monai.utils.first(test_loader)
|
237 |
+
print("Test loader shapes:", temp['image'].shape, temp['mask'].shape)
|
238 |
+
|
239 |
+
return train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train
|
240 |
+
|
241 |
+
|
utils/data_transforms.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import monai
|
2 |
+
import cv2
|
3 |
+
from monai.transforms import MapTransform
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import morphsnakes as ms
|
8 |
+
import monai
|
9 |
+
import nrrd
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from monai.transforms import (
|
12 |
+
Activations, AsDiscreteD, AsDiscrete, Compose, CastToTypeD, RandSpatialCropd,
|
13 |
+
ToTensorD, CropForegroundD, Resized, GaussianSmoothD,
|
14 |
+
LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
|
15 |
+
RandAffineD, ResizeWithPadOrCropd, ToTensor,
|
16 |
+
FillHoles, KeepLargestConnectedComponent, HistogramNormalizeD, NormalizeIntensityD
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def define_transforms_loadonly():
|
22 |
+
transformations = Compose([
|
23 |
+
LoadImageD(keys=["mask"], reader="NrrdReader", ensure_channel_first=True),
|
24 |
+
ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]),
|
25 |
+
ToTensor()
|
26 |
+
])
|
27 |
+
return transformations
|
28 |
+
|
29 |
+
|
30 |
+
def define_post_processing(config):
|
31 |
+
# Post-processing transforms
|
32 |
+
post_processing = [
|
33 |
+
# Apply softmax activation to convert logits to probabilities
|
34 |
+
Activations(sigmoid=True),
|
35 |
+
# Convert predicted probabilities to discrete values (0 or 1)
|
36 |
+
AsDiscrete(argmax=True, to_onehot=None if len(config['KEEP_CLASSES']) <= 2 else len(config['KEEP_CLASSES'])),
|
37 |
+
# Remove small connected components for 1=liver and 2=tumor
|
38 |
+
KeepLargestConnectedComponent(applied_labels=[1]),
|
39 |
+
# Fill holes in the binary mask for 1=liver and 2=tumor
|
40 |
+
FillHoles(applied_labels=[1]),
|
41 |
+
ToTensor()
|
42 |
+
]
|
43 |
+
|
44 |
+
return Compose(post_processing)
|
45 |
+
|
46 |
+
def define_transforms(config):
|
47 |
+
|
48 |
+
transformations_test = [
|
49 |
+
LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
|
50 |
+
# Orient up and down
|
51 |
+
OrientationD(keys=["image", "mask"], axcodes="PLI"),
|
52 |
+
ToTensorD(keys=["image", "mask"])
|
53 |
+
# histogram equilization or normalization
|
54 |
+
# HistogramNormalizeD(keys=["image"], num_bins=256, min=0, max=1),
|
55 |
+
# Intensity normalization
|
56 |
+
# NormalizeIntensityD(keys=["image"]),
|
57 |
+
#CastToTypeD(keys=["image"], dtype=torch.float32),
|
58 |
+
#CastToTypeD(keys=["mask"], dtype=torch.int32),
|
59 |
+
]
|
60 |
+
|
61 |
+
if config['MASKNONLIVER']:
|
62 |
+
transformations_test.extend(
|
63 |
+
[
|
64 |
+
MaskOutNonliver(mask_key="mask"),
|
65 |
+
CropForegroundD(keys=["image", "mask"], source_key="image", allow_smaller=True),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
|
69 |
+
transformations_test.append(
|
70 |
+
# Windowing based on liver parameters
|
71 |
+
ScaleIntensityRangeD(keys=["image"],
|
72 |
+
a_min=config['HU_RANGE'][0],
|
73 |
+
a_max=config['HU_RANGE'][1],
|
74 |
+
b_min=0.0, b_max=1.0, clip=True
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
if config['PREPROCESSING'] == "clihe":
|
79 |
+
transformations_test.append(CLIHE(keys=["image"]))
|
80 |
+
|
81 |
+
elif config['PREPROCESSING'] == "gaussian":
|
82 |
+
transformations_test.append(GaussianSmoothD(keys=["image"], sigma=0.5))
|
83 |
+
|
84 |
+
# convert labels to 0,1,2 instead of 0,1,2,3,4
|
85 |
+
transformations_test.append(ConvertMaskValues(keys=["mask"], keep_classes=config['KEEP_CLASSES']))
|
86 |
+
|
87 |
+
if len(config['KEEP_CLASSES']) > 2: # NEEDED FOR MULTICLASS https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb
|
88 |
+
transformations_test.append(AsDiscreteD(keys=["mask"], to_onehot=len(config['KEEP_CLASSES']))) # (N, C, H, W) 2d; (1, C, H, W, Z)
|
89 |
+
|
90 |
+
if "3D" not in config['MODEL_NAME']:
|
91 |
+
transformations_test.append(TransposeD(keys=["image", "mask"], indices=(3,0,1,2)))
|
92 |
+
|
93 |
+
# training transforms include data augmentation
|
94 |
+
transformations_train = transformations_test.copy()
|
95 |
+
if config['MASKNONLIVER']: transformations_test = transformations_test[:4] + transformations_test[5:] # do not crop to liver foregroudn
|
96 |
+
|
97 |
+
if config['DATA_AUGMENTATION']:
|
98 |
+
if "3D" in config["MODEL_NAME"]:
|
99 |
+
transformations_train.append(
|
100 |
+
RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
|
101 |
+
mode="bilinear", spatial_size=config['ROI_SIZE'],
|
102 |
+
rotate_range=(0.15,0.15,0.15), #translate_range=(30,30,30),
|
103 |
+
scale_range=(0.1,0.1,0.1)))
|
104 |
+
else:
|
105 |
+
transformations_train.append(
|
106 |
+
RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
|
107 |
+
mode="bilinear", #spatial_size=(512, 512),
|
108 |
+
rotate_range=(0.15,0.15), #translate_range=(30,30),
|
109 |
+
scale_range=(0.1,0.1)))
|
110 |
+
|
111 |
+
transformations_train.extend(
|
112 |
+
[
|
113 |
+
RandSpatialCropd(keys=["image", "mask"], roi_size=config['ROI_SIZE'], random_size=False),
|
114 |
+
ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=config['ROI_SIZE'], method="end", mode='constant', value=0)
|
115 |
+
]
|
116 |
+
)
|
117 |
+
|
118 |
+
postprocessing_transforms = define_post_processing(config)
|
119 |
+
preprocessing_transforms_test = Compose(transformations_test)
|
120 |
+
preprocessing_transforms_train = Compose(transformations_train)
|
121 |
+
preprocessing_transforms_train.set_random_state(seed=1)
|
122 |
+
preprocessing_transforms_test.set_random_state(seed=1)
|
123 |
+
|
124 |
+
return preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
class CLIHE(MapTransform):
|
129 |
+
def __init__(self, keys, allow_missing_keys=False):
|
130 |
+
super().__init__(allow_missing_keys)
|
131 |
+
self.keys = keys
|
132 |
+
|
133 |
+
def __call__(self, data):
|
134 |
+
for key in self.keys:
|
135 |
+
if len(data['image'].shape) > 3: # 3D image
|
136 |
+
data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
|
137 |
+
else:
|
138 |
+
data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
|
139 |
+
return data
|
140 |
+
|
141 |
+
def apply_clahe_3d(self, image):
|
142 |
+
image = np.asarray(image)
|
143 |
+
clahe_slices = []
|
144 |
+
for slice_idx in range(image.shape[-1]):
|
145 |
+
# Extract the current slice
|
146 |
+
slice_2d = image[0, :, :, slice_idx]
|
147 |
+
|
148 |
+
# Apply CLAHE to the current slice
|
149 |
+
# slice_2d = cv2.medianBlur(slice_2d, 5)
|
150 |
+
# slice_2d = cv2.anisotropicDiffusion(slice_2d, alpha=0.1, K=1, iterations=50)
|
151 |
+
# slice_2d = anisotropic_diffusion(slice_2d)
|
152 |
+
# slice_2d = cv2.Sobel(slice_2d, cv2.CV_64F, dx=1, dy=1, ksize=5)
|
153 |
+
clahe = cv2.createCLAHE(clipLimit=1, tileGridSize=(16,16))
|
154 |
+
slice_2d = clahe.apply(slice_2d.astype(np.uint8))
|
155 |
+
#cv2.threshold(clahe_slice, 155, 255, cv2.THRESH_BINARY)
|
156 |
+
kernel = np.ones((2,2), np.float32)/4
|
157 |
+
slice_2d = cv2.filter2D(slice_2d, -1, kernel)
|
158 |
+
#t = anisodiff2D(delta_t=0.2,kappa=50)
|
159 |
+
#slice_2d = t.fit(slice_2d)
|
160 |
+
|
161 |
+
# Append the CLAHE enhanced slice to the list
|
162 |
+
clahe_slices.append(slice_2d)
|
163 |
+
|
164 |
+
# Stack the CLAHE enhanced slices along the slice axis to form the 3D image
|
165 |
+
clahe_image = np.stack(clahe_slices, axis=-1)
|
166 |
+
|
167 |
+
return torch.from_numpy(clahe_image[None,:])
|
168 |
+
|
169 |
+
def apply_clahe_2d(self, image):
|
170 |
+
image = np.asarray(image)
|
171 |
+
|
172 |
+
clahe = cv2.createCLAHE(clipLimit=5)
|
173 |
+
clahe_slice = clahe.apply(image[0].astype(np.uint8))
|
174 |
+
|
175 |
+
return torch.from_numpy(clahe_slice)
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
class GaussianFilter(MapTransform):
|
180 |
+
def __init__(self, keys, allow_missing_keys=False):
|
181 |
+
super().__init__(allow_missing_keys)
|
182 |
+
self.keys = keys
|
183 |
+
|
184 |
+
def __call__(self, data):
|
185 |
+
for key in self.keys:
|
186 |
+
if len(data['image'].shape) > 3: # 3D image
|
187 |
+
data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
|
188 |
+
else:
|
189 |
+
data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
|
190 |
+
return data
|
191 |
+
|
192 |
+
def apply_clahe_3d(self, image):
|
193 |
+
image = np.asarray(image)
|
194 |
+
clahe_slices = []
|
195 |
+
for slice_idx in range(image.shape[-1]):
|
196 |
+
# Extract the current slice
|
197 |
+
slice_2d = image[0, :, :, slice_idx]
|
198 |
+
|
199 |
+
# Apply CLAHE to the current slice
|
200 |
+
kernel = np.ones((3,3), np.float32)/9
|
201 |
+
slice_2d = cv2.filter2D(slice_2d, -1, kernel)
|
202 |
+
|
203 |
+
# Append the CLAHE enhanced slice to the list
|
204 |
+
clahe_slices.append(slice_2d)
|
205 |
+
|
206 |
+
# Stack the CLAHE enhanced slices along the slice axis to form the 3D image
|
207 |
+
clahe_image = np.stack(clahe_slices, axis=-1)
|
208 |
+
|
209 |
+
return torch.from_numpy(clahe_image[None,:])
|
210 |
+
|
211 |
+
def apply_clahe_2d(self, image):
|
212 |
+
image = np.asarray(image)
|
213 |
+
|
214 |
+
kernel = np.ones((3,3), np.float32)/9
|
215 |
+
slice_2d = cv2.filter2D(image, -1, kernel)
|
216 |
+
|
217 |
+
return torch.from_numpy(slice_2d)
|
218 |
+
|
219 |
+
|
220 |
+
class Morphsnakes(MapTransform):
|
221 |
+
# https://github.com/pmneila/morphsnakes/blob/master/morphsnakes.py
|
222 |
+
def __init__(self, allow_missing_keys=False):
|
223 |
+
super().__init__(allow_missing_keys)
|
224 |
+
|
225 |
+
def __call__(self, data):
|
226 |
+
if np.sum(data['mask'][-1]) > 0:
|
227 |
+
res = ms.morphological_chan_vese(data['image'][0], iterations=2, init_level_set=data['mask'][-1])
|
228 |
+
data['mask'] = res
|
229 |
+
return data
|
230 |
+
|
231 |
+
|
232 |
+
class MaskOutNonliver(MapTransform):
|
233 |
+
def __init__(self, allow_missing_keys=False, mask_key="mask"):
|
234 |
+
super().__init__(allow_missing_keys)
|
235 |
+
self.mask_key = mask_key
|
236 |
+
|
237 |
+
def __call__(self, data):
|
238 |
+
# mask out non-liver regions of an image
|
239 |
+
# non-liver regions are liver, tumor, or portal vein
|
240 |
+
if data[self.mask_key].shape != data['image'].shape:
|
241 |
+
return data
|
242 |
+
data['image'][data[self.mask_key] >= 4] = -1000
|
243 |
+
data['image'][data[self.mask_key] <= 0] = -1000
|
244 |
+
return data
|
245 |
+
|
246 |
+
|
247 |
+
class ConvertMaskValues(MapTransform):
|
248 |
+
def __init__(self, keys, allow_missing_keys=False, keep_classes=["normal", "liver", "tumor"]):
|
249 |
+
super().__init__(keys, allow_missing_keys)
|
250 |
+
self.keep_classes = keep_classes
|
251 |
+
|
252 |
+
def __call__(self, data):
|
253 |
+
# original labels: 0 for normal region, 1 for liver, 2 for tumor mass, 3 for portal vein, and 4 for abdominal aorta.
|
254 |
+
# converted labels: 0 for normal region and abdominal aorta, 1 for liver and portal vein, 2 for tumor mass
|
255 |
+
|
256 |
+
for key in self.keys:
|
257 |
+
data[key][data[key] > 4] = 4 # one patient had class label = 5, converted to 4
|
258 |
+
if key in data:
|
259 |
+
if "liver" not in self.keep_classes:
|
260 |
+
data[key][data[key] == 1] = 0
|
261 |
+
if "tumor" not in self.keep_classes:
|
262 |
+
data[key][data[key] == 2] = 1
|
263 |
+
if "portal vein" not in self.keep_classes:
|
264 |
+
data[key][data[key] == 3] = 1
|
265 |
+
if "abdominal aorta" not in self.keep_classes:
|
266 |
+
data[key][data[key] >= 4] = 0
|
267 |
+
return data
|
utils/inference.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from monai.transforms import (
|
4 |
+
Activations, AsDiscreteD, AsDiscrete, Compose, ToTensorD,
|
5 |
+
GaussianSmoothD, LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
|
6 |
+
ToTensor, FillHoles, KeepLargestConnectedComponent, NormalizeIntensityD
|
7 |
+
)
|
8 |
+
from nrrd import read
|
9 |
+
from visualization import visualize_results
|
10 |
+
from data_preparation import get_patient_dictionaries
|
11 |
+
from monai.data import Dataset, DataLoader
|
12 |
+
import os
|
13 |
+
from data_transforms import ConvertMaskValues, MaskOutNonliver
|
14 |
+
from pipeline import build_model, evaluate
|
15 |
+
|
16 |
+
def run_sequential_inference(txt_file, config_liver, config_tumor, eval_metrics, output_dir, only_tumor=False, export=True):
|
17 |
+
|
18 |
+
def custom_collate_fn(batch):
|
19 |
+
num_samples_to_select = config_liver['BATCH_SIZE']
|
20 |
+
|
21 |
+
# Extract images and masks from the batch, ensure image and mask same size
|
22 |
+
images, masks, pred_liver = [], [], []
|
23 |
+
for sample in batch:
|
24 |
+
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
|
25 |
+
random_indices = torch.randperm(num_samples)[:num_samples_to_select]
|
26 |
+
images.append(sample["image"][:,:512,:512,:])
|
27 |
+
masks.append(sample["mask"][:,:512,:512,:])
|
28 |
+
|
29 |
+
# Stack images and masks along the first dimension
|
30 |
+
try:
|
31 |
+
concatenated_images = torch.stack(images, dim=0)
|
32 |
+
concatenated_masks = torch.stack(masks, dim=0)
|
33 |
+
except Exception as e:
|
34 |
+
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
|
35 |
+
return None, None
|
36 |
+
|
37 |
+
# Return stacked images and masks as tensors
|
38 |
+
if "pred_liver" in sample.keys():
|
39 |
+
return {"image": concatenated_images, "mask": concatenated_masks, "pred_liver": sample["pred_liver"]}
|
40 |
+
else:
|
41 |
+
return {"image": concatenated_images, "mask": concatenated_masks}
|
42 |
+
|
43 |
+
### Model preparation
|
44 |
+
print("")
|
45 |
+
print("Loading models....")
|
46 |
+
liver_model = build_model(config_liver)
|
47 |
+
tumor_model = build_model(config_tumor)
|
48 |
+
|
49 |
+
#### Data preparation
|
50 |
+
print("")
|
51 |
+
print("Loading test data....")
|
52 |
+
test_data_dict = get_patient_dictionaries(txt_file=txt_file, data_dir=config_liver['DATA_DIR'])
|
53 |
+
print(" Number of test patients:", len(test_data_dict))
|
54 |
+
|
55 |
+
# assign output file names and paths
|
56 |
+
export_file_metadata = []
|
57 |
+
if not os.path.exists(output_dir): os.makedirs(output_dir)
|
58 |
+
for patient_dict in test_data_dict:
|
59 |
+
patient_folder = os.path.join(output_dir, patient_dict['patient_id'])
|
60 |
+
if not os.path.exists(patient_folder): os.makedirs(patient_folder)
|
61 |
+
patient_dict['pred_liver'] = os.path.join(patient_folder, "liver_segmentation.nrrd")
|
62 |
+
patient_dict['pred_tumor'] = os.path.join(patient_folder, "tumor_segmentation.nrrd")
|
63 |
+
export_file_metadata.append(read(patient_dict['image'])[1])
|
64 |
+
|
65 |
+
#### Liver segmentation
|
66 |
+
# define liver data loading and preprocessing
|
67 |
+
if not only_tumor:
|
68 |
+
print("")
|
69 |
+
print("Producing liver segmentations....")
|
70 |
+
liver_preprocessing = Compose([
|
71 |
+
LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
|
72 |
+
OrientationD(keys=["image", "mask"], axcodes="PLI"),
|
73 |
+
ScaleIntensityRangeD(keys=["image"],
|
74 |
+
a_min=config_liver['HU_RANGE'][0],
|
75 |
+
a_max=config_liver['HU_RANGE'][1],
|
76 |
+
b_min=0.0, b_max=1.0, clip=True
|
77 |
+
),
|
78 |
+
ConvertMaskValues(keys=["mask"], keep_classes=["liver"]),
|
79 |
+
ToTensorD(keys=["image", "mask"])
|
80 |
+
])
|
81 |
+
|
82 |
+
liver_postprocessing = Compose([
|
83 |
+
Activations(sigmoid=True),
|
84 |
+
AsDiscrete(argmax=True, to_onehot=None),
|
85 |
+
KeepLargestConnectedComponent(applied_labels=[1]),
|
86 |
+
FillHoles(applied_labels=[1]),
|
87 |
+
ToTensor()
|
88 |
+
])
|
89 |
+
test_ds_liver = Dataset(test_data_dict, transform=liver_preprocessing)
|
90 |
+
test_ds_liver = DataLoader(test_ds_liver, batch_size=config_liver['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_liver['NUM_WORKERS'])
|
91 |
+
|
92 |
+
# produce liver model results
|
93 |
+
test_metrics_liver, sample_output_liver = evaluate(liver_model, test_ds_liver, eval_metrics, config_liver, postprocessing_transforms=liver_postprocessing, export_filenames = [p['pred_liver'] for p in test_data_dict], export_file_metadata=export_file_metadata)
|
94 |
+
|
95 |
+
print("")
|
96 |
+
print("==============================")
|
97 |
+
print("Liver segmentation test performance ....")
|
98 |
+
for key, value in test_metrics_liver.items():
|
99 |
+
print(f' {key.replace("_avg", "_liver")}: {value:.3f}')
|
100 |
+
print("==============================")
|
101 |
+
|
102 |
+
##### Tumor segmentation
|
103 |
+
print("")
|
104 |
+
print("Producing tumor segmentations....")
|
105 |
+
|
106 |
+
# define tumor loading and preprocessing
|
107 |
+
tumor_preprocessing = Compose([
|
108 |
+
LoadImageD(keys=["image", "mask", "pred_liver"], reader="NrrdReader", ensure_channel_first=True),
|
109 |
+
OrientationD(keys=["image", "mask"], axcodes="PLI"),
|
110 |
+
MaskOutNonliver(mask_key="pred_liver"), # note that liver's predicted segmentation is used to crop to the liver region
|
111 |
+
ScaleIntensityRangeD(keys=["image"],
|
112 |
+
a_min=config_tumor['HU_RANGE'][0],
|
113 |
+
a_max=config_tumor['HU_RANGE'][1],
|
114 |
+
b_min=0.0, b_max=1.0, clip=True
|
115 |
+
),
|
116 |
+
ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]), # format mask for measuring test performance
|
117 |
+
AsDiscreteD(keys=["mask"], to_onehot=3), # format mask for measuring test performance
|
118 |
+
ToTensorD(keys=["image", "mask", "pred_liver"])
|
119 |
+
])
|
120 |
+
|
121 |
+
tumor_postprocessing = Compose([
|
122 |
+
Activations(sigmoid=True),
|
123 |
+
AsDiscrete(argmax=True, to_onehot=3),
|
124 |
+
ToTensor()
|
125 |
+
])
|
126 |
+
|
127 |
+
test_ds_tumor = Dataset(test_data_dict, transform=tumor_preprocessing)
|
128 |
+
test_ds_tumor = DataLoader(test_ds_tumor, batch_size=config_tumor['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_tumor['NUM_WORKERS'])
|
129 |
+
|
130 |
+
test_metrics_tumor, sample_output_tumor = evaluate(tumor_model, test_ds_tumor, eval_metrics, config_tumor, tumor_postprocessing, use_liver_seg = True, export_filenames = [p['pred_tumor'] for p in test_data_dict] if export else [], export_file_metadata=export_file_metadata)
|
131 |
+
|
132 |
+
print("")
|
133 |
+
print("==============================")
|
134 |
+
print("Tumor segmentation test performance ....")
|
135 |
+
for key, value in test_metrics_tumor.items():
|
136 |
+
if "class2" in key:
|
137 |
+
print(f' {key.replace("_class2", "_tumor")}: {value:.3f}')
|
138 |
+
print("==============================")
|
139 |
+
print("")
|
140 |
+
|
141 |
+
#### Visualization
|
142 |
+
|
143 |
+
# combine liver and tumor segmentations into one segmentation output
|
144 |
+
if not only_tumor: sample_output_tumor[2][0][1] = sample_output_liver[2][0][0]
|
145 |
+
|
146 |
+
# visualization
|
147 |
+
print("")
|
148 |
+
if not only_tumor:
|
149 |
+
visualize_results(sample_output_liver[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
|
150 |
+
else:
|
151 |
+
visualize_results(sample_output_tumor[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
|
152 |
+
|
153 |
+
return
|
154 |
+
|
155 |
+
|
utils/loss.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from collections.abc import Callable, Sequence
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.nn.modules.loss import _Loss
|
10 |
+
|
11 |
+
from monai.losses.dice import DiceLoss
|
12 |
+
from monai.losses.focal_loss import FocalLoss
|
13 |
+
from monai.networks import one_hot
|
14 |
+
from monai.utils import DiceCEReduction, LossReduction, Weight, deprecated_arg, look_up_option, pytorch_after
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
##### Adapted from Monai DiceFocalLoss
|
19 |
+
class WeaklyDiceFocalLoss(_Loss):
|
20 |
+
"""
|
21 |
+
Compute Dice loss, Focal Loss, and weakly supervised loss from clinical predictor, and return the weighted sum of these three losses.
|
22 |
+
|
23 |
+
``gamma`` and ``lambda_focal`` are only used for the focal loss.
|
24 |
+
``include_background``, ``weight`` and ``reduction`` are used for both losses
|
25 |
+
and other parameters are only used for dice loss.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
include_background: bool = True,
|
31 |
+
to_onehot_y: bool = False,
|
32 |
+
sigmoid: bool = False,
|
33 |
+
softmax: bool = False,
|
34 |
+
other_act: Callable | None = None,
|
35 |
+
squared_pred: bool = False,
|
36 |
+
jaccard: bool = False,
|
37 |
+
reduction: str = "mean",
|
38 |
+
smooth_nr: float = 1e-5,
|
39 |
+
smooth_dr: float = 1e-5,
|
40 |
+
batch: bool = False,
|
41 |
+
gamma: float = 2.0,
|
42 |
+
focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
|
43 |
+
weight: Sequence[float] | float | int | torch.Tensor | None = None,
|
44 |
+
lambda_dice: float = 1.0,
|
45 |
+
lambda_focal: float = 1.0,
|
46 |
+
lambda_weak: float = 1.0,
|
47 |
+
) -> None:
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
include_background: if False channel index 0 (background category) is excluded from the calculation.
|
51 |
+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
|
52 |
+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
|
53 |
+
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
|
54 |
+
don't need to specify activation function for `FocalLoss`.
|
55 |
+
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
|
56 |
+
don't need to specify activation function for `FocalLoss`.
|
57 |
+
other_act: callable function to execute other activation layers, Defaults to ``None``.
|
58 |
+
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
|
59 |
+
squared_pred: use squared versions of targets and predictions in the denominator or not.
|
60 |
+
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
|
61 |
+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
|
62 |
+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
|
63 |
+
|
64 |
+
- ``"none"``: no reduction will be applied.
|
65 |
+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
|
66 |
+
- ``"sum"``: the output will be summed.
|
67 |
+
|
68 |
+
smooth_nr: a small constant added to the numerator to avoid zero.
|
69 |
+
smooth_dr: a small constant added to the denominator to avoid nan.
|
70 |
+
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
|
71 |
+
Defaults to False, a Dice loss value is computed independently from each item in the batch
|
72 |
+
before any `reduction`.
|
73 |
+
gamma: value of the exponent gamma in the definition of the Focal loss.
|
74 |
+
weight: weights to apply to the voxels of each class. If None no weights are applied.
|
75 |
+
The input can be a single value (same weight for all classes), a sequence of values (the length
|
76 |
+
of the sequence should be the same as the number of classes).
|
77 |
+
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
|
78 |
+
Defaults to 1.0.
|
79 |
+
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
|
80 |
+
Defaults to 1.0.
|
81 |
+
lambda_weak: the trade-off weight value for weakly supervised loss. The value should be no less than 0.0
|
82 |
+
Defaults to 0.2.
|
83 |
+
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
weight = focal_weight if focal_weight is not None else weight
|
87 |
+
self.dice = DiceLoss(
|
88 |
+
include_background=include_background,
|
89 |
+
to_onehot_y=False,
|
90 |
+
sigmoid=sigmoid,
|
91 |
+
softmax=softmax,
|
92 |
+
other_act=other_act,
|
93 |
+
squared_pred=squared_pred,
|
94 |
+
jaccard=jaccard,
|
95 |
+
reduction=reduction,
|
96 |
+
smooth_nr=smooth_nr,
|
97 |
+
smooth_dr=smooth_dr,
|
98 |
+
batch=batch,
|
99 |
+
weight=weight,
|
100 |
+
)
|
101 |
+
self.focal = FocalLoss(
|
102 |
+
include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
|
103 |
+
)
|
104 |
+
if lambda_dice < 0.0:
|
105 |
+
raise ValueError("lambda_dice should be no less than 0.0.")
|
106 |
+
if lambda_focal < 0.0:
|
107 |
+
raise ValueError("lambda_focal should be no less than 0.0.")
|
108 |
+
if lambda_weak < 0.0:
|
109 |
+
raise ValueError("lambda_weak should be no less than 0.0.")
|
110 |
+
self.lambda_dice = lambda_dice
|
111 |
+
self.lambda_focal = lambda_focal
|
112 |
+
self.to_onehot_y = to_onehot_y
|
113 |
+
self.lambda_weak = lambda_weak
|
114 |
+
|
115 |
+
|
116 |
+
def compute_weakly_supervised_loss(self, input: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor:
|
117 |
+
# compute ratio of tumor/liver in the predicted mask
|
118 |
+
tumor_pixels = torch.sum(input[:, -1, ...], dim=(1, 2, 3))
|
119 |
+
liver_pixels = torch.sum(input[:, -2, ...], dim=(1, 2, 3)) + tumor_pixels
|
120 |
+
predicted_ratio = tumor_pixels / liver_pixels
|
121 |
+
loss = torch.mean((predicted_ratio - weaktarget) ** 2)
|
122 |
+
return loss
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor:
|
127 |
+
"""
|
128 |
+
Args:
|
129 |
+
input: the shape should be BNH[WD]. The input should be the original logits
|
130 |
+
due to the restriction of ``monai.losses.FocalLoss``.
|
131 |
+
target: the shape should be BNH[WD] or B1H[WD].
|
132 |
+
|
133 |
+
Raises:
|
134 |
+
ValueError: When number of dimensions for input and target are different.
|
135 |
+
ValueError: When number of channels for target is neither 1 nor the same as input.
|
136 |
+
|
137 |
+
"""
|
138 |
+
if len(input.shape) != len(target.shape):
|
139 |
+
raise ValueError(
|
140 |
+
"the number of dimensions for input and target should be the same, "
|
141 |
+
f"got shape {input.shape} and {target.shape}."
|
142 |
+
)
|
143 |
+
if self.to_onehot_y:
|
144 |
+
n_pred_ch = input.shape[1]
|
145 |
+
if n_pred_ch == 1:
|
146 |
+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
|
147 |
+
else:
|
148 |
+
target = one_hot(target, num_classes=n_pred_ch)
|
149 |
+
dice_loss = self.dice(input, target)
|
150 |
+
focal_loss = self.focal(input, target)
|
151 |
+
weak_loss = self.compute_weakly_supervised_loss(input, weaktarget)
|
152 |
+
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss + self.lambda_weak * weak_loss
|
153 |
+
return total_loss
|
utils/models.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
# 2D: net = UNet2D(1,2,pab_channels=64,use_batchnorm=True)
|
8 |
+
# 3D: net = UNet3D(1,2,pab_channels=32,use_batchnorm=True)
|
9 |
+
|
10 |
+
class _NonLocalBlockND(nn.Module):
|
11 |
+
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
|
12 |
+
super(_NonLocalBlockND, self).__init__()
|
13 |
+
|
14 |
+
assert dimension in [1, 2, 3]
|
15 |
+
|
16 |
+
self.dimension = dimension
|
17 |
+
self.sub_sample = sub_sample
|
18 |
+
|
19 |
+
self.in_channels = in_channels
|
20 |
+
self.inter_channels = inter_channels
|
21 |
+
|
22 |
+
if self.inter_channels is None:
|
23 |
+
self.inter_channels = in_channels // 2
|
24 |
+
if self.inter_channels == 0:
|
25 |
+
self.inter_channels = 1
|
26 |
+
|
27 |
+
if dimension == 3:
|
28 |
+
conv_nd = nn.Conv3d
|
29 |
+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
30 |
+
bn = nn.BatchNorm3d
|
31 |
+
elif dimension == 2:
|
32 |
+
conv_nd = nn.Conv2d
|
33 |
+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
34 |
+
bn = nn.BatchNorm2d
|
35 |
+
else:
|
36 |
+
conv_nd = nn.Conv1d
|
37 |
+
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
|
38 |
+
bn = nn.BatchNorm1d
|
39 |
+
|
40 |
+
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
|
41 |
+
kernel_size=1, stride=1, padding=0)
|
42 |
+
|
43 |
+
if bn_layer:
|
44 |
+
self.W = nn.Sequential(
|
45 |
+
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
|
46 |
+
kernel_size=1, stride=1, padding=0),
|
47 |
+
bn(self.in_channels)
|
48 |
+
)
|
49 |
+
nn.init.constant_(self.W[1].weight, 0)
|
50 |
+
nn.init.constant_(self.W[1].bias, 0)
|
51 |
+
else:
|
52 |
+
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
|
53 |
+
kernel_size=1, stride=1, padding=0)
|
54 |
+
nn.init.constant_(self.W.weight, 0)
|
55 |
+
nn.init.constant_(self.W.bias, 0)
|
56 |
+
|
57 |
+
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
|
58 |
+
kernel_size=1, stride=1, padding=0)
|
59 |
+
|
60 |
+
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
|
61 |
+
kernel_size=1, stride=1, padding=0)
|
62 |
+
|
63 |
+
if sub_sample:
|
64 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
65 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
'''
|
69 |
+
:param x: (b, c, t, h, w)
|
70 |
+
:return:
|
71 |
+
'''
|
72 |
+
|
73 |
+
batch_size = x.size(0)
|
74 |
+
|
75 |
+
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
|
76 |
+
g_x = g_x.permute(0, 2, 1)
|
77 |
+
|
78 |
+
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
|
79 |
+
theta_x = theta_x.permute(0, 2, 1)
|
80 |
+
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
|
81 |
+
f = torch.matmul(theta_x, phi_x)
|
82 |
+
N = f.size(-1)
|
83 |
+
f_div_C = f / N
|
84 |
+
|
85 |
+
y = torch.matmul(f_div_C, g_x)
|
86 |
+
y = y.permute(0, 2, 1).contiguous()
|
87 |
+
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
|
88 |
+
W_y = self.W(y)
|
89 |
+
z = W_y + x
|
90 |
+
|
91 |
+
return z
|
92 |
+
|
93 |
+
|
94 |
+
class NONLocalBlock1D(_NonLocalBlockND):
|
95 |
+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
|
96 |
+
super(NONLocalBlock1D, self).__init__(in_channels,
|
97 |
+
inter_channels=inter_channels,
|
98 |
+
dimension=1, sub_sample=sub_sample,
|
99 |
+
bn_layer=bn_layer)
|
100 |
+
|
101 |
+
|
102 |
+
class NONLocalBlock2D(_NonLocalBlockND):
|
103 |
+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
|
104 |
+
super(NONLocalBlock2D, self).__init__(in_channels,
|
105 |
+
inter_channels=inter_channels,
|
106 |
+
dimension=2, sub_sample=sub_sample,
|
107 |
+
bn_layer=bn_layer)
|
108 |
+
|
109 |
+
|
110 |
+
class NONLocalBlock3D(_NonLocalBlockND):
|
111 |
+
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
|
112 |
+
super(NONLocalBlock3D, self).__init__(in_channels,
|
113 |
+
inter_channels=inter_channels,
|
114 |
+
dimension=3, sub_sample=sub_sample,
|
115 |
+
bn_layer=bn_layer)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
class Conv2dReLU(nn.Sequential):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
in_channels,
|
123 |
+
out_channels,
|
124 |
+
kernel_size,
|
125 |
+
padding=0,
|
126 |
+
stride=1,
|
127 |
+
use_batchnorm=True,
|
128 |
+
):
|
129 |
+
|
130 |
+
if use_batchnorm == "inplace" and InPlaceABN is None:
|
131 |
+
raise RuntimeError(
|
132 |
+
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
133 |
+
+ "To install see: https://github.com/mapillary/inplace_abn"
|
134 |
+
)
|
135 |
+
|
136 |
+
conv = nn.Conv2d(
|
137 |
+
in_channels,
|
138 |
+
out_channels,
|
139 |
+
kernel_size,
|
140 |
+
stride=stride,
|
141 |
+
padding=padding,
|
142 |
+
bias=not (use_batchnorm),
|
143 |
+
)
|
144 |
+
relu = nn.ReLU(inplace=True)
|
145 |
+
|
146 |
+
if use_batchnorm == "inplace":
|
147 |
+
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
|
148 |
+
relu = nn.Identity()
|
149 |
+
|
150 |
+
elif use_batchnorm and use_batchnorm != "inplace":
|
151 |
+
bn = nn.BatchNorm2d(out_channels)
|
152 |
+
|
153 |
+
else:
|
154 |
+
bn = nn.Identity()
|
155 |
+
|
156 |
+
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
157 |
+
|
158 |
+
class Conv3dReLU(nn.Sequential):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
in_channels,
|
162 |
+
out_channels,
|
163 |
+
kernel_size,
|
164 |
+
padding=0,
|
165 |
+
stride=1,
|
166 |
+
use_batchnorm=True,
|
167 |
+
):
|
168 |
+
|
169 |
+
if use_batchnorm == "inplace" and InPlaceABN is None:
|
170 |
+
raise RuntimeError(
|
171 |
+
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
|
172 |
+
+ "To install see: https://github.com/mapillary/inplace_abn"
|
173 |
+
)
|
174 |
+
|
175 |
+
conv = nn.Conv3d(
|
176 |
+
in_channels,
|
177 |
+
out_channels,
|
178 |
+
kernel_size,
|
179 |
+
stride=stride,
|
180 |
+
padding=padding,
|
181 |
+
bias=not (use_batchnorm),
|
182 |
+
)
|
183 |
+
relu = nn.ReLU(inplace=True)
|
184 |
+
|
185 |
+
if use_batchnorm == "inplace":
|
186 |
+
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
|
187 |
+
relu = nn.Identity()
|
188 |
+
|
189 |
+
elif use_batchnorm and use_batchnorm != "inplace":
|
190 |
+
bn = nn.BatchNorm3d(out_channels)
|
191 |
+
|
192 |
+
else:
|
193 |
+
bn = nn.Identity()
|
194 |
+
|
195 |
+
super(Conv3dReLU, self).__init__(conv, bn, relu)
|
196 |
+
class PAB2D(nn.Module):
|
197 |
+
def __init__(self, in_channels, out_channels, pab_channels=64):
|
198 |
+
super(PAB2D, self).__init__()
|
199 |
+
# Series of 1x1 conv to generate attention feature maps
|
200 |
+
self.pab_channels = pab_channels
|
201 |
+
self.in_channels = in_channels
|
202 |
+
self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
|
203 |
+
self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
|
204 |
+
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
205 |
+
self.map_softmax = nn.Softmax(dim=1)
|
206 |
+
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
bsize = x.size()[0]
|
210 |
+
h = x.size()[2]
|
211 |
+
w = x.size()[3]
|
212 |
+
x_top = self.top_conv(x)
|
213 |
+
x_center = self.center_conv(x)
|
214 |
+
x_bottom = self.bottom_conv(x)
|
215 |
+
|
216 |
+
x_top = x_top.flatten(2)
|
217 |
+
x_center = x_center.flatten(2).transpose(1, 2)
|
218 |
+
x_bottom = x_bottom.flatten(2).transpose(1, 2)
|
219 |
+
|
220 |
+
sp_map = torch.matmul(x_center, x_top)
|
221 |
+
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
|
222 |
+
sp_map = torch.matmul(sp_map, x_bottom)
|
223 |
+
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
|
224 |
+
x = x + sp_map
|
225 |
+
x = self.out_conv(x)
|
226 |
+
# print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
|
227 |
+
return x
|
228 |
+
|
229 |
+
class MFAB2D(nn.Module):
|
230 |
+
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
|
231 |
+
# MFAB is just a modified version of SE-blocks, one for skip, one for input
|
232 |
+
super(MFAB2D, self).__init__()
|
233 |
+
self.hl_conv = nn.Sequential(
|
234 |
+
Conv2dReLU(
|
235 |
+
in_channels,
|
236 |
+
in_channels,
|
237 |
+
kernel_size=3,
|
238 |
+
padding=1,
|
239 |
+
use_batchnorm=use_batchnorm,
|
240 |
+
),
|
241 |
+
Conv2dReLU(
|
242 |
+
in_channels,
|
243 |
+
skip_channels,
|
244 |
+
kernel_size=1,
|
245 |
+
use_batchnorm=use_batchnorm,
|
246 |
+
)
|
247 |
+
)
|
248 |
+
self.SE_ll = nn.Sequential(
|
249 |
+
nn.AdaptiveAvgPool2d(1),
|
250 |
+
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
|
251 |
+
nn.ReLU(inplace=True),
|
252 |
+
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
|
253 |
+
nn.Sigmoid(),
|
254 |
+
)
|
255 |
+
self.SE_hl = nn.Sequential(
|
256 |
+
nn.AdaptiveAvgPool2d(1),
|
257 |
+
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
|
258 |
+
nn.ReLU(inplace=True),
|
259 |
+
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
|
260 |
+
nn.Sigmoid(),
|
261 |
+
)
|
262 |
+
self.conv1 = Conv2dReLU(
|
263 |
+
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
|
264 |
+
out_channels,
|
265 |
+
kernel_size=3,
|
266 |
+
padding=1,
|
267 |
+
use_batchnorm=use_batchnorm,
|
268 |
+
)
|
269 |
+
self.conv2 = Conv2dReLU(
|
270 |
+
out_channels,
|
271 |
+
out_channels,
|
272 |
+
kernel_size=3,
|
273 |
+
padding=1,
|
274 |
+
use_batchnorm=use_batchnorm,
|
275 |
+
)
|
276 |
+
|
277 |
+
def forward(self, x, skip=None):
|
278 |
+
x = self.hl_conv(x)
|
279 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
280 |
+
attention_hl = self.SE_hl(x)
|
281 |
+
if skip is not None:
|
282 |
+
attention_ll = self.SE_ll(skip)
|
283 |
+
attention_hl = attention_hl + attention_ll
|
284 |
+
x = x * attention_hl
|
285 |
+
x = torch.cat([x, skip], dim=1)
|
286 |
+
x = self.conv1(x)
|
287 |
+
x = self.conv2(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
class PAB3D(nn.Module):
|
291 |
+
def __init__(self, in_channels, out_channels, pab_channels=64):
|
292 |
+
super(PAB3D, self).__init__()
|
293 |
+
# Series of 1x1 conv to generate attention feature maps
|
294 |
+
self.pab_channels = pab_channels
|
295 |
+
self.in_channels = in_channels
|
296 |
+
self.top_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
|
297 |
+
self.center_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
|
298 |
+
self.bottom_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
|
299 |
+
self.map_softmax = nn.Softmax(dim=1)
|
300 |
+
self.out_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
bsize = x.size()[0]
|
304 |
+
h = x.size()[2]
|
305 |
+
w = x.size()[3]
|
306 |
+
d = x.size()[4]
|
307 |
+
x_top = self.top_conv(x)
|
308 |
+
x_center = self.center_conv(x)
|
309 |
+
x_bottom = self.bottom_conv(x)
|
310 |
+
|
311 |
+
x_top = x_top.flatten(2)
|
312 |
+
x_center = x_center.flatten(2).transpose(1, 2)
|
313 |
+
x_bottom = x_bottom.flatten(2).transpose(1, 2)
|
314 |
+
sp_map = torch.matmul(x_center, x_top)
|
315 |
+
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w*d, h*w*d)
|
316 |
+
sp_map = torch.matmul(sp_map, x_bottom)
|
317 |
+
sp_map = sp_map.reshape(bsize, self.in_channels, h, w, d)
|
318 |
+
x = x + sp_map
|
319 |
+
x = self.out_conv(x)
|
320 |
+
# print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
|
321 |
+
return x
|
322 |
+
|
323 |
+
class MFAB3D(nn.Module):
|
324 |
+
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
|
325 |
+
# MFAB is just a modified version of SE-blocks, one for skip, one for input
|
326 |
+
super(MFAB3D, self).__init__()
|
327 |
+
self.hl_conv = nn.Sequential(
|
328 |
+
Conv3dReLU(
|
329 |
+
in_channels,
|
330 |
+
in_channels,
|
331 |
+
kernel_size=3,
|
332 |
+
padding=1,
|
333 |
+
use_batchnorm=use_batchnorm,
|
334 |
+
),
|
335 |
+
Conv3dReLU(
|
336 |
+
in_channels,
|
337 |
+
skip_channels,
|
338 |
+
kernel_size=1,
|
339 |
+
use_batchnorm=use_batchnorm,
|
340 |
+
)
|
341 |
+
)
|
342 |
+
self.SE_ll = nn.Sequential(
|
343 |
+
nn.AdaptiveAvgPool3d(1),
|
344 |
+
nn.Conv3d(skip_channels, skip_channels // reduction, 1),
|
345 |
+
nn.ReLU(inplace=True),
|
346 |
+
nn.Conv3d(skip_channels // reduction, skip_channels, 1),
|
347 |
+
nn.Sigmoid(),
|
348 |
+
)
|
349 |
+
self.SE_hl = nn.Sequential(
|
350 |
+
nn.AdaptiveAvgPool3d(1),
|
351 |
+
nn.Conv3d(skip_channels, skip_channels // reduction, 1),
|
352 |
+
nn.ReLU(inplace=True),
|
353 |
+
nn.Conv3d(skip_channels // reduction, skip_channels, 1),
|
354 |
+
nn.Sigmoid(),
|
355 |
+
)
|
356 |
+
self.conv1 = Conv3dReLU(
|
357 |
+
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
|
358 |
+
out_channels,
|
359 |
+
kernel_size=3,
|
360 |
+
padding=1,
|
361 |
+
use_batchnorm=use_batchnorm,
|
362 |
+
)
|
363 |
+
self.conv2 = Conv3dReLU(
|
364 |
+
out_channels,
|
365 |
+
out_channels,
|
366 |
+
kernel_size=3,
|
367 |
+
padding=1,
|
368 |
+
use_batchnorm=use_batchnorm,
|
369 |
+
)
|
370 |
+
|
371 |
+
def forward(self, x, skip=None):
|
372 |
+
x = self.hl_conv(x)
|
373 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
374 |
+
attention_hl = self.SE_hl(x)
|
375 |
+
if skip is not None:
|
376 |
+
attention_ll = self.SE_ll(skip)
|
377 |
+
attention_hl = attention_hl + attention_ll
|
378 |
+
x = x * attention_hl
|
379 |
+
x = torch.cat([x, skip], dim=1)
|
380 |
+
x = self.conv1(x)
|
381 |
+
x = self.conv2(x)
|
382 |
+
return x
|
383 |
+
|
384 |
+
class DoubleConv2D(nn.Module):
|
385 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
386 |
+
|
387 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
388 |
+
super().__init__()
|
389 |
+
if not mid_channels:
|
390 |
+
mid_channels = out_channels
|
391 |
+
self.double_conv = nn.Sequential(
|
392 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
393 |
+
nn.BatchNorm2d(mid_channels),
|
394 |
+
nn.ReLU(inplace=True),
|
395 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
396 |
+
nn.BatchNorm2d(out_channels),
|
397 |
+
nn.ReLU(inplace=True)
|
398 |
+
)
|
399 |
+
|
400 |
+
def forward(self, x):
|
401 |
+
return self.double_conv(x)
|
402 |
+
|
403 |
+
class Down2D(nn.Module):
|
404 |
+
"""Downscaling with maxpool then double conv"""
|
405 |
+
|
406 |
+
def __init__(self, in_channels, out_channels):
|
407 |
+
super().__init__()
|
408 |
+
self.maxpool_conv = nn.Sequential(
|
409 |
+
nn.MaxPool2d(2),
|
410 |
+
NONLocalBlock2D(in_channels),
|
411 |
+
DoubleConv2D(in_channels, out_channels)
|
412 |
+
)
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
return self.maxpool_conv(x)
|
416 |
+
|
417 |
+
|
418 |
+
class Up2D(nn.Module):
|
419 |
+
"""Upscaling then double conv"""
|
420 |
+
|
421 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
422 |
+
super().__init__()
|
423 |
+
|
424 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
425 |
+
if bilinear:
|
426 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
427 |
+
self.conv = DoubleConv2D(in_channels, out_channels, in_channels // 2)
|
428 |
+
else:
|
429 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
430 |
+
self.conv = DoubleConv2D(in_channels, out_channels)
|
431 |
+
|
432 |
+
def forward(self, x1, x2):
|
433 |
+
x1 = self.up(x1)
|
434 |
+
# input is CHW
|
435 |
+
diffY = x2.size()[2] - x1.size()[2]
|
436 |
+
diffX = x2.size()[3] - x1.size()[3]
|
437 |
+
|
438 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
439 |
+
diffY // 2, diffY - diffY // 2])
|
440 |
+
# if you have padding issues, see
|
441 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
442 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
443 |
+
x = torch.cat([x2, x1], dim=1)
|
444 |
+
return self.conv(x)
|
445 |
+
|
446 |
+
class OutConv2D(nn.Module):
|
447 |
+
def __init__(self, in_channels, out_channels):
|
448 |
+
super(OutConv2D, self).__init__()
|
449 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
450 |
+
|
451 |
+
def forward(self, x):
|
452 |
+
return self.conv(x)
|
453 |
+
|
454 |
+
class UNet2D(nn.Module):
|
455 |
+
def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
|
456 |
+
super(UNet2D, self).__init__()
|
457 |
+
self.n_channels = n_channels
|
458 |
+
self.n_classes = n_classes
|
459 |
+
self.bilinear = bilinear
|
460 |
+
self.inc = DoubleConv2D(n_channels, pab_channels)
|
461 |
+
self.down1 = Down2D(pab_channels, 2*pab_channels)
|
462 |
+
self.down2 = Down2D(2*pab_channels, 4*pab_channels)
|
463 |
+
self.down3 = Down2D(4*pab_channels, 8*pab_channels)
|
464 |
+
factor = 2 if bilinear else 1
|
465 |
+
self.down4 = Down2D(8*pab_channels, 16*pab_channels // factor)
|
466 |
+
self.pab = PAB2D(8*pab_channels,8*pab_channels)
|
467 |
+
self.up1 = Up2D(16*pab_channels, 8*pab_channels // factor, bilinear)
|
468 |
+
self.up2 = Up2D(8*pab_channels, 4*pab_channels // factor, bilinear)
|
469 |
+
self.up3 = Up2D(4*pab_channels, 2*pab_channels // factor, bilinear)
|
470 |
+
self.up4 = Up2D(2*pab_channels, pab_channels, bilinear)
|
471 |
+
|
472 |
+
self.mfab1 = MFAB2D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
|
473 |
+
self.mfab2 = MFAB2D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
|
474 |
+
self.mfab3 = MFAB2D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
|
475 |
+
self.mfab4 = MFAB2D(pab_channels,pab_channels,pab_channels,use_batchnorm)
|
476 |
+
self.outc = OutConv2D(pab_channels, n_classes)
|
477 |
+
|
478 |
+
if aux_classifier == False:
|
479 |
+
self.aux = None
|
480 |
+
else:
|
481 |
+
# customize the auxiliary classification loss
|
482 |
+
# self.aux = nn.Sequential(nn.AdaptiveAvgPool2d(1),
|
483 |
+
# nn.Flatten(),
|
484 |
+
# nn.Dropout(p=0.1, inplace=True),
|
485 |
+
# nn.Linear(8*pab_channels, 16, bias=True),
|
486 |
+
# nn.Dropout(p=0.1, inplace=True),
|
487 |
+
# nn.Linear(16, n_classes, bias=True),
|
488 |
+
# nn.Softmax(1))
|
489 |
+
|
490 |
+
self.aux = nn.Sequential(
|
491 |
+
NONLocalBlock2D(8*pab_channels),
|
492 |
+
nn.Conv2d(8*pab_channels,1,1),
|
493 |
+
nn.InstanceNorm2d(1),
|
494 |
+
nn.ReLU(),
|
495 |
+
nn.Flatten(),
|
496 |
+
nn.Linear(24*24, 16, bias=True),
|
497 |
+
nn.Dropout(p=0.2, inplace=True),
|
498 |
+
nn.Linear(16, n_classes, bias=True),
|
499 |
+
nn.Softmax(1))
|
500 |
+
def forward(self, x):
|
501 |
+
x1 = self.inc(x)
|
502 |
+
x2 = self.down1(x1)
|
503 |
+
x3 = self.down2(x2)
|
504 |
+
x4 = self.down3(x3)
|
505 |
+
x5 = self.down4(x4)
|
506 |
+
x5 = self.pab(x5)
|
507 |
+
|
508 |
+
x = self.mfab1(x5,x4)
|
509 |
+
x = self.mfab2(x,x3)
|
510 |
+
x = self.mfab3(x,x2)
|
511 |
+
x = self.mfab4(x,x1)
|
512 |
+
|
513 |
+
# x = self.up1(x5, x4)
|
514 |
+
# x = self.up2(x, x3)
|
515 |
+
# x = self.up3(x, x2)
|
516 |
+
# x = self.up4(x, x1)
|
517 |
+
logits = self.outc(x)
|
518 |
+
logits = F.softmax(logits,1)
|
519 |
+
|
520 |
+
if self.aux ==None:
|
521 |
+
return logits
|
522 |
+
else:
|
523 |
+
aux = self.aux(x5)
|
524 |
+
return logits, aux
|
525 |
+
|
526 |
+
|
527 |
+
|
528 |
+
|
529 |
+
class DoubleConv3D(nn.Module):
|
530 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
531 |
+
|
532 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
533 |
+
super().__init__()
|
534 |
+
if not mid_channels:
|
535 |
+
mid_channels = out_channels
|
536 |
+
self.double_conv = nn.Sequential(
|
537 |
+
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
|
538 |
+
nn.BatchNorm3d(mid_channels),
|
539 |
+
nn.ReLU(inplace=True),
|
540 |
+
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
|
541 |
+
nn.BatchNorm3d(out_channels),
|
542 |
+
nn.ReLU(inplace=True)
|
543 |
+
)
|
544 |
+
|
545 |
+
def forward(self, x):
|
546 |
+
return self.double_conv(x)
|
547 |
+
|
548 |
+
class Down3D(nn.Module):
|
549 |
+
"""Downscaling with maxpool then double conv"""
|
550 |
+
|
551 |
+
def __init__(self, in_channels, out_channels):
|
552 |
+
super().__init__()
|
553 |
+
self.maxpool_conv = nn.Sequential(
|
554 |
+
nn.MaxPool3d(2),
|
555 |
+
# NONLocalBlock3D(in_channels),
|
556 |
+
DoubleConv3D(in_channels, out_channels)
|
557 |
+
)
|
558 |
+
|
559 |
+
def forward(self, x):
|
560 |
+
return self.maxpool_conv(x)
|
561 |
+
|
562 |
+
class Up3D(nn.Module):
|
563 |
+
"""Upscaling then double conv"""
|
564 |
+
|
565 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
566 |
+
super().__init__()
|
567 |
+
|
568 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
569 |
+
if bilinear:
|
570 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
571 |
+
self.conv = DoubleConv3D(in_channels, out_channels, in_channels // 2)
|
572 |
+
else:
|
573 |
+
self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
574 |
+
self.conv = DoubleConv3D(in_channels, out_channels)
|
575 |
+
|
576 |
+
def forward(self, x1, x2):
|
577 |
+
x1 = self.up(x1)
|
578 |
+
# input is CHW
|
579 |
+
diffY = x2.size()[2] - x1.size()[2]
|
580 |
+
diffX = x2.size()[3] - x1.size()[3]
|
581 |
+
|
582 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
583 |
+
diffY // 2, diffY - diffY // 2])
|
584 |
+
# if you have padding issues, see
|
585 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
586 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
587 |
+
x = torch.cat([x2, x1], dim=1)
|
588 |
+
return self.conv(x)
|
589 |
+
|
590 |
+
class OutConv3D(nn.Module):
|
591 |
+
def __init__(self, in_channels, out_channels):
|
592 |
+
super(OutConv3D, self).__init__()
|
593 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
594 |
+
|
595 |
+
def forward(self, x):
|
596 |
+
return self.conv(x)
|
597 |
+
|
598 |
+
class UNet3D(nn.Module):
|
599 |
+
def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
|
600 |
+
super(UNet3D, self).__init__()
|
601 |
+
self.n_channels = n_channels
|
602 |
+
self.n_classes = n_classes
|
603 |
+
self.bilinear = bilinear
|
604 |
+
|
605 |
+
self.inc = DoubleConv3D(n_channels, pab_channels)
|
606 |
+
self.down1 = Down3D(pab_channels, 2*pab_channels)
|
607 |
+
self.nnblock2 = NONLocalBlock3D(2*pab_channels)
|
608 |
+
self.down2 = Down3D(2*pab_channels, 4*pab_channels)
|
609 |
+
self.down3 = Down3D(4*pab_channels, 8*pab_channels)
|
610 |
+
factor = 2 if bilinear else 1
|
611 |
+
self.down4 = Down3D(8*pab_channels, 16*pab_channels // factor)
|
612 |
+
self.pab = PAB3D(8*pab_channels,8*pab_channels)
|
613 |
+
self.up1 = Up3D(16*pab_channels, 8*pab_channels // factor, bilinear)
|
614 |
+
self.up2 = Up3D(8*pab_channels, 4*pab_channels // factor, bilinear)
|
615 |
+
self.up3 = Up3D(4*pab_channels, 2*pab_channels // factor, bilinear)
|
616 |
+
self.up4 = Up3D(2*pab_channels, pab_channels, bilinear)
|
617 |
+
|
618 |
+
self.mfab1 = MFAB3D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
|
619 |
+
self.mfab2 = MFAB3D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
|
620 |
+
self.mfab3 = MFAB3D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
|
621 |
+
self.mfab4 = MFAB3D(pab_channels,pab_channels,pab_channels,use_batchnorm)
|
622 |
+
self.outc = OutConv3D(pab_channels, n_classes)
|
623 |
+
|
624 |
+
if aux_classifier == False:
|
625 |
+
self.aux = None
|
626 |
+
else:
|
627 |
+
# customize the auxiliary classification loss
|
628 |
+
# self.aux = nn.Sequential(nn.AdaptiveMaxPool3d(1),
|
629 |
+
# nn.Flatten(),
|
630 |
+
# nn.Dropout(p=0.1, inplace=True),
|
631 |
+
# nn.Linear(8*pab_channels, 16, bias=True),
|
632 |
+
# nn.Dropout(p=0.1, inplace=True),
|
633 |
+
# nn.Linear(16, n_classes, bias=True),
|
634 |
+
# nn.Softmax(1))
|
635 |
+
|
636 |
+
self.aux = nn.Sequential(nn.Conv3d(8*pab_channels,1,1),
|
637 |
+
nn.InstanceNorm3d(1),
|
638 |
+
nn.ReLU(),
|
639 |
+
nn.Flatten(),
|
640 |
+
nn.Linear(16*16*2, 16, bias=True),
|
641 |
+
nn.Dropout(p=0.2, inplace=True),
|
642 |
+
nn.Linear(16, n_classes, bias=True),
|
643 |
+
nn.Softmax(1))
|
644 |
+
|
645 |
+
def forward(self, x):
|
646 |
+
x1 = self.inc(x)
|
647 |
+
x2 = self.down1(x1)
|
648 |
+
# x2 = self.nnblock2(x2)
|
649 |
+
x3 = self.down2(x2)
|
650 |
+
x4 = self.down3(x3)
|
651 |
+
x5 = self.down4(x4)
|
652 |
+
x5 = self.pab(x5)
|
653 |
+
|
654 |
+
x = self.mfab1(x5,x4)
|
655 |
+
x = self.mfab2(x,x3)
|
656 |
+
x = self.mfab3(x,x2)
|
657 |
+
x = self.mfab4(x,x1)
|
658 |
+
|
659 |
+
# x = self.up1(x5, x4)
|
660 |
+
# x = self.up2(x, x3)
|
661 |
+
# x = self.up3(x, x2)
|
662 |
+
# x = self.up4(x, x1)
|
663 |
+
logits = self.outc(x)
|
664 |
+
logits = F.softmax(logits,1)
|
665 |
+
|
666 |
+
if self.aux ==None:
|
667 |
+
return logits
|
668 |
+
else:
|
669 |
+
aux = self.aux(x5)
|
670 |
+
return logits, aux
|
utils/pipeline.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import tempfile
|
4 |
+
from glob import glob
|
5 |
+
from torchsummary import summary
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
from torch.utils.tensorboard import SummaryWriter
|
11 |
+
from torch.cuda.amp import autocast, GradScaler
|
12 |
+
import torch.nn as nn
|
13 |
+
import torchvision
|
14 |
+
import monai
|
15 |
+
from monai.metrics import DiceMetric, ConfusionMatrixMetric, MeanIoU
|
16 |
+
from monai.visualize import plot_2d_or_3d_image
|
17 |
+
from visualization import visualize_patient
|
18 |
+
from sliding_window import sw_inference
|
19 |
+
from data_preparation import build_dataset
|
20 |
+
from models import UNet2D, UNet3D
|
21 |
+
from loss import WeaklyDiceFocalLoss
|
22 |
+
from sklearn.linear_model import LinearRegression
|
23 |
+
from nrrd import write, read
|
24 |
+
import morphsnakes as ms
|
25 |
+
from monai.data import decollate_batch
|
26 |
+
|
27 |
+
|
28 |
+
def build_optimizer(model, config):
|
29 |
+
|
30 |
+
if config['LOSS'] == "gdice":
|
31 |
+
loss_function = monai.losses.GeneralizedDiceLoss(
|
32 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
33 |
+
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.GeneralizedDiceLoss(
|
34 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
|
35 |
+
elif config['LOSS'] == 'cdice':
|
36 |
+
loss_function = monai.losses.DiceCELoss(
|
37 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
38 |
+
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceCELoss(
|
39 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
|
40 |
+
elif config['LOSS'] == 'mdice':
|
41 |
+
loss_function = monai.losses.MaskedDiceLoss()
|
42 |
+
elif config['LOSS'] == 'wdice':
|
43 |
+
# Example with 3 classes (including the background: label 0).
|
44 |
+
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
|
45 |
+
# The distance between class 1 and class 2 is 0.5.
|
46 |
+
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
|
47 |
+
loss_function = monai.losses.GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
|
48 |
+
elif config['LOSS'] == "fdice":
|
49 |
+
loss_function = monai.losses.DiceFocalLoss(
|
50 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceFocalLoss(
|
51 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True)
|
52 |
+
elif config['LOSS'] == "wfdice":
|
53 |
+
loss_function = WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True, lambda_weak=config['LAMBDA_WEAK']) if len(config['KEEP_CLASSES'])<=2 else WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True, lambda_weak=config['LAMBDA_WEAK'])
|
54 |
+
else:
|
55 |
+
loss_function = monai.losses.DiceLoss(
|
56 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
57 |
+
reduction="mean", to_onehot_y=True, sigmoid=True, squared_pred=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceLoss(
|
58 |
+
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True, squared_pred=True)
|
59 |
+
|
60 |
+
eval_metrics = [
|
61 |
+
("sensitivity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='sensitivity', reduction="mean_batch")),
|
62 |
+
("specificity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='specificity', reduction="mean_batch")),
|
63 |
+
("accuracy", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='accuracy', reduction="mean_batch")),
|
64 |
+
("dice", DiceMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")),
|
65 |
+
("IoU", MeanIoU(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch"))
|
66 |
+
]
|
67 |
+
|
68 |
+
optimizer = torch.optim.Adam(model.parameters(), config['LEARNING_RATE'], weight_decay=1e-5, amsgrad=True)
|
69 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['MAX_EPOCHS'])
|
70 |
+
return loss_function, optimizer, lr_scheduler, eval_metrics
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def load_weights(model, config):
|
75 |
+
try:
|
76 |
+
model.load_state_dict(torch.load("checkpoints/" + config['PRETRAINED_WEIGHTS'] + ".pth", map_location=torch.device(config['DEVICE'])))
|
77 |
+
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
|
78 |
+
except Exception as e:
|
79 |
+
try:
|
80 |
+
model.load_state_dict(torch.load(config['PRETRAINED_WEIGHTS'], map_location=torch.device(config['DEVICE'])))
|
81 |
+
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
|
82 |
+
except Exception as e: # load
|
83 |
+
print("WARNING: weights were not loaded. ", e)
|
84 |
+
pass
|
85 |
+
|
86 |
+
return model
|
87 |
+
|
88 |
+
|
89 |
+
def build_model(config):
|
90 |
+
|
91 |
+
config = get_defaults(config)
|
92 |
+
|
93 |
+
dropout_prob = config['DROPOUT']
|
94 |
+
|
95 |
+
if "SegResNetVAE" in config["MODEL_NAME"]:
|
96 |
+
model = monai.networks.nets.SegResNetVAE(
|
97 |
+
input_image_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
|
98 |
+
vae_estimate_std=False,
|
99 |
+
vae_default_std=0.3,
|
100 |
+
vae_nz=256,
|
101 |
+
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
102 |
+
blocks_down=[1, 2, 2, 4],
|
103 |
+
blocks_up=[1, 1, 1],
|
104 |
+
init_filters=16,
|
105 |
+
in_channels=1,
|
106 |
+
norm='instance',
|
107 |
+
out_channels=len(config['KEEP_CLASSES']),
|
108 |
+
dropout_prob=dropout_prob,
|
109 |
+
).to(config['DEVICE'])
|
110 |
+
|
111 |
+
elif "SegResNet" in config["MODEL_NAME"]:
|
112 |
+
model = monai.networks.nets.SegResNet(
|
113 |
+
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
114 |
+
blocks_down=[1, 2, 2, 4],
|
115 |
+
blocks_up=[1, 1, 1],
|
116 |
+
init_filters=16,
|
117 |
+
in_channels=1,
|
118 |
+
out_channels=len(config['KEEP_CLASSES']),
|
119 |
+
dropout_prob=dropout_prob,
|
120 |
+
norm="instance"
|
121 |
+
).to(config['DEVICE'])
|
122 |
+
|
123 |
+
elif "SwinUNETR" in config["MODEL_NAME"]:
|
124 |
+
model = monai.networks.nets.SwinUNETR(
|
125 |
+
img_size=config['ROI_SIZE'],
|
126 |
+
in_channels=1,
|
127 |
+
out_channels=len(config['KEEP_CLASSES']),
|
128 |
+
feature_size=48,
|
129 |
+
drop_rate=dropout_prob,
|
130 |
+
attn_drop_rate=0.0,
|
131 |
+
dropout_path_rate=0.0,
|
132 |
+
use_checkpoint=True
|
133 |
+
).to(config['DEVICE'])
|
134 |
+
|
135 |
+
elif "UNETR" in config["MODEL_NAME"]:
|
136 |
+
model = monai.networks.nets.UNETR(
|
137 |
+
img_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
|
138 |
+
in_channels=1,
|
139 |
+
out_channels=len(config['KEEP_CLASSES']),
|
140 |
+
feature_size=16,
|
141 |
+
hidden_size=256,
|
142 |
+
mlp_dim=3072,
|
143 |
+
num_heads=8,
|
144 |
+
pos_embed="perceptron",
|
145 |
+
norm_name="instance",
|
146 |
+
res_block=True,
|
147 |
+
dropout_rate=dropout_prob,
|
148 |
+
).to(config['DEVICE'])
|
149 |
+
|
150 |
+
elif "MANet" in config["MODEL_NAME"]:
|
151 |
+
if "2D" in config["MODEL_NAME"]:
|
152 |
+
model = UNet2D(
|
153 |
+
1,
|
154 |
+
len(config['KEEP_CLASSES']),
|
155 |
+
pab_channels=64,
|
156 |
+
use_batchnorm=True
|
157 |
+
).to(config['DEVICE'])
|
158 |
+
else:
|
159 |
+
model = UNet3D(
|
160 |
+
1,
|
161 |
+
len(config['KEEP_CLASSES']),
|
162 |
+
pab_channels=32,
|
163 |
+
use_batchnorm=True
|
164 |
+
).to(config['DEVICE'])
|
165 |
+
|
166 |
+
elif "UNetPlusPlus" in config["MODEL_NAME"]:
|
167 |
+
model = monai.networks.nets.BasicUNetPlusPlus(
|
168 |
+
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
169 |
+
in_channels=1,
|
170 |
+
out_channels=len(config['KEEP_CLASSES']),
|
171 |
+
features=(32, 32, 64, 128, 256, 32),
|
172 |
+
norm="instance",
|
173 |
+
dropout=dropout_prob,
|
174 |
+
).to(config['DEVICE'])
|
175 |
+
|
176 |
+
elif "UNet1" in config['MODEL_NAME']:
|
177 |
+
model = monai.networks.nets.UNet(
|
178 |
+
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
179 |
+
in_channels=1,
|
180 |
+
out_channels=len(config['KEEP_CLASSES']),
|
181 |
+
channels=(16, 32, 64, 128, 256),
|
182 |
+
strides=(2, 2, 2, 2),
|
183 |
+
num_res_units=2,
|
184 |
+
norm="instance"
|
185 |
+
).to(config['DEVICE'])
|
186 |
+
|
187 |
+
elif "UNet2" in config['MODEL_NAME']:
|
188 |
+
model = monai.networks.nets.UNet(
|
189 |
+
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
190 |
+
in_channels=1,
|
191 |
+
out_channels=len(config['KEEP_CLASSES']),
|
192 |
+
channels=(32, 64, 128, 256),
|
193 |
+
strides=(2, 2, 2, 2),
|
194 |
+
num_res_units=4,
|
195 |
+
norm="instance"
|
196 |
+
).to(config['DEVICE'])
|
197 |
+
|
198 |
+
else:
|
199 |
+
print(config["MODEL_NAME"], "is not a valid model name")
|
200 |
+
return None
|
201 |
+
|
202 |
+
try:
|
203 |
+
if "3D" in config['MODEL_NAME']:
|
204 |
+
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1], config['ROI_SIZE'][2])))
|
205 |
+
else:
|
206 |
+
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1])))
|
207 |
+
except Exception as e:
|
208 |
+
print("could not load model summary:", e)
|
209 |
+
|
210 |
+
if config['PRETRAINED_WEIGHTS'] is not None and config['PRETRAINED_WEIGHTS']:
|
211 |
+
model = load_weights(model, config)
|
212 |
+
return model
|
213 |
+
|
214 |
+
|
215 |
+
def train(model, train_loader, val_loader, loss_function, eval_metrics, optimizer, config,
|
216 |
+
scheduler=None, writer=None, postprocessing_transforms = None, weak_labels = None):
|
217 |
+
|
218 |
+
if writer is None: writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
|
219 |
+
best_metric, best_metric_epoch = -1, -1
|
220 |
+
prev_metric, patience, patience_counter = 1, config['EARLY_STOPPING_PATIENCE'], 0
|
221 |
+
if config['AUTOCAST']: scaler = GradScaler() # Initialize GradScaler for mixed precision training
|
222 |
+
|
223 |
+
for epoch in range(config['MAX_EPOCHS']):
|
224 |
+
print("-" * 10)
|
225 |
+
model.train()
|
226 |
+
epoch_loss, step = 0, 0
|
227 |
+
with tqdm(train_loader) as progress_bar:
|
228 |
+
for batch_data in progress_bar:
|
229 |
+
step += 1
|
230 |
+
inputs, labels = batch_data["image"].to(config['DEVICE']), batch_data["mask"].to(config['DEVICE'])
|
231 |
+
|
232 |
+
# only train with batches that have tumor; skip those without tumor
|
233 |
+
if config['TYPE'] == "tumor":
|
234 |
+
if torch.sum(labels[:,-1]) == 0:
|
235 |
+
continue
|
236 |
+
|
237 |
+
# check input shapes
|
238 |
+
if inputs is None or labels is None:
|
239 |
+
continue
|
240 |
+
if inputs.shape[-1] != labels.shape[-1] or inputs.shape[0] != labels.shape[0]:
|
241 |
+
print("WARNING: Batch skipped. Image and mask shape does not match:", inputs.shape[0], labels.shape[0])
|
242 |
+
continue
|
243 |
+
|
244 |
+
optimizer.zero_grad()
|
245 |
+
if not config['AUTOCAST']:
|
246 |
+
|
247 |
+
# segmentation output
|
248 |
+
outputs = model(inputs)
|
249 |
+
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
|
250 |
+
if isinstance(outputs, list): outputs = outputs[0]
|
251 |
+
|
252 |
+
# loss
|
253 |
+
if weak_labels is not None:
|
254 |
+
weak_label = torch.tensor([weak_labels[step]]).to(config['DEVICE'])
|
255 |
+
loss = loss_function(outputs, labels, weak_label) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
|
256 |
+
loss.backward()
|
257 |
+
optimizer.step()
|
258 |
+
|
259 |
+
else:
|
260 |
+
with autocast():
|
261 |
+
outputs = model(inputs)
|
262 |
+
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
|
263 |
+
if isinstance(outputs, list): outputs = outputs[0]
|
264 |
+
loss = loss_function(outputs, labels, [weak_labels[step]]) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
|
265 |
+
|
266 |
+
scaler.scale(loss).backward()
|
267 |
+
scaler.unscale_(optimizer)
|
268 |
+
if torch.isinf(loss).any():
|
269 |
+
print("Detected inf in gradients.")
|
270 |
+
else:
|
271 |
+
scaler.step(optimizer)
|
272 |
+
scaler.update()
|
273 |
+
|
274 |
+
epoch_loss += loss.item()
|
275 |
+
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss/step:.4f}')
|
276 |
+
|
277 |
+
epoch_loss /= step
|
278 |
+
writer.add_scalar("train_loss_epoch", epoch_loss, epoch)
|
279 |
+
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss:.4f}')
|
280 |
+
|
281 |
+
# validation
|
282 |
+
if (epoch + 1) % config['VAL_INTERVAL'] == 0:
|
283 |
+
|
284 |
+
# get a list of validation measures, pick one to be the decision maker
|
285 |
+
val_metrics, (val_images, val_labels, val_outputs) = evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms)
|
286 |
+
if isinstance(config['EVAL_METRIC'], list):
|
287 |
+
cur_metric = np.mean([val_metrics[m] for m in config['EVAL_METRIC']])
|
288 |
+
else:
|
289 |
+
cur_metric = val_metrics[config['EVAL_METRIC']]
|
290 |
+
|
291 |
+
# determine if better than previous best validation metric
|
292 |
+
if cur_metric > best_metric:
|
293 |
+
best_metric, best_metric_epoch = cur_metric, epoch + 1
|
294 |
+
torch.save(model.state_dict(), "checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth")
|
295 |
+
|
296 |
+
# early stopping
|
297 |
+
patience_counter = patience_counter + 1 if prev_metric > cur_metric else 0
|
298 |
+
if patience_counter == patience or epoch - best_metric_epoch > patience:
|
299 |
+
print("Early stopping at epoch", epoch + 1)
|
300 |
+
break
|
301 |
+
print(f'Current epoch: {epoch + 1} current avg {config["EVAL_METRIC"]}: {cur_metric :.4f} best avg {config["EVAL_METRIC"]}: {best_metric:.4f} at epoch {best_metric_epoch}')
|
302 |
+
prev_metric = cur_metric
|
303 |
+
|
304 |
+
# writer
|
305 |
+
for key, value in val_metrics.items():
|
306 |
+
writer.add_scalar("val_" + key, value, epoch)
|
307 |
+
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=len(val_outputs)//2, tag="image",frame_dim=-1)
|
308 |
+
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=len(val_outputs)//2, tag="label",frame_dim=-1)
|
309 |
+
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=len(val_outputs)//2, tag="output",frame_dim=-1)
|
310 |
+
|
311 |
+
# update scheduler
|
312 |
+
try:
|
313 |
+
if scheduler is not None: scheduler.step()
|
314 |
+
except:
|
315 |
+
pass
|
316 |
+
|
317 |
+
print(f"Train completed, best {config['EVAL_METRIC']}: {best_metric:.4f} at epoch: {best_metric_epoch}")
|
318 |
+
writer.close()
|
319 |
+
return model, writer
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
def evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms=None, use_liver_seg=False, export_filenames = [], export_file_metadata = []):
|
324 |
+
|
325 |
+
val_metrics = {}
|
326 |
+
model.eval()
|
327 |
+
with torch.no_grad():
|
328 |
+
|
329 |
+
step = 0
|
330 |
+
for val_data in val_loader:
|
331 |
+
# 3D: val_images has shape (1,C,H,W,Z)
|
332 |
+
# 2D: val_images has shape (B,C,H,W)
|
333 |
+
val_images, val_labels = val_data["image"].to(config['DEVICE']), val_data["mask"].to(config['DEVICE'])
|
334 |
+
if use_liver_seg: val_liver = val_data["pred_liver"].to(config['DEVICE'])
|
335 |
+
|
336 |
+
if (val_images[0].shape[-1] != val_labels[0].shape[-1]) or (
|
337 |
+
"3D" not in config["MODEL_NAME"] and val_images.shape[0] != val_labels.shape[0]):
|
338 |
+
print("WARNING: Batch skipped. Image and mask shape does not match:", val_images.shape, val_labels.shape)
|
339 |
+
continue
|
340 |
+
|
341 |
+
# convert outputs to probability
|
342 |
+
if "3D" in config["MODEL_NAME"]:
|
343 |
+
val_outputs = sw_inference(model, val_images, config['ROI_SIZE'], config['AUTOCAST'], discard_second_output='SegResNetVAE' in config['MODEL_NAME'])
|
344 |
+
else:
|
345 |
+
if "SegResNetVAE" in config["MODEL_NAME"]: val_outputs, _ = model(val_images)
|
346 |
+
else: val_outputs = model(val_images)
|
347 |
+
|
348 |
+
# post-procesing
|
349 |
+
if postprocessing_transforms is not None:
|
350 |
+
val_outputs = [postprocessing_transforms(i) for i in decollate_batch(val_outputs)]
|
351 |
+
|
352 |
+
# remove tumor predictions outside liver
|
353 |
+
for i in range(len(val_outputs)):
|
354 |
+
val_outputs[i][-1][torch.where(val_images[i][0] <= 1e-6)] = 0
|
355 |
+
|
356 |
+
# apply morphological snakes algorithm
|
357 |
+
if config['POSTPROCESSING_MORF']:
|
358 |
+
for i in range(len(val_outputs)):
|
359 |
+
val_outputs[i][-1] = torch.from_numpy(ms.morphological_chan_vese(val_images[i][0].cpu(), iterations=2, init_level_set=val_outputs[i][-1].cpu())).to(config['DEVICE'])
|
360 |
+
|
361 |
+
for i in range(len(val_outputs)):
|
362 |
+
if use_liver_seg:
|
363 |
+
# use liver model outputs for liver channel
|
364 |
+
val_outputs[i][1] = val_liver[i]
|
365 |
+
# if region is tumor, assign liver prediction to 0
|
366 |
+
val_outputs[i][1] -= val_outputs[i][2]
|
367 |
+
|
368 |
+
# compute metric for current iteration
|
369 |
+
for metric_name, metric in eval_metrics:
|
370 |
+
if isinstance(val_outputs[0], list):
|
371 |
+
val_outputs = val_outputs[0]
|
372 |
+
metric(val_outputs, val_labels)
|
373 |
+
|
374 |
+
# save prediction to local folder
|
375 |
+
if len(export_filenames) > 0:
|
376 |
+
for _ in range(len(val_outputs)):
|
377 |
+
numpy_array = val_outputs[_].cpu().detach().numpy()
|
378 |
+
write(export_filenames[step], numpy_array[-1], header=export_file_metadata[step])
|
379 |
+
print(" Segmentation exported to", export_filenames[step])
|
380 |
+
step += 1
|
381 |
+
|
382 |
+
# aggregate the final mean metric
|
383 |
+
for metric_name, metric in eval_metrics:
|
384 |
+
if "dice" in metric_name or "IoU" in metric_name: metric_value = metric.aggregate().tolist()
|
385 |
+
else: metric_value = metric.aggregate()[0].tolist() # a list of accuracies, one per class
|
386 |
+
val_metrics[metric_name + "_avg"] = np.mean(metric_value)
|
387 |
+
if config['TYPE'] != "liver":
|
388 |
+
for c in range(1, len(metric_value) + 1): # class-wise accuracies
|
389 |
+
val_metrics[metric_name + "_class" + str(c)] = metric_value[c-1]
|
390 |
+
metric.reset()
|
391 |
+
|
392 |
+
return val_metrics, (val_images, val_labels, val_outputs)
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
def get_defaults(config):
|
398 |
+
|
399 |
+
if 'TRAIN' not in config.keys(): config['TRAIN'] = True
|
400 |
+
if 'VALID_PATIENT_RATIO' not in config.keys(): config['VALID_PATIENT_RATIO'] = 0.2
|
401 |
+
if 'VAL_INTERVAL' not in config.keys(): config['VAL_INTERVAL'] = 1
|
402 |
+
if 'VAL_INTERVAL' not in config.keys(): config['DROPOUT'] = 0.1
|
403 |
+
if 'EARLY_STOPPING_PATIENCE' not in config.keys(): config['EARLY_STOPPING_PATIENCE'] = 20
|
404 |
+
if 'AUTOCAST' not in config.keys(): config['AUTOCAST'] = False
|
405 |
+
if 'NUM_WORKERS' not in config.keys(): config['NUM_WORKERS'] = 0
|
406 |
+
if 'DROPOUT' not in config.keys(): config['DROPOUT'] = 0.1
|
407 |
+
if 'ONESAMPLETESTRUN' not in config.keys(): config['ONESAMPLETESTRUN'] = False
|
408 |
+
if 'TRAIN' not in config.keys(): config['TRAIN'] = True
|
409 |
+
if 'DATA_AUGMENTATION' not in config.keys(): config['DATA_AUGMENTATION'] = False
|
410 |
+
if 'POSTPROCESSING_MORF' not in config.keys(): config['POSTPROCESSING_MORF'] = False
|
411 |
+
if 'PREPROCESSING' not in config.keys(): config['PREPROCESSING'] = ""
|
412 |
+
if 'PRETRAINED_WEIGHTS' not in config.keys(): config['PRETRAINED_WEIGHTS'] = ""
|
413 |
+
|
414 |
+
if 'EVAL_INCLUDE_BACKGROUND' not in config.keys():
|
415 |
+
if config['TYPE'] == "liver": config['EVAL_INCLUDE_BACKGROUND'] = True
|
416 |
+
else: config['EVAL_INCLUDE_BACKGROUND'] = False
|
417 |
+
if 'EVAL_METRIC' not in config.keys():
|
418 |
+
if config['TYPE'] == "liver": config['EVAL_METRIC'] = ["dice_avg"]
|
419 |
+
else: config['EVAL_METRIC'] = ["dice_class2"]
|
420 |
+
|
421 |
+
if 'CLINICAL_DATA_FILE' not in config.keys(): config['CLINICAL_DATA_FILE'] = "Dataset/HCC-TACE-Seg_clinical_data-V2.xlsx"
|
422 |
+
if 'CLINICAL_PREDICTORS' not in config.keys(): config['CLINICAL_PREDICTORS'] = ['T_involvment', 'CLIP_Score','Personal history of cancer', 'TNM', 'Metastasis','fhx_can', 'Alcohol', 'Smoking', 'Evidence_of_cirh', 'AFP', 'age', 'Diabetes', 'Lymphnodes', 'Interval_BL', 'TTP']
|
423 |
+
if 'LAMBDA_WEAK' not in config.keys(): config['LAMBDA_WEAK'] = 0.5
|
424 |
+
if 'MASKNONLIVER' not in config.keys(): config['MASKNONLIVER'] = False
|
425 |
+
|
426 |
+
if config['TYPE'] == "liver": config['KEEP_CLASSES']=["normal", "liver"]
|
427 |
+
elif config['TYPE'] == "tumor": config['KEEP_CLASSES']=["normal", "liver", "tumor"]
|
428 |
+
else: config['KEEP_CLASSES'] = ["normal", "liver", "tumor", "portal vein", "abdominal aorta"]
|
429 |
+
|
430 |
+
config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
431 |
+
config['EXPORT_FILE_NAME'] = config['TYPE']+ "_" + config['MODEL_NAME'] + "_" + config['LOSS'] + "_batchsize" + str(config['BATCH_SIZE']) + "_DA" + str(config['DATA_AUGMENTATION']) + "_HU" + str(config['HU_RANGE'][0]) + "-" + str(config['HU_RANGE'][1]) + "_" + config['PREPROCESSING'] + "_" + str(config['ROI_SIZE'][0]) + "_" + str(config['ROI_SIZE'][1]) + "_" + str(config['ROI_SIZE'][2]) + "_dropout" + str(config['DROPOUT'])
|
432 |
+
if config['MASKNONLIVER']: config['EXPORT_FILE_NAME'] += "_wobackground"
|
433 |
+
if config['LOSS'] == "wfdice": config['EXPORT_FILE_NAME'] += "_weaklambda" + str(config['LAMBDA_WEAK'])
|
434 |
+
if config['PRETRAINED_WEIGHTS'] != "" and config['PRETRAINED_WEIGHTS'] != config['EXPORT_FILE_NAME']: config['EXPORT_FILE_NAME'] += "_pretraining"
|
435 |
+
if config['POSTPROCESSING_MORF']: config['EXPORT_FILE_NAME'] += "_wpostmorf"
|
436 |
+
if not config['EVAL_INCLUDE_BACKGROUND']: config['EXPORT_FILE_NAME'] += "_evalnobackground"
|
437 |
+
|
438 |
+
return config
|
439 |
+
|
440 |
+
|
441 |
+
def train_clinical(df_clinical):
|
442 |
+
|
443 |
+
clinical_model = LinearRegression()
|
444 |
+
|
445 |
+
# train model
|
446 |
+
print("Training model using", df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'].shape[1], "features")
|
447 |
+
print(df_clinical.head())
|
448 |
+
clinical_model.fit(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'], df_clinical['tumor_ratio'])
|
449 |
+
|
450 |
+
# obtain predicted ratios
|
451 |
+
pred = clinical_model.predict(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'])
|
452 |
+
|
453 |
+
# evaluate
|
454 |
+
corr = np.corrcoef(pred, df_clinical['tumor_ratio'])[0][1]
|
455 |
+
mae = np.mean(np.abs(pred - df_clinical['tumor_ratio']))
|
456 |
+
print(f"The clinical model was fitted. Corr = {corr: .6f} MAE = {mae: .6f}")
|
457 |
+
|
458 |
+
return pred
|
459 |
+
|
460 |
+
|
461 |
+
def model_pipeline(config=None, plot=True):
|
462 |
+
|
463 |
+
torch.cuda.empty_cache()
|
464 |
+
config = get_defaults(config)
|
465 |
+
print(f"You Are Running on a: {config['DEVICE']}")
|
466 |
+
print("file name:", config['EXPORT_FILE_NAME'])
|
467 |
+
|
468 |
+
writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
|
469 |
+
|
470 |
+
# prepare data
|
471 |
+
train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train = build_dataset(config, get_clinical=config['LOSS']=="wfdice")
|
472 |
+
|
473 |
+
# train clinical model
|
474 |
+
if config['LOSS'] == "wfdice": weak_labels = train_clinical(df_clinical_train)
|
475 |
+
else: weak_labels = None
|
476 |
+
|
477 |
+
# train segmentation model
|
478 |
+
model = build_model(config)
|
479 |
+
loss_function, optimizer, lr_scheduler, eval_metrics = build_optimizer(model, config)
|
480 |
+
if config['TRAIN']:
|
481 |
+
train(model, train_loader, valid_loader, loss_function, eval_metrics, optimizer, config, lr_scheduler, writer, postprocessing_transforms, weak_labels)
|
482 |
+
model.load_state_dict(torch.load("checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth", map_location=torch.device(config['DEVICE'])))
|
483 |
+
if config['ONESAMPLETESTRUN']:
|
484 |
+
return None, None, None
|
485 |
+
|
486 |
+
# test segmentation model
|
487 |
+
test_metrics, (test_images, test_labels, test_outputs) = evaluate(model, test_loader, eval_metrics, config, postprocessing_transforms)
|
488 |
+
print("Test metrics")
|
489 |
+
for key, value in test_metrics.items():
|
490 |
+
print(f" {key}: {value:.4f}")
|
491 |
+
|
492 |
+
# visualize
|
493 |
+
if plot:
|
494 |
+
if "3D" in config['MODEL_NAME']:
|
495 |
+
visualize_patient(test_images[0].cpu(), mask=test_labels[0].cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
496 |
+
visualize_patient(test_images[0].cpu(), mask=test_outputs[0].cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
497 |
+
else:
|
498 |
+
visualize_patient(test_images.cpu(), mask=test_labels.cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
499 |
+
visualize_patient(test_images.cpu(), mask=torch.stack(test_outputs).cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
500 |
+
|
501 |
+
return (test_images, test_labels, test_outputs)
|
utils/sliding_window.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Callable, Sequence
|
2 |
+
from typing import Any, Iterable
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from monai.data.meta_tensor import MetaTensor
|
7 |
+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
|
8 |
+
from monai.inferers.utils import _create_buffered_slices, _compute_coords, _get_scan_interval, _flatten_struct, _pack_struct
|
9 |
+
from monai.utils import (
|
10 |
+
BlendMode,
|
11 |
+
PytorchPadMode,
|
12 |
+
convert_data_type,
|
13 |
+
convert_to_dst_type,
|
14 |
+
ensure_tuple,
|
15 |
+
ensure_tuple_rep,
|
16 |
+
fall_back_tuple,
|
17 |
+
look_up_option,
|
18 |
+
optional_import,
|
19 |
+
pytorch_after,
|
20 |
+
)
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
# Adapted from monai
|
24 |
+
def sliding_window_inference(
|
25 |
+
inputs: torch.Tensor | MetaTensor,
|
26 |
+
roi_size: Sequence[int] | int,
|
27 |
+
sw_batch_size: int,
|
28 |
+
predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
|
29 |
+
overlap: Sequence[float] | float = 0.25,
|
30 |
+
mode: BlendMode | str = BlendMode.CONSTANT,
|
31 |
+
sigma_scale: Sequence[float] | float = 0.125,
|
32 |
+
padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,
|
33 |
+
cval: float = 0.0,
|
34 |
+
sw_device: torch.device | str | None = None,
|
35 |
+
device: torch.device | str | None = None,
|
36 |
+
progress: bool = False,
|
37 |
+
roi_weight_map: torch.Tensor | None = None,
|
38 |
+
process_fn: Callable | None = None,
|
39 |
+
buffer_steps: int | None = None,
|
40 |
+
buffer_dim: int = -1,
|
41 |
+
with_coord: bool = False,
|
42 |
+
discard_second_output: bool = False,
|
43 |
+
*args: Any,
|
44 |
+
**kwargs: Any,
|
45 |
+
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
|
46 |
+
"""
|
47 |
+
Sliding window inference on `inputs` with `predictor`.
|
48 |
+
|
49 |
+
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
|
50 |
+
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
|
51 |
+
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
|
52 |
+
could be ([128,64,256], [64,32,128]).
|
53 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
|
54 |
+
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
|
55 |
+
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
|
56 |
+
|
57 |
+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
|
58 |
+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
inputs: input image to be processed (assuming NCHW[D])
|
62 |
+
roi_size: the spatial window size for inferences.
|
63 |
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
64 |
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
65 |
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
66 |
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
67 |
+
sw_batch_size: the batch size to run window slices.
|
68 |
+
predictor: given input tensor ``patch_data`` in shape NCHW[D],
|
69 |
+
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
|
70 |
+
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
|
71 |
+
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
|
72 |
+
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
|
73 |
+
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
|
74 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
|
75 |
+
to ensure the scaled output ROI sizes are still integers.
|
76 |
+
If the `predictor`'s input and output spatial sizes are different,
|
77 |
+
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
|
78 |
+
overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.
|
79 |
+
mode: {``"constant"``, ``"gaussian"``}
|
80 |
+
How to blend output of overlapping windows. Defaults to ``"constant"``.
|
81 |
+
|
82 |
+
- ``"constant``": gives equal weight to all predictions.
|
83 |
+
- ``"gaussian``": gives less weight to predictions on edges of windows.
|
84 |
+
|
85 |
+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
|
86 |
+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
|
87 |
+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
|
88 |
+
spatial dimensions.
|
89 |
+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
|
90 |
+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
|
91 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
92 |
+
cval: fill value for 'constant' padding mode. Default: 0
|
93 |
+
sw_device: device for the window data.
|
94 |
+
By default the device (and accordingly the memory) of the `inputs` is used.
|
95 |
+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
|
96 |
+
device: device for the stitched output prediction.
|
97 |
+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
|
98 |
+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
|
99 |
+
`inputs` and `roi_size`. Output is on the `device`.
|
100 |
+
progress: whether to print a `tqdm` progress bar.
|
101 |
+
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
|
102 |
+
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
|
103 |
+
process_fn: process inference output and adjust the importance map per window
|
104 |
+
buffer_steps: the number of sliding window iterations along the ``buffer_dim``
|
105 |
+
to be buffered on ``sw_device`` before writing to ``device``.
|
106 |
+
(Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)
|
107 |
+
default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,
|
108 |
+
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
|
109 |
+
buffer_dim: the spatial dimension along which the buffers are created.
|
110 |
+
0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
|
111 |
+
with_coord: whether to pass the window coordinates to ``predictor``. Default is False.
|
112 |
+
If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.
|
113 |
+
args: optional args to be passed to ``predictor``.
|
114 |
+
kwargs: optional keyword args to be passed to ``predictor``.
|
115 |
+
|
116 |
+
Note:
|
117 |
+
- input must be channel-first and have a batch dim, supports N-D sliding window.
|
118 |
+
|
119 |
+
"""
|
120 |
+
buffered = buffer_steps is not None and buffer_steps > 0
|
121 |
+
num_spatial_dims = len(inputs.shape) - 2
|
122 |
+
if buffered:
|
123 |
+
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
|
124 |
+
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
|
125 |
+
if buffer_dim < 0:
|
126 |
+
buffer_dim += num_spatial_dims
|
127 |
+
overlap = ensure_tuple_rep(overlap, num_spatial_dims)
|
128 |
+
for o in overlap:
|
129 |
+
if o < 0 or o >= 1:
|
130 |
+
raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.")
|
131 |
+
compute_dtype = inputs.dtype
|
132 |
+
|
133 |
+
# determine image spatial size and batch size
|
134 |
+
# Note: all input images must have the same image size and batch size
|
135 |
+
batch_size, _, *image_size_ = inputs.shape
|
136 |
+
device = device or inputs.device
|
137 |
+
sw_device = sw_device or inputs.device
|
138 |
+
|
139 |
+
temp_meta = None
|
140 |
+
if isinstance(inputs, MetaTensor):
|
141 |
+
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
|
142 |
+
inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0]
|
143 |
+
roi_size = fall_back_tuple(roi_size, image_size_)
|
144 |
+
|
145 |
+
# in case that image size is smaller than roi size
|
146 |
+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
|
147 |
+
pad_size = []
|
148 |
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
149 |
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
150 |
+
half = diff // 2
|
151 |
+
pad_size.extend([half, diff - half])
|
152 |
+
if any(pad_size):
|
153 |
+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
|
154 |
+
|
155 |
+
# Store all slices
|
156 |
+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
|
157 |
+
slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered)
|
158 |
+
|
159 |
+
num_win = len(slices) # number of windows per image
|
160 |
+
total_slices = num_win * batch_size # total number of windows
|
161 |
+
windows_range: Iterable
|
162 |
+
if not buffered:
|
163 |
+
non_blocking = False
|
164 |
+
windows_range = range(0, total_slices, sw_batch_size)
|
165 |
+
else:
|
166 |
+
slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(
|
167 |
+
slices, batch_size, sw_batch_size, buffer_dim, buffer_steps
|
168 |
+
)
|
169 |
+
non_blocking, _ss = torch.cuda.is_available(), -1
|
170 |
+
for x in b_slices[:n_per_batch]:
|
171 |
+
if x[1] < _ss: # detect overlapping slices
|
172 |
+
non_blocking = False
|
173 |
+
break
|
174 |
+
_ss = x[2]
|
175 |
+
|
176 |
+
# Create window-level importance map
|
177 |
+
valid_patch_size = get_valid_patch_size(image_size, roi_size)
|
178 |
+
if valid_patch_size == roi_size and (roi_weight_map is not None):
|
179 |
+
importance_map_ = roi_weight_map
|
180 |
+
else:
|
181 |
+
try:
|
182 |
+
valid_p_size = ensure_tuple(valid_patch_size)
|
183 |
+
importance_map_ = compute_importance_map(
|
184 |
+
valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype
|
185 |
+
)
|
186 |
+
if len(importance_map_.shape) == num_spatial_dims and not process_fn:
|
187 |
+
importance_map_ = importance_map_[None, None] # adds batch, channel dimensions
|
188 |
+
except Exception as e:
|
189 |
+
raise RuntimeError(
|
190 |
+
f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n"
|
191 |
+
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
|
192 |
+
) from e
|
193 |
+
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0]
|
194 |
+
|
195 |
+
# stores output and count map
|
196 |
+
output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore
|
197 |
+
# for each patch
|
198 |
+
for slice_g in tqdm(windows_range) if progress else windows_range:
|
199 |
+
slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices))
|
200 |
+
unravel_slice = [
|
201 |
+
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
|
202 |
+
for idx in slice_range
|
203 |
+
]
|
204 |
+
if sw_batch_size > 1:
|
205 |
+
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
206 |
+
else:
|
207 |
+
win_data = inputs[unravel_slice[0]].to(sw_device)
|
208 |
+
if with_coord:
|
209 |
+
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
|
210 |
+
if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0]
|
211 |
+
else:
|
212 |
+
seg_prob_out = predictor(win_data, *args, **kwargs)
|
213 |
+
if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0]
|
214 |
+
|
215 |
+
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
|
216 |
+
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
|
217 |
+
if process_fn:
|
218 |
+
seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_)
|
219 |
+
else:
|
220 |
+
w_t = importance_map_
|
221 |
+
if len(w_t.shape) == num_spatial_dims:
|
222 |
+
w_t = w_t[None, None]
|
223 |
+
w_t = w_t.to(dtype=compute_dtype, device=sw_device)
|
224 |
+
if buffered:
|
225 |
+
c_start, c_end = b_slices[b_s][1:]
|
226 |
+
if not sw_device_buffer:
|
227 |
+
k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored
|
228 |
+
sp_size = list(image_size)
|
229 |
+
sp_size[buffer_dim] = c_end - c_start
|
230 |
+
sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]
|
231 |
+
for p, s in zip(seg_tuple[0], unravel_slice):
|
232 |
+
offset = s[buffer_dim + 2].start - c_start
|
233 |
+
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
|
234 |
+
s[0] = slice(0, 1)
|
235 |
+
sw_device_buffer[0][s] += p * w_t
|
236 |
+
b_i += len(unravel_slice)
|
237 |
+
if b_i < b_slices[b_s][0]:
|
238 |
+
continue
|
239 |
+
else:
|
240 |
+
sw_device_buffer = list(seg_tuple)
|
241 |
+
|
242 |
+
for ss in range(len(sw_device_buffer)):
|
243 |
+
b_shape = sw_device_buffer[ss].shape
|
244 |
+
seg_chns, seg_shape = b_shape[1], b_shape[2:]
|
245 |
+
z_scale = None
|
246 |
+
if not buffered and seg_shape != roi_size:
|
247 |
+
z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]
|
248 |
+
w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)
|
249 |
+
if len(output_image_list) <= ss:
|
250 |
+
output_shape = [batch_size, seg_chns]
|
251 |
+
output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)
|
252 |
+
# allocate memory to store the full output and the count for overlapping parts
|
253 |
+
new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore
|
254 |
+
output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))
|
255 |
+
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
|
256 |
+
w_t_ = w_t.to(device)
|
257 |
+
for __s in slices:
|
258 |
+
if z_scale is not None:
|
259 |
+
__s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))
|
260 |
+
count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_
|
261 |
+
if buffered:
|
262 |
+
o_slice = [slice(None)] * len(inputs.shape)
|
263 |
+
o_slice[buffer_dim + 2] = slice(c_start, c_end)
|
264 |
+
img_b = b_s // n_per_batch # image batch index
|
265 |
+
o_slice[0] = slice(img_b, img_b + 1)
|
266 |
+
if non_blocking:
|
267 |
+
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
|
268 |
+
else:
|
269 |
+
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
|
270 |
+
else:
|
271 |
+
sw_device_buffer[ss] *= w_t
|
272 |
+
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
|
273 |
+
_compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])
|
274 |
+
sw_device_buffer = []
|
275 |
+
if buffered:
|
276 |
+
b_s += 1
|
277 |
+
|
278 |
+
if non_blocking:
|
279 |
+
torch.cuda.current_stream().synchronize()
|
280 |
+
|
281 |
+
# account for any overlapping sections
|
282 |
+
for ss in range(len(output_image_list)):
|
283 |
+
output_image_list[ss] /= count_map_list.pop(0)
|
284 |
+
|
285 |
+
# remove padding if image_size smaller than roi_size
|
286 |
+
if any(pad_size):
|
287 |
+
for ss, output_i in enumerate(output_image_list):
|
288 |
+
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
|
289 |
+
final_slicing: list[slice] = []
|
290 |
+
for sp in range(num_spatial_dims):
|
291 |
+
si = num_spatial_dims - sp - 1
|
292 |
+
slice_dim = slice(
|
293 |
+
int(round(pad_size[sp * 2] * zoom_scale[si])),
|
294 |
+
int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])),
|
295 |
+
)
|
296 |
+
final_slicing.insert(0, slice_dim)
|
297 |
+
output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]
|
298 |
+
|
299 |
+
final_output = _pack_struct(output_image_list, dict_keys)
|
300 |
+
if temp_meta is not None:
|
301 |
+
final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0]
|
302 |
+
else:
|
303 |
+
final_output = convert_to_dst_type(final_output, inputs, device=device)[0]
|
304 |
+
|
305 |
+
return final_output # type: ignore
|
306 |
+
|
307 |
+
|
308 |
+
def sw_inference(model, input, roi_size, autocast_on, discard_second_output, overlap=0.8):
|
309 |
+
def _compute(input):
|
310 |
+
return sliding_window_inference(
|
311 |
+
inputs=input,
|
312 |
+
roi_size=roi_size,
|
313 |
+
sw_batch_size=1,
|
314 |
+
predictor=model,
|
315 |
+
overlap=overlap,
|
316 |
+
progress=False,
|
317 |
+
mode="constant",
|
318 |
+
discard_second_output=discard_second_output
|
319 |
+
)
|
320 |
+
|
321 |
+
if autocast_on:
|
322 |
+
with torch.cuda.amp.autocast():
|
323 |
+
return _compute(input)
|
324 |
+
else:
|
325 |
+
return _compute(input)
|
326 |
+
|
327 |
+
|
328 |
+
|
utils/tumor_features.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage import label, find_objects
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
IMAGE_SPACING_X = 0.7031
|
7 |
+
IMAGE_SPACING_Y = 0.7031
|
8 |
+
IMAGE_SPACING_Z = 2.5
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def compute_largest_diameter(binary_mask):
|
13 |
+
|
14 |
+
# Label connected components in the binary mask
|
15 |
+
labeled_array, num_features = label(binary_mask)
|
16 |
+
|
17 |
+
# Find the objects (tumors) in the labeled array
|
18 |
+
tumor_objects = find_objects(labeled_array)
|
19 |
+
|
20 |
+
# Initialize the largest diameter variable
|
21 |
+
largest_diameter = 0
|
22 |
+
|
23 |
+
# Iterate through each tumor object
|
24 |
+
for obj in tumor_objects:
|
25 |
+
# Calculate the dimensions of the tumor object
|
26 |
+
z_dim = obj[2].stop - obj[2].start
|
27 |
+
y_dim = obj[1].stop - obj[1].start
|
28 |
+
x_dim = obj[0].stop - obj[0].start
|
29 |
+
|
30 |
+
# Calculate the diameter using the longest dimension
|
31 |
+
diameter = max(z_dim * IMAGE_SPACING_Z, y_dim * IMAGE_SPACING_Y, x_dim * IMAGE_SPACING_X)
|
32 |
+
|
33 |
+
# Update the largest diameter if necessary
|
34 |
+
if diameter > largest_diameter:
|
35 |
+
largest_diameter = diameter
|
36 |
+
|
37 |
+
return largest_diameter / 10 # IN CM
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def generate_features(img, liver, tumor):
|
43 |
+
|
44 |
+
contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
45 |
+
|
46 |
+
|
47 |
+
features = {
|
48 |
+
"lesion size (cm)": compute_largest_diameter(tumor),
|
49 |
+
"lesion shape": "irregular",
|
50 |
+
"lesion density (HU)": np.mean(img[tumor==1]),
|
51 |
+
"involvement of adjacent organs:": "Yes" if np.sum(np.multiply(liver==0, tumor)) > 0 else "No"
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
return features
|
utils/visualization.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""):
|
8 |
+
"""
|
9 |
+
img: tensor [C, H, W, Z]
|
10 |
+
mask: tensor [C, H, W, Z]
|
11 |
+
pred: tensor [C, H, W, Z]
|
12 |
+
n_slices: number of slices to visualize
|
13 |
+
slices: list of slices to visualize
|
14 |
+
title; title of the plot
|
15 |
+
"""
|
16 |
+
if slices is not None:
|
17 |
+
n_slices = len(slices)
|
18 |
+
|
19 |
+
fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices))
|
20 |
+
inc = img.shape[-1] // n_slices
|
21 |
+
mask_masked = np.ma.masked_where(mask == 0, mask)
|
22 |
+
pred_masked = np.ma.masked_where(pred == 0, pred)
|
23 |
+
|
24 |
+
for i in range(n_slices):
|
25 |
+
slice_num = i*inc if slices is None else slices[i]
|
26 |
+
|
27 |
+
# image
|
28 |
+
for c in range(3):
|
29 |
+
ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray")
|
30 |
+
ax[i,c].axis("off")
|
31 |
+
ax[i,c].set_title(f'image')
|
32 |
+
|
33 |
+
# ground truth
|
34 |
+
ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
|
35 |
+
ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
|
36 |
+
ax[i,1].set_title(f'ground truth')
|
37 |
+
|
38 |
+
# predicted
|
39 |
+
ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
|
40 |
+
ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
|
41 |
+
ax[i,2].set_title(f'predicted')
|
42 |
+
|
43 |
+
plt.suptitle(title, size=14)
|
44 |
+
plt.tight_layout()
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
|
48 |
+
def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""):
|
49 |
+
"""
|
50 |
+
img: tensor [C, H, W, Z]
|
51 |
+
mask: tensor [C, H, W, Z]
|
52 |
+
n: number of slices to visualize
|
53 |
+
"""
|
54 |
+
if slices is not None:
|
55 |
+
n_slices = len(slices)
|
56 |
+
|
57 |
+
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
|
58 |
+
if z_dim_last: inc = img.shape[-1] // n_slices
|
59 |
+
else: inc = img.shape[0] // n_slices
|
60 |
+
masked = np.ma.masked_where(mask == 0, mask)
|
61 |
+
|
62 |
+
for i in range(n_slices):
|
63 |
+
r, c = divmod(i, 3)
|
64 |
+
slice_num = i*inc if slices is None else slices[i]
|
65 |
+
if n_slices <= 3:
|
66 |
+
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
|
67 |
+
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
|
68 |
+
ax[c].axis("off")
|
69 |
+
ax[c].set_title(f'slice {slice_num}')
|
70 |
+
if mask is not None:
|
71 |
+
if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
|
72 |
+
else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
|
73 |
+
else:
|
74 |
+
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
|
75 |
+
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
|
76 |
+
ax[r][c].axis("off")
|
77 |
+
ax[r][c].set_title(f'slice {slice_num}')
|
78 |
+
if mask is not None:
|
79 |
+
if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
|
80 |
+
else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
|
81 |
+
|
82 |
+
plt.suptitle(title, size=14)
|
83 |
+
#if mask is not None:
|
84 |
+
# cbar = fig.colorbar(mask_overlay, extend='both')
|
85 |
+
plt.tight_layout()
|
86 |
+
plt.show()
|
87 |
+
|
88 |
+
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
|
89 |
+
if z_dim_last: inc = img.shape[-1] // n_slices
|
90 |
+
else: inc = img.shape[0] // n_slices
|
91 |
+
|
92 |
+
for i in range(n_slices):
|
93 |
+
r, c = divmod(i, 3)
|
94 |
+
slice_num = i*inc if slices is None else slices[i]
|
95 |
+
if n_slices <= 3:
|
96 |
+
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
|
97 |
+
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
|
98 |
+
ax[c].axis("off")
|
99 |
+
ax[c].set_title(f'slice {slice_num}')
|
100 |
+
else:
|
101 |
+
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
|
102 |
+
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
|
103 |
+
ax[r][c].axis("off")
|
104 |
+
ax[r][c].set_title(f'slice {slice_num}')
|
105 |
+
|
106 |
+
plt.suptitle(title, size=14)
|
107 |
+
|
108 |
+
plt.tight_layout()
|
109 |
+
plt.show()
|