Spaces:
Sleeping
Sleeping
import logging | |
import sys | |
import tempfile | |
from glob import glob | |
from torchsummary import summary | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import torch | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.cuda.amp import autocast, GradScaler | |
import torch.nn as nn | |
import torchvision | |
import monai | |
from monai.metrics import DiceMetric, ConfusionMatrixMetric, MeanIoU | |
from monai.visualize import plot_2d_or_3d_image | |
from visualization import visualize_patient | |
from sliding_window import sw_inference | |
from data_preparation import build_dataset | |
from models import UNet2D, UNet3D | |
from loss import WeaklyDiceFocalLoss | |
from sklearn.linear_model import LinearRegression | |
from nrrd import write, read | |
import morphsnakes as ms | |
from monai.data import decollate_batch | |
def build_optimizer(model, config): | |
if config['LOSS'] == "gdice": | |
loss_function = monai.losses.GeneralizedDiceLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], | |
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.GeneralizedDiceLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True) | |
elif config['LOSS'] == 'cdice': | |
loss_function = monai.losses.DiceCELoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], | |
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceCELoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True) | |
elif config['LOSS'] == 'mdice': | |
loss_function = monai.losses.MaskedDiceLoss() | |
elif config['LOSS'] == 'wdice': | |
# Example with 3 classes (including the background: label 0). | |
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1. | |
# The distance between class 1 and class 2 is 0.5. | |
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32) | |
loss_function = monai.losses.GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat) | |
elif config['LOSS'] == "fdice": | |
loss_function = monai.losses.DiceFocalLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceFocalLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True) | |
elif config['LOSS'] == "wfdice": | |
loss_function = WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True, lambda_weak=config['LAMBDA_WEAK']) if len(config['KEEP_CLASSES'])<=2 else WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True, lambda_weak=config['LAMBDA_WEAK']) | |
else: | |
loss_function = monai.losses.DiceLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], | |
reduction="mean", to_onehot_y=True, sigmoid=True, squared_pred=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceLoss( | |
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True, squared_pred=True) | |
eval_metrics = [ | |
("sensitivity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='sensitivity', reduction="mean_batch")), | |
("specificity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='specificity', reduction="mean_batch")), | |
("accuracy", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='accuracy', reduction="mean_batch")), | |
("dice", DiceMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")), | |
("IoU", MeanIoU(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")) | |
] | |
optimizer = torch.optim.Adam(model.parameters(), config['LEARNING_RATE'], weight_decay=1e-5, amsgrad=True) | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['MAX_EPOCHS']) | |
return loss_function, optimizer, lr_scheduler, eval_metrics | |
def load_weights(model, config): | |
try: | |
model.load_state_dict(torch.load("checkpoints/" + config['PRETRAINED_WEIGHTS'] + ".pth", map_location=torch.device(config['DEVICE']))) | |
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded") | |
except Exception as e: | |
try: | |
model.load_state_dict(torch.load(config['PRETRAINED_WEIGHTS'], map_location=torch.device(config['DEVICE']))) | |
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded") | |
except Exception as e: # load | |
print("WARNING: weights were not loaded. ", e) | |
pass | |
return model | |
def build_model(config): | |
config = get_defaults(config) | |
dropout_prob = config['DROPOUT'] | |
if "SegResNetVAE" in config["MODEL_NAME"]: | |
model = monai.networks.nets.SegResNetVAE( | |
input_image_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]), | |
vae_estimate_std=False, | |
vae_default_std=0.3, | |
vae_nz=256, | |
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2, | |
blocks_down=[1, 2, 2, 4], | |
blocks_up=[1, 1, 1], | |
init_filters=16, | |
in_channels=1, | |
norm='instance', | |
out_channels=len(config['KEEP_CLASSES']), | |
dropout_prob=dropout_prob, | |
).to(config['DEVICE']) | |
elif "SegResNet" in config["MODEL_NAME"]: | |
model = monai.networks.nets.SegResNet( | |
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2, | |
blocks_down=[1, 2, 2, 4], | |
blocks_up=[1, 1, 1], | |
init_filters=16, | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
dropout_prob=dropout_prob, | |
norm="instance" | |
).to(config['DEVICE']) | |
elif "SwinUNETR" in config["MODEL_NAME"]: | |
model = monai.networks.nets.SwinUNETR( | |
img_size=config['ROI_SIZE'], | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
feature_size=48, | |
drop_rate=dropout_prob, | |
attn_drop_rate=0.0, | |
dropout_path_rate=0.0, | |
use_checkpoint=True | |
).to(config['DEVICE']) | |
elif "UNETR" in config["MODEL_NAME"]: | |
model = monai.networks.nets.UNETR( | |
img_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]), | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
feature_size=16, | |
hidden_size=256, | |
mlp_dim=3072, | |
num_heads=8, | |
pos_embed="perceptron", | |
norm_name="instance", | |
res_block=True, | |
dropout_rate=dropout_prob, | |
).to(config['DEVICE']) | |
elif "MANet" in config["MODEL_NAME"]: | |
if "2D" in config["MODEL_NAME"]: | |
model = UNet2D( | |
1, | |
len(config['KEEP_CLASSES']), | |
pab_channels=64, | |
use_batchnorm=True | |
).to(config['DEVICE']) | |
else: | |
model = UNet3D( | |
1, | |
len(config['KEEP_CLASSES']), | |
pab_channels=32, | |
use_batchnorm=True | |
).to(config['DEVICE']) | |
elif "UNetPlusPlus" in config["MODEL_NAME"]: | |
model = monai.networks.nets.BasicUNetPlusPlus( | |
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2, | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
features=(32, 32, 64, 128, 256, 32), | |
norm="instance", | |
dropout=dropout_prob, | |
).to(config['DEVICE']) | |
elif "UNet1" in config['MODEL_NAME']: | |
model = monai.networks.nets.UNet( | |
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2, | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
channels=(16, 32, 64, 128, 256), | |
strides=(2, 2, 2, 2), | |
num_res_units=2, | |
norm="instance" | |
).to(config['DEVICE']) | |
elif "UNet2" in config['MODEL_NAME']: | |
model = monai.networks.nets.UNet( | |
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2, | |
in_channels=1, | |
out_channels=len(config['KEEP_CLASSES']), | |
channels=(32, 64, 128, 256), | |
strides=(2, 2, 2, 2), | |
num_res_units=4, | |
norm="instance" | |
).to(config['DEVICE']) | |
else: | |
print(config["MODEL_NAME"], "is not a valid model name") | |
return None | |
try: | |
if "3D" in config['MODEL_NAME']: | |
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1], config['ROI_SIZE'][2]))) | |
else: | |
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1]))) | |
except Exception as e: | |
print("could not load model summary:", e) | |
if config['PRETRAINED_WEIGHTS'] is not None and config['PRETRAINED_WEIGHTS']: | |
model = load_weights(model, config) | |
return model | |
def train(model, train_loader, val_loader, loss_function, eval_metrics, optimizer, config, | |
scheduler=None, writer=None, postprocessing_transforms = None, weak_labels = None): | |
if writer is None: writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME']) | |
best_metric, best_metric_epoch = -1, -1 | |
prev_metric, patience, patience_counter = 1, config['EARLY_STOPPING_PATIENCE'], 0 | |
if config['AUTOCAST']: scaler = GradScaler() # Initialize GradScaler for mixed precision training | |
for epoch in range(config['MAX_EPOCHS']): | |
print("-" * 10) | |
model.train() | |
epoch_loss, step = 0, 0 | |
with tqdm(train_loader) as progress_bar: | |
for batch_data in progress_bar: | |
step += 1 | |
inputs, labels = batch_data["image"].to(config['DEVICE']), batch_data["mask"].to(config['DEVICE']) | |
# only train with batches that have tumor; skip those without tumor | |
if config['TYPE'] == "tumor": | |
if torch.sum(labels[:,-1]) == 0: | |
continue | |
# check input shapes | |
if inputs is None or labels is None: | |
continue | |
if inputs.shape[-1] != labels.shape[-1] or inputs.shape[0] != labels.shape[0]: | |
print("WARNING: Batch skipped. Image and mask shape does not match:", inputs.shape[0], labels.shape[0]) | |
continue | |
optimizer.zero_grad() | |
if not config['AUTOCAST']: | |
# segmentation output | |
outputs = model(inputs) | |
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0] | |
if isinstance(outputs, list): outputs = outputs[0] | |
# loss | |
if weak_labels is not None: | |
weak_label = torch.tensor([weak_labels[step]]).to(config['DEVICE']) | |
loss = loss_function(outputs, labels, weak_label) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
else: | |
with autocast(): | |
outputs = model(inputs) | |
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0] | |
if isinstance(outputs, list): outputs = outputs[0] | |
loss = loss_function(outputs, labels, [weak_labels[step]]) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels) | |
scaler.scale(loss).backward() | |
scaler.unscale_(optimizer) | |
if torch.isinf(loss).any(): | |
print("Detected inf in gradients.") | |
else: | |
scaler.step(optimizer) | |
scaler.update() | |
epoch_loss += loss.item() | |
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss/step:.4f}') | |
epoch_loss /= step | |
writer.add_scalar("train_loss_epoch", epoch_loss, epoch) | |
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss:.4f}') | |
# validation | |
if (epoch + 1) % config['VAL_INTERVAL'] == 0: | |
# get a list of validation measures, pick one to be the decision maker | |
val_metrics, (val_images, val_labels, val_outputs) = evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms) | |
if isinstance(config['EVAL_METRIC'], list): | |
cur_metric = np.mean([val_metrics[m] for m in config['EVAL_METRIC']]) | |
else: | |
cur_metric = val_metrics[config['EVAL_METRIC']] | |
# determine if better than previous best validation metric | |
if cur_metric > best_metric: | |
best_metric, best_metric_epoch = cur_metric, epoch + 1 | |
torch.save(model.state_dict(), "checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth") | |
# early stopping | |
patience_counter = patience_counter + 1 if prev_metric > cur_metric else 0 | |
if patience_counter == patience or epoch - best_metric_epoch > patience: | |
print("Early stopping at epoch", epoch + 1) | |
break | |
print(f'Current epoch: {epoch + 1} current avg {config["EVAL_METRIC"]}: {cur_metric :.4f} best avg {config["EVAL_METRIC"]}: {best_metric:.4f} at epoch {best_metric_epoch}') | |
prev_metric = cur_metric | |
# writer | |
for key, value in val_metrics.items(): | |
writer.add_scalar("val_" + key, value, epoch) | |
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=len(val_outputs)//2, tag="image",frame_dim=-1) | |
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=len(val_outputs)//2, tag="label",frame_dim=-1) | |
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=len(val_outputs)//2, tag="output",frame_dim=-1) | |
# update scheduler | |
try: | |
if scheduler is not None: scheduler.step() | |
except: | |
pass | |
print(f"Train completed, best {config['EVAL_METRIC']}: {best_metric:.4f} at epoch: {best_metric_epoch}") | |
writer.close() | |
return model, writer | |
def evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms=None, use_liver_seg=False, export_filenames = [], export_file_metadata = []): | |
val_metrics = {} | |
model.eval() | |
with torch.no_grad(): | |
step = 0 | |
for val_data in val_loader: | |
# 3D: val_images has shape (1,C,H,W,Z) | |
# 2D: val_images has shape (B,C,H,W) | |
val_images, val_labels = val_data["image"].to(config['DEVICE']), val_data["mask"].to(config['DEVICE']) | |
if use_liver_seg: val_liver = val_data["pred_liver"].to(config['DEVICE']) | |
if (val_images[0].shape[-1] != val_labels[0].shape[-1]) or ( | |
"3D" not in config["MODEL_NAME"] and val_images.shape[0] != val_labels.shape[0]): | |
print("WARNING: Batch skipped. Image and mask shape does not match:", val_images.shape, val_labels.shape) | |
continue | |
# convert outputs to probability | |
if "3D" in config["MODEL_NAME"]: | |
val_outputs = sw_inference(model, val_images, config['ROI_SIZE'], config['AUTOCAST'], discard_second_output='SegResNetVAE' in config['MODEL_NAME']) | |
else: | |
if "SegResNetVAE" in config["MODEL_NAME"]: val_outputs, _ = model(val_images) | |
else: val_outputs = model(val_images) | |
# post-procesing | |
if postprocessing_transforms is not None: | |
val_outputs = [postprocessing_transforms(i) for i in decollate_batch(val_outputs)] | |
# remove tumor predictions outside liver | |
for i in range(len(val_outputs)): | |
val_outputs[i][-1][torch.where(val_images[i][0] <= 1e-6)] = 0 | |
# apply morphological snakes algorithm | |
if config['POSTPROCESSING_MORF']: | |
for i in range(len(val_outputs)): | |
val_outputs[i][-1] = torch.from_numpy(ms.morphological_chan_vese(val_images[i][0].cpu(), iterations=2, init_level_set=val_outputs[i][-1].cpu())).to(config['DEVICE']) | |
for i in range(len(val_outputs)): | |
if use_liver_seg: | |
# use liver model outputs for liver channel | |
val_outputs[i][1] = val_liver[i] | |
# if region is tumor, assign liver prediction to 0 | |
val_outputs[i][1] -= val_outputs[i][2] | |
# compute metric for current iteration | |
for metric_name, metric in eval_metrics: | |
if isinstance(val_outputs[0], list): | |
val_outputs = val_outputs[0] | |
metric(val_outputs, val_labels) | |
# save prediction to local folder | |
if len(export_filenames) > 0: | |
for _ in range(len(val_outputs)): | |
numpy_array = val_outputs[_].cpu().detach().numpy() | |
write(export_filenames[step], numpy_array[-1], header=export_file_metadata[step]) | |
print(" Segmentation exported to", export_filenames[step]) | |
step += 1 | |
# aggregate the final mean metric | |
for metric_name, metric in eval_metrics: | |
if "dice" in metric_name or "IoU" in metric_name: metric_value = metric.aggregate().tolist() | |
else: metric_value = metric.aggregate()[0].tolist() # a list of accuracies, one per class | |
val_metrics[metric_name + "_avg"] = np.mean(metric_value) | |
if config['TYPE'] != "liver": | |
for c in range(1, len(metric_value) + 1): # class-wise accuracies | |
val_metrics[metric_name + "_class" + str(c)] = metric_value[c-1] | |
metric.reset() | |
return val_metrics, (val_images, val_labels, val_outputs) | |
def get_defaults(config): | |
if 'TRAIN' not in config.keys(): config['TRAIN'] = True | |
if 'VALID_PATIENT_RATIO' not in config.keys(): config['VALID_PATIENT_RATIO'] = 0.2 | |
if 'VAL_INTERVAL' not in config.keys(): config['VAL_INTERVAL'] = 1 | |
if 'VAL_INTERVAL' not in config.keys(): config['DROPOUT'] = 0.1 | |
if 'EARLY_STOPPING_PATIENCE' not in config.keys(): config['EARLY_STOPPING_PATIENCE'] = 20 | |
if 'AUTOCAST' not in config.keys(): config['AUTOCAST'] = False | |
if 'NUM_WORKERS' not in config.keys(): config['NUM_WORKERS'] = 0 | |
if 'DROPOUT' not in config.keys(): config['DROPOUT'] = 0.1 | |
if 'ONESAMPLETESTRUN' not in config.keys(): config['ONESAMPLETESTRUN'] = False | |
if 'TRAIN' not in config.keys(): config['TRAIN'] = True | |
if 'DATA_AUGMENTATION' not in config.keys(): config['DATA_AUGMENTATION'] = False | |
if 'POSTPROCESSING_MORF' not in config.keys(): config['POSTPROCESSING_MORF'] = False | |
if 'PREPROCESSING' not in config.keys(): config['PREPROCESSING'] = "" | |
if 'PRETRAINED_WEIGHTS' not in config.keys(): config['PRETRAINED_WEIGHTS'] = "" | |
if 'EVAL_INCLUDE_BACKGROUND' not in config.keys(): | |
if config['TYPE'] == "liver": config['EVAL_INCLUDE_BACKGROUND'] = True | |
else: config['EVAL_INCLUDE_BACKGROUND'] = False | |
if 'EVAL_METRIC' not in config.keys(): | |
if config['TYPE'] == "liver": config['EVAL_METRIC'] = ["dice_avg"] | |
else: config['EVAL_METRIC'] = ["dice_class2"] | |
if 'CLINICAL_DATA_FILE' not in config.keys(): config['CLINICAL_DATA_FILE'] = "Dataset/HCC-TACE-Seg_clinical_data-V2.xlsx" | |
if 'CLINICAL_PREDICTORS' not in config.keys(): config['CLINICAL_PREDICTORS'] = ['T_involvment', 'CLIP_Score','Personal history of cancer', 'TNM', 'Metastasis','fhx_can', 'Alcohol', 'Smoking', 'Evidence_of_cirh', 'AFP', 'age', 'Diabetes', 'Lymphnodes', 'Interval_BL', 'TTP'] | |
if 'LAMBDA_WEAK' not in config.keys(): config['LAMBDA_WEAK'] = 0.5 | |
if 'MASKNONLIVER' not in config.keys(): config['MASKNONLIVER'] = False | |
if config['TYPE'] == "liver": config['KEEP_CLASSES']=["normal", "liver"] | |
elif config['TYPE'] == "tumor": config['KEEP_CLASSES']=["normal", "liver", "tumor"] | |
else: config['KEEP_CLASSES'] = ["normal", "liver", "tumor", "portal vein", "abdominal aorta"] | |
config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
config['EXPORT_FILE_NAME'] = config['TYPE']+ "_" + config['MODEL_NAME'] + "_" + config['LOSS'] + "_batchsize" + str(config['BATCH_SIZE']) + "_DA" + str(config['DATA_AUGMENTATION']) + "_HU" + str(config['HU_RANGE'][0]) + "-" + str(config['HU_RANGE'][1]) + "_" + config['PREPROCESSING'] + "_" + str(config['ROI_SIZE'][0]) + "_" + str(config['ROI_SIZE'][1]) + "_" + str(config['ROI_SIZE'][2]) + "_dropout" + str(config['DROPOUT']) | |
if config['MASKNONLIVER']: config['EXPORT_FILE_NAME'] += "_wobackground" | |
if config['LOSS'] == "wfdice": config['EXPORT_FILE_NAME'] += "_weaklambda" + str(config['LAMBDA_WEAK']) | |
if config['PRETRAINED_WEIGHTS'] != "" and config['PRETRAINED_WEIGHTS'] != config['EXPORT_FILE_NAME']: config['EXPORT_FILE_NAME'] += "_pretraining" | |
if config['POSTPROCESSING_MORF']: config['EXPORT_FILE_NAME'] += "_wpostmorf" | |
if not config['EVAL_INCLUDE_BACKGROUND']: config['EXPORT_FILE_NAME'] += "_evalnobackground" | |
return config | |
def train_clinical(df_clinical): | |
clinical_model = LinearRegression() | |
# train model | |
print("Training model using", df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'].shape[1], "features") | |
print(df_clinical.head()) | |
clinical_model.fit(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'], df_clinical['tumor_ratio']) | |
# obtain predicted ratios | |
pred = clinical_model.predict(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio']) | |
# evaluate | |
corr = np.corrcoef(pred, df_clinical['tumor_ratio'])[0][1] | |
mae = np.mean(np.abs(pred - df_clinical['tumor_ratio'])) | |
print(f"The clinical model was fitted. Corr = {corr: .6f} MAE = {mae: .6f}") | |
return pred | |
def model_pipeline(config=None, plot=True): | |
torch.cuda.empty_cache() | |
config = get_defaults(config) | |
print(f"You Are Running on a: {config['DEVICE']}") | |
print("file name:", config['EXPORT_FILE_NAME']) | |
writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME']) | |
# prepare data | |
train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train = build_dataset(config, get_clinical=config['LOSS']=="wfdice") | |
# train clinical model | |
if config['LOSS'] == "wfdice": weak_labels = train_clinical(df_clinical_train) | |
else: weak_labels = None | |
# train segmentation model | |
model = build_model(config) | |
loss_function, optimizer, lr_scheduler, eval_metrics = build_optimizer(model, config) | |
if config['TRAIN']: | |
train(model, train_loader, valid_loader, loss_function, eval_metrics, optimizer, config, lr_scheduler, writer, postprocessing_transforms, weak_labels) | |
model.load_state_dict(torch.load("checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth", map_location=torch.device(config['DEVICE']))) | |
if config['ONESAMPLETESTRUN']: | |
return None, None, None | |
# test segmentation model | |
test_metrics, (test_images, test_labels, test_outputs) = evaluate(model, test_loader, eval_metrics, config, postprocessing_transforms) | |
print("Test metrics") | |
for key, value in test_metrics.items(): | |
print(f" {key}: {value:.4f}") | |
# visualize | |
if plot: | |
if "3D" in config['MODEL_NAME']: | |
visualize_patient(test_images[0].cpu(), mask=test_labels[0].cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1) | |
visualize_patient(test_images[0].cpu(), mask=test_outputs[0].cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1) | |
else: | |
visualize_patient(test_images.cpu(), mask=test_labels.cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1) | |
visualize_patient(test_images.cpu(), mask=torch.stack(test_outputs).cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1) | |
return (test_images, test_labels, test_outputs) | |