{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 1.4.dev2409\n",
"Numpy version: 1.26.2\n",
"Pytorch version: 1.13.0+cu116\n",
"MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
"MONAI rev id: 46c1b228091283fba829280a5d747f4237f76ed0\n",
"MONAI __file__: /usr/local/lib/python3.9/site-packages/monai/__init__.py\n",
"\n",
"Optional dependencies:\n",
"Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.\n",
"ITK version: NOT INSTALLED or UNKNOWN VERSION.\n",
"Nibabel version: 5.2.1\n",
"scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n",
"scipy version: 1.11.4\n",
"Pillow version: 10.1.0\n",
"Tensorboard version: 2.16.2\n",
"gdown version: NOT INSTALLED or UNKNOWN VERSION.\n",
"TorchVision version: 0.14.0+cu116\n",
"tqdm version: 4.66.1\n",
"lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n",
"psutil version: 5.9.8\n",
"pandas version: 2.2.1\n",
"einops version: 0.7.0\n",
"transformers version: 4.35.2\n",
"mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n",
"pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n",
"clearml version: NOT INSTALLED or UNKNOWN VERSION.\n",
"\n",
"For details about installing the optional dependencies, please visit:\n",
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
"\n"
]
}
],
"source": [
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from monai.config import print_config\n",
"from monai.losses import DiceLoss\n",
"from monai.inferers import sliding_window_inference\n",
"from monai.transforms import MapTransform\n",
"from monai.data import DataLoader, Dataset\n",
"from monai.utils import set_determinism\n",
"from monai import transforms\n",
"import torch\n",
"\n",
"print_config()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"set_determinism(seed=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Số lượng mẫu trong '/app/brats_2021_task1/BraTS2021_Training_Data' là: 1251\n"
]
}
],
"source": [
"import os\n",
"\n",
"parent_folder_path = '/app/brats_2021_task1/BraTS2021_Training_Data'\n",
"subfolders = [f for f in os.listdir(parent_folder_path) if os.path.isdir(os.path.join(parent_folder_path, f))]\n",
"num_folders = len(subfolders)\n",
"print(f\"Số lượng mẫu trong '{parent_folder_path}' là: {num_folders}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"\n",
"folder_data = []\n",
"\n",
"for fold_number in os.listdir(parent_folder_path):\n",
" fold_path = os.path.join(parent_folder_path, fold_number)\n",
"\n",
" if os.path.isdir(fold_path):\n",
" entry = {\"fold\": 0, \"image\": [], \"label\": \"\"}\n",
"\n",
" for file_type in ['flair', 't1ce', 't1', 't2']:\n",
" file_name = f\"{fold_number}_{file_type}.nii.gz\"\n",
" file_path = os.path.join(fold_path, file_name)\n",
"\n",
" if os.path.exists(file_path):\n",
"\n",
" entry[\"image\"].append(os.path.abspath(file_path))\n",
"\n",
" label_name = f\"{fold_number}_seg.nii.gz\"\n",
" label_path = os.path.join(fold_path, label_name)\n",
" if os.path.exists(label_path):\n",
" entry[\"label\"] = os.path.abspath(label_path)\n",
"\n",
" folder_data.append(entry)\n",
"\n",
"\n",
"json_data = {\"training\": folder_data}\n",
"\n",
"json_file_path = '/app/info.json'\n",
"with open(json_file_path, 'w') as json_file:\n",
" json.dump(json_data, json_file, indent=2)\n",
"\n",
"print(f\"Thông tin đã được ghi vào {json_file_path}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n",
" \"\"\"\n",
" Convert labels to multi channels based on brats classes:\n",
" label 1 is the necrotic and non-enhancing tumor core\n",
" label 2 is the peritumoral edema\n",
" label 4 is the GD-enhancing tumor\n",
" The possible classes are TC (Tumor core), WT (Whole tumor)\n",
" and ET (Enhancing tumor).\n",
"\n",
" \"\"\"\n",
"\n",
" def __call__(self, data):\n",
" d = dict(data)\n",
" for key in self.keys:\n",
" result = []\n",
" # merge label 1 and label 4 to construct TC\n",
" result.append(np.logical_or(d[key] == 1, d[key] == 4))\n",
" # merge labels 1, 2 and 4 to construct WT\n",
" result.append(\n",
" np.logical_or(\n",
" np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2\n",
" )\n",
" )\n",
" # label 4 is ET\n",
" result.append(d[key] == 4)\n",
" d[key] = np.stack(result, axis=0).astype(np.float32)\n",
" return d"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def datafold_read(datalist, basedir, fold=0, key=\"training\"):\n",
" with open(datalist) as f:\n",
" json_data = json.load(f)\n",
"\n",
" json_data = json_data[key]\n",
"\n",
" for d in json_data:\n",
" for k in d:\n",
" if isinstance(d[k], list):\n",
" d[k] = [os.path.join(basedir, iv) for iv in d[k]]\n",
" elif isinstance(d[k], str):\n",
" d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]\n",
"\n",
" tr = []\n",
" val = []\n",
" for d in json_data:\n",
" if \"fold\" in d and d[\"fold\"] == fold:\n",
" val.append(d)\n",
" else:\n",
" tr.append(d)\n",
"\n",
" return tr, val"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :\n",
" train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)\n",
" from sklearn.model_selection import train_test_split\n",
" if volume != None :\n",
" train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)\n",
" \n",
" train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)\n",
" \n",
" validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)\n",
" return train_files, validation_files, test_files"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):\n",
" train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)\n",
" \n",
" train_transform = transforms.Compose(\n",
" [\n",
" transforms.LoadImaged(keys=[\"image\", \"label\"]),\n",
" transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
" transforms.CropForegroundd(\n",
" keys=[\"image\", \"label\"],\n",
" source_key=\"image\",\n",
" k_divisible=[roi[0], roi[1], roi[2]],\n",
" ),\n",
" transforms.RandSpatialCropd(\n",
" keys=[\"image\", \"label\"],\n",
" roi_size=[roi[0], roi[1], roi[2]],\n",
" random_size=False,\n",
" ),\n",
" transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n",
" transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n",
" transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n",
" transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
" transforms.RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n",
" transforms.RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n",
" ]\n",
" )\n",
" val_transform = transforms.Compose(\n",
" [\n",
" transforms.LoadImaged(keys=[\"image\", \"label\"]),\n",
" transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n",
" transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n",
" ]\n",
" )\n",
"\n",
" train_ds = Dataset(data=train_files, transform=train_transform)\n",
" train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=2,\n",
" pin_memory=True,\n",
" )\n",
" val_ds = Dataset(data=validation_files, transform=val_transform)\n",
" val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=1,\n",
" shuffle=False,\n",
" num_workers=2,\n",
" pin_memory=True,\n",
" )\n",
" test_ds = Dataset(data=test_files, transform=val_transform)\n",
" test_loader = DataLoader(\n",
" test_ds,\n",
" batch_size=1,\n",
" shuffle=False,\n",
" num_workers=2,\n",
" pin_memory=True,\n",
" )\n",
" return train_loader, val_loader,test_loader"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.9/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.\n",
" warn_deprecated(argname, msg, warning_category)\n"
]
}
],
"source": [
"import json\n",
"data_dir = \"/app/brats_2021_task1\"\n",
"json_list = \"/app/info.json\"\n",
"roi = (128, 128, 128)\n",
"batch_size = 1\n",
"sw_batch_size = 2\n",
"fold = 1\n",
"infer_overlap = 0.5\n",
"max_epochs = 100\n",
"val_every = 10\n",
"train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=0.5, test_size=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"100"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(val_loader)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Model design, base on SegResNet, VAE and TransBTS"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"#Re-use from encoder block\n",
"def normalization(planes, norm = 'instance'):\n",
" if norm == 'bn':\n",
" m = nn.BatchNorm3d(planes)\n",
" elif norm == 'gn':\n",
" m = nn.GroupNorm(8, planes)\n",
" elif norm == 'instance':\n",
" m = nn.InstanceNorm3d(planes)\n",
" else:\n",
" raise ValueError(\"Does not support this kind of norm.\")\n",
" return m\n",
"class ResNetBlock(nn.Module):\n",
" def __init__(self, in_channels, norm = 'instance'):\n",
" super().__init__()\n",
" self.resnetblock = nn.Sequential(\n",
" normalization(in_channels, norm = norm),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),\n",
" normalization(in_channels, norm = norm),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" y = self.resnetblock(x)\n",
" return y + x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from torch.nn import functional as F\n",
"\n",
"def calculate_total_dimension(a):\n",
" res = 1\n",
" for x in a:\n",
" res *= x\n",
" return res\n",
"\n",
"class VAE(nn.Module):\n",
" def __init__(self, input_shape, latent_dim, num_channels):\n",
" super().__init__()\n",
" self.input_shape = input_shape\n",
" self.in_channels = input_shape[1] #input_shape[0] is batch size\n",
" self.latent_dim = latent_dim\n",
" self.encoder_channels = self.in_channels // 16\n",
"\n",
" #Encoder\n",
" self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,\n",
" kernel_size = 3, stride = 2, padding=1)\n",
" # self.VAE_reshape = nn.Sequential(\n",
" # nn.GroupNorm(8, self.in_channels),\n",
" # nn.ReLU(),\n",
" # nn.Conv3d(self.in_channels, self.encoder_channels,\n",
" # kernel_size = 3, stride = 2, padding=1),\n",
" # )\n",
"\n",
" flatten_input_shape = calculate_total_dimension(input_shape)\n",
" flatten_input_shape_after_vae_reshape = \\\n",
" flatten_input_shape * self.encoder_channels // (8 * self.in_channels)\n",
"\n",
" #Convert from total dimension to latent space\n",
" self.to_latent_space = nn.Linear(\n",
" flatten_input_shape_after_vae_reshape // self.in_channels, 1)\n",
"\n",
" self.mean = nn.Linear(self.in_channels, self.latent_dim)\n",
" self.logvar = nn.Linear(self.in_channels, self.latent_dim)\n",
"# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))\n",
"\n",
" #Decoder\n",
" self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)\n",
" self.Reconstruct = nn.Sequential(\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(\n",
" self.encoder_channels, self.in_channels,\n",
" stride = 1, kernel_size = 1),\n",
" nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
"\n",
" nn.Conv3d(\n",
" self.in_channels, self.in_channels // 2,\n",
" stride = 1, kernel_size = 1),\n",
" nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
" ResNetBlock(self.in_channels // 2),\n",
"\n",
" nn.Conv3d(\n",
" self.in_channels // 2, self.in_channels // 4,\n",
" stride = 1, kernel_size = 1),\n",
" nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
" ResNetBlock(self.in_channels // 4),\n",
"\n",
" nn.Conv3d(\n",
" self.in_channels // 4, self.in_channels // 8,\n",
" stride = 1, kernel_size = 1),\n",
" nn.Upsample(scale_factor=2, mode = 'nearest'),\n",
" ResNetBlock(self.in_channels // 8),\n",
"\n",
" nn.InstanceNorm3d(self.in_channels // 8),\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(\n",
" self.in_channels // 8, num_channels,\n",
" kernel_size = 3, padding = 1),\n",
"# nn.Sigmoid()\n",
" )\n",
"\n",
"\n",
" def forward(self, x): #x has shape = input_shape\n",
" #Encoder\n",
" # print(x.shape)\n",
" x = self.VAE_reshape(x)\n",
" shape = x.shape\n",
"\n",
" x = x.view(self.in_channels, -1)\n",
" x = self.to_latent_space(x)\n",
" x = x.view(1, self.in_channels)\n",
"\n",
" mean = self.mean(x)\n",
" logvar = self.logvar(x)\n",
"# sigma = torch.exp(0.5 * logvar)\n",
" # Reparameter\n",
" epsilon = torch.randn_like(logvar)\n",
" sample = mean + epsilon * torch.exp(0.5*logvar)\n",
"\n",
" #Decoder\n",
" y = self.to_original_dimension(sample)\n",
" y = y.view(*shape)\n",
" return self.Reconstruct(y), mean, logvar\n",
" def total_params(self):\n",
" total = sum(p.numel() for p in self.parameters())\n",
" return format(total, ',')\n",
"\n",
" def total_trainable_params(self):\n",
" total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
" return format(total_trainable, ',')\n",
"\n",
"\n",
"# x = torch.rand((1, 256, 16, 16, 16))\n",
"# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)\n",
"# y = vae(x)\n",
"# print(y[0].shape, y[1].shape, y[2].shape)\n",
"# print(vae.total_trainable_params())\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"from einops import rearrange\n",
"from einops.layers.torch import Rearrange\n",
"\n",
"def pair(t):\n",
" return t if isinstance(t, tuple) else (t, t)\n",
"\n",
"\n",
"class PreNorm(nn.Module):\n",
" def __init__(self, dim, function):\n",
" super().__init__()\n",
" self.norm = nn.LayerNorm(dim)\n",
" self.function = function\n",
"\n",
" def forward(self, x):\n",
" return self.function(self.norm(x))\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, dim, hidden_dim, dropout = 0.0):\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(dim, hidden_dim),\n",
" nn.GELU(),\n",
" nn.Dropout(dropout),\n",
" nn.Linear(hidden_dim, dim),\n",
" nn.Dropout(dropout)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"class Attention(nn.Module):\n",
" def __init__(self, dim, heads, dim_head, dropout = 0.0):\n",
" super().__init__()\n",
" all_head_size = heads * dim_head\n",
" project_out = not (heads == 1 and dim_head == dim)\n",
"\n",
" self.heads = heads\n",
" self.scale = dim_head ** -0.5\n",
"\n",
" self.softmax = nn.Softmax(dim = -1)\n",
" self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)\n",
"\n",
" self.to_out = nn.Sequential(\n",
" nn.Linear(all_head_size, dim),\n",
" nn.Dropout(dropout)\n",
" ) if project_out else nn.Identity()\n",
"\n",
" def forward(self, x):\n",
" qkv = self.to_qkv(x).chunk(3, dim = -1)\n",
" #(batch, heads * dim_head) -> (batch, all_head_size)\n",
" q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)\n",
"\n",
" dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale\n",
"\n",
" atten = self.softmax(dots)\n",
"\n",
" out = torch.matmul(atten, v)\n",
" out = rearrange(out, 'b h n d -> b n (h d)')\n",
" return self.to_out(out)\n",
"\n",
"class Transformer(nn.Module):\n",
" def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):\n",
" super().__init__()\n",
" self.layers = nn.ModuleList([])\n",
" for _ in range(depth):\n",
" self.layers.append(nn.ModuleList([\n",
" PreNorm(dim, Attention(dim, heads, dim_head, dropout)),\n",
" PreNorm(dim, FeedForward(dim, mlp_dim, dropout))\n",
" ]))\n",
" def forward(self, x):\n",
" for attention, feedforward in self.layers:\n",
" x = attention(x) + x\n",
" x = feedforward(x) + x\n",
" return x\n",
"\n",
"class FixedPositionalEncoding(nn.Module):\n",
" def __init__(self, embedding_dim, max_length=768):\n",
" super(FixedPositionalEncoding, self).__init__()\n",
"\n",
" pe = torch.zeros(max_length, embedding_dim)\n",
" position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)\n",
" div_term = torch.exp(\n",
" torch.arange(0, embedding_dim, 2).float()\n",
" * (-torch.log(torch.tensor(10000.0)) / embedding_dim)\n",
" )\n",
" pe[:, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 1::2] = torch.cos(position * div_term)\n",
" pe = pe.unsqueeze(0).transpose(0, 1)\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x):\n",
" x = x + self.pe[: x.size(0), :]\n",
" return x\n",
"\n",
"\n",
"class LearnedPositionalEncoding(nn.Module):\n",
" def __init__(self, embedding_dim, seq_length):\n",
" super(LearnedPositionalEncoding, self).__init__()\n",
" self.seq_length = seq_length\n",
" self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x\n",
"\n",
" def forward(self, x, position_ids=None):\n",
" position_embeddings = self.position_embeddings\n",
"# print(x.shape, self.position_embeddings.shape)\n",
" return x + position_embeddings"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"### Encoder ####\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class InitConv(nn.Module):\n",
" def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):\n",
" super().__init__()\n",
" self.layer = nn.Sequential(\n",
" nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),\n",
" nn.Dropout3d(dropout)\n",
" )\n",
" def forward(self, x):\n",
" y = self.layer(x)\n",
" return y\n",
"\n",
"\n",
"class DownSample(nn.Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super().__init__()\n",
" self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)\n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"class Encoder(nn.Module):\n",
" def __init__(self, in_channels, base_channels, dropout = 0.2):\n",
" super().__init__()\n",
"\n",
" self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)\n",
" self.encoder_block1 = ResNetBlock(in_channels = base_channels)\n",
" self.encoder_down1 = DownSample(base_channels, base_channels * 2)\n",
"\n",
" self.encoder_block2_1 = ResNetBlock(base_channels * 2)\n",
" self.encoder_block2_2 = ResNetBlock(base_channels * 2)\n",
" self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)\n",
"\n",
" self.encoder_block3_1 = ResNetBlock(base_channels * 4)\n",
" self.encoder_block3_2 = ResNetBlock(base_channels * 4)\n",
" self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)\n",
"\n",
" self.encoder_block4_1 = ResNetBlock(base_channels * 8)\n",
" self.encoder_block4_2 = ResNetBlock(base_channels * 8)\n",
" self.encoder_block4_3 = ResNetBlock(base_channels * 8)\n",
" self.encoder_block4_4 = ResNetBlock(base_channels * 8)\n",
" # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)\n",
" def forward(self, x):\n",
" x = self.init_conv(x) #(1, 16, 128, 128, 128)\n",
"\n",
" x1 = self.encoder_block1(x)\n",
" x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)\n",
"\n",
" x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))\n",
" x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)\n",
"\n",
" x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))\n",
" x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)\n",
"\n",
" output = self.encoder_block4_4(\n",
" self.encoder_block4_3(\n",
" self.encoder_block4_2(\n",
" self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)\n",
" return x1, x2, x3, output\n",
"\n",
"# x = torch.rand((1, 4, 128, 128, 128))\n",
"# Enc = Encoder(4, 32)\n",
"# _, _, _, y = Enc(x)\n",
"# print(y.shape) (1,256,16,16)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"### Decoder ####\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class Upsample(nn.Module):\n",
" def __init__(self, in_channel, out_channel):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)\n",
" self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)\n",
" self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)\n",
"\n",
" def forward(self, prev, x):\n",
" x = self.deconv(self.conv1(x))\n",
" y = torch.cat((prev, x), dim = 1)\n",
" return self.conv2(y)\n",
"\n",
"class FinalConv(nn.Module): # Input channels are equal to output channels\n",
" def __init__(self, in_channels, out_channels=32, norm=\"instance\"):\n",
" super(FinalConv, self).__init__()\n",
" if norm == \"batch\":\n",
" norm_layer = nn.BatchNorm3d(num_features=in_channels)\n",
" elif norm == \"group\":\n",
" norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)\n",
" elif norm == 'instance':\n",
" norm_layer = nn.InstanceNorm3d(in_channels)\n",
"\n",
" self.layer = nn.Sequential(\n",
" norm_layer,\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n",
" )\n",
" def forward(self, x):\n",
" return self.layer(x)\n",
"\n",
"class Decoder(nn.Module):\n",
" def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):\n",
" super().__init__()\n",
" self.img_dim = img_dim\n",
" self.patch_dim = patch_dim\n",
" self.embedding_dim = embedding_dim\n",
"\n",
" self.decoder_upsample_1 = Upsample(128, 64)\n",
" self.decoder_block_1 = ResNetBlock(64)\n",
"\n",
" self.decoder_upsample_2 = Upsample(64, 32)\n",
" self.decoder_block_2 = ResNetBlock(32)\n",
"\n",
" self.decoder_upsample_3 = Upsample(32, 16)\n",
" self.decoder_block_3 = ResNetBlock(16)\n",
"\n",
" self.endconv = FinalConv(16, num_classes)\n",
" # self.normalize = nn.Sigmoid()\n",
"\n",
" def forward(self, x1, x2, x3, x):\n",
" x = self.decoder_upsample_1(x3, x)\n",
" x = self.decoder_block_1(x)\n",
"\n",
" x = self.decoder_upsample_2(x2, x)\n",
" x = self.decoder_block_2(x)\n",
"\n",
" x = self.decoder_upsample_3(x1, x)\n",
" x = self.decoder_block_3(x)\n",
"\n",
" y = self.endconv(x)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class FeatureMapping(nn.Module):\n",
" def __init__(self, in_channel, out_channel, norm = 'instance'):\n",
" super().__init__()\n",
" if norm == 'bn':\n",
" norm_layer_1 = nn.BatchNorm3d(out_channel)\n",
" norm_layer_2 = nn.BatchNorm3d(out_channel)\n",
" elif norm == 'gn':\n",
" norm_layer_1 = nn.GroupNorm(8, out_channel)\n",
" norm_layer_2 = nn.GroupNorm(8, out_channel)\n",
" elif norm == 'instance':\n",
" norm_layer_1 = nn.InstanceNorm3d(out_channel)\n",
" norm_layer_2 = nn.InstanceNorm3d(out_channel)\n",
" self.feature_mapping = nn.Sequential(\n",
" nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),\n",
" norm_layer_1,\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),\n",
" norm_layer_2,\n",
" nn.LeakyReLU(0.2, inplace=True)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.feature_mapping(x)\n",
"\n",
"\n",
"class FeatureMapping1(nn.Module):\n",
" def __init__(self, in_channel, norm = 'instance'):\n",
" super().__init__()\n",
" if norm == 'bn':\n",
" norm_layer_1 = nn.BatchNorm3d(in_channel)\n",
" norm_layer_2 = nn.BatchNorm3d(in_channel)\n",
" elif norm == 'gn':\n",
" norm_layer_1 = nn.GroupNorm(8, in_channel)\n",
" norm_layer_2 = nn.GroupNorm(8, in_channel)\n",
" elif norm == 'instance':\n",
" norm_layer_1 = nn.InstanceNorm3d(in_channel)\n",
" norm_layer_2 = nn.InstanceNorm3d(in_channel)\n",
" self.feature_mapping1 = nn.Sequential(\n",
" nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n",
" norm_layer_1,\n",
" nn.LeakyReLU(0.2, inplace=True),\n",
" nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n",
" norm_layer_2,\n",
" nn.LeakyReLU(0.2, inplace=True)\n",
" )\n",
" def forward(self, x):\n",
" y = self.feature_mapping1(x)\n",
" return x + y #Resnet Like"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class SegTransVAE(nn.Module):\n",
" def __init__(self, img_dim, patch_dim, num_channels, num_classes,\n",
" embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,\n",
" dropout = 0.0, attention_dropout = 0.0,\n",
" conv_patch_representation = True, positional_encoding = 'learned',\n",
" use_VAE = False):\n",
"\n",
" super().__init__()\n",
" assert embedding_dim % num_heads == 0\n",
" assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0\n",
"\n",
" self.img_dim = img_dim\n",
" self.embedding_dim = embedding_dim\n",
" self.num_heads = num_heads\n",
" self.num_classes = num_classes\n",
" self.patch_dim = patch_dim\n",
" self.num_channels = num_channels\n",
" self.in_channels_vae = in_channels_vae\n",
" self.dropout = dropout\n",
" self.attention_dropout = attention_dropout\n",
" self.conv_patch_representation = conv_patch_representation\n",
" self.use_VAE = use_VAE\n",
"\n",
" self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))\n",
" self.seq_length = self.num_patches\n",
" self.flatten_dim = 128 * num_channels\n",
"\n",
" self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)\n",
" if positional_encoding == \"learned\":\n",
" self.position_encoding = LearnedPositionalEncoding(\n",
" self.embedding_dim, self.seq_length\n",
" )\n",
" elif positional_encoding == \"fixed\":\n",
" self.position_encoding = FixedPositionalEncoding(\n",
" self.embedding_dim,\n",
" )\n",
" self.pe_dropout = nn.Dropout(self.dropout)\n",
"\n",
" self.transformer = Transformer(\n",
" embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout\n",
" )\n",
" self.pre_head_ln = nn.LayerNorm(embedding_dim)\n",
"\n",
" if self.conv_patch_representation:\n",
" self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)\n",
" self.encoder = Encoder(self.num_channels, 16)\n",
" self.bn = nn.InstanceNorm3d(128)\n",
" self.relu = nn.LeakyReLU(0.2, inplace=True)\n",
" self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)\n",
" self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)\n",
" self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)\n",
"\n",
" self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)\n",
" if use_VAE:\n",
" self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)\n",
" def encode(self, x):\n",
" if self.conv_patch_representation:\n",
" x1, x2, x3, x = self.encoder(x)\n",
" x = self.bn(x)\n",
" x = self.relu(x)\n",
" x = self.conv_x(x)\n",
" x = x.permute(0, 2, 3, 4, 1).contiguous()\n",
" x = x.view(x.size(0), -1, self.embedding_dim)\n",
" x = self.position_encoding(x)\n",
" x = self.pe_dropout(x)\n",
" x = self.transformer(x)\n",
" x = self.pre_head_ln(x)\n",
"\n",
" return x1, x2, x3, x\n",
"\n",
" def decode(self, x1, x2, x3, x):\n",
" #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)\n",
"# print(\"In decode...\")\n",
"# print(\" x1: {} \\n x2: {} \\n x3: {} \\n x: {}\".format( x1.shape, x2.shape, x3.shape, x.shape))\n",
"# break\n",
" return self.decoder(x1, x2, x3, x)\n",
"\n",
" def forward(self, x, is_validation = True):\n",
" x1, x2, x3, x = self.encode(x)\n",
" x = x.view( x.size(0),\n",
" self.img_dim[0] // self.patch_dim,\n",
" self.img_dim[1] // self.patch_dim,\n",
" self.img_dim[2] // self.patch_dim,\n",
" self.embedding_dim)\n",
" x = x.permute(0, 4, 1, 2, 3).contiguous()\n",
" x = self.FeatureMapping(x)\n",
" x = self.FeatureMapping1(x)\n",
" if self.use_VAE and not is_validation:\n",
" vae_out, mu, sigma = self.vae(x)\n",
" y = self.decode(x1, x2, x3, x)\n",
" if self.use_VAE and not is_validation:\n",
" return y, vae_out, mu, sigma\n",
" else:\n",
" return y\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA (GPU) is available. Using GPU.\n"
]
}
],
"source": [
"import torch\n",
"\n",
"# Check if CUDA (GPU support) is available\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda:0\")\n",
" print(\"CUDA (GPU) is available. Using GPU.\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" print(\"CUDA (GPU) is not available. Using CPU.\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"model = SegTransVAE(img_dim = (128, 128, 128),patch_dim= 8,num_channels =4,num_classes= 3,embedding_dim= 768,num_heads= 8,num_layers= 4, hidden_dim= 3072,in_channels_vae=128 , use_VAE = True)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tổng số tham số của mô hình là: 44727120\n",
"Tổng số tham số cần tính gradient của mô hình là: 44727120\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f'Tổng số tham số của mô hình là: {total_params}')\n",
"\n",
"total_params_requires_grad = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f'Tổng số tham số cần tính gradient của mô hình là: {total_params_requires_grad}')\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class Loss_VAE(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.mse = nn.MSELoss(reduction='sum')\n",
"\n",
" def forward(self, recon_x, x, mu, log_var):\n",
" mse = self.mse(recon_x, x)\n",
" kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
" loss = mse + kld\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def DiceScore(\n",
" y_pred: torch.Tensor,\n",
" y: torch.Tensor,\n",
" include_background: bool = True,\n",
") -> torch.Tensor:\n",
" \"\"\"Computes Dice score metric from full size Tensor and collects average.\n",
" Args:\n",
" y_pred: input data to compute, typical segmentation model output.\n",
" It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n",
" should be binarized.\n",
" y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.\n",
" The values should be binarized.\n",
" include_background: whether to skip Dice computation on the first channel of\n",
" the predicted output. Defaults to True.\n",
" Returns:\n",
" Dice scores per batch and per class, (shape [batch_size, num_classes]).\n",
" Raises:\n",
" ValueError: when `y_pred` and `y` have different shapes.\n",
" \"\"\"\n",
"\n",
" y = y.float()\n",
" y_pred = y_pred.float()\n",
"\n",
" if y.shape != y_pred.shape:\n",
" raise ValueError(\"y_pred and y should have same shapes.\")\n",
"\n",
" # reducing only spatial dimensions (not batch nor channels)\n",
" n_len = len(y_pred.shape)\n",
" reduce_axis = list(range(2, n_len))\n",
" intersection = torch.sum(y * y_pred, dim=reduce_axis)\n",
"\n",
" y_o = torch.sum(y, reduce_axis)\n",
" y_pred_o = torch.sum(y_pred, dim=reduce_axis)\n",
" denominator = y_o + y_pred_o\n",
"\n",
" return torch.where(\n",
" denominator > 0,\n",
" (2.0 * intersection) / denominator,\n",
" torch.tensor(float(\"1\"), device=y_o.device),\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# Pytorch Lightning\n",
"import pytorch_lightning as pl\n",
"import matplotlib.pyplot as plt\n",
"import csv\n",
"from monai.transforms import AsDiscrete, Activations, Compose, EnsureType"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"class BRATS(pl.LightningModule):\n",
" def __init__(self, use_VAE = True, lr = 1e-4, ):\n",
" super().__init__()\n",
" \n",
" self.use_vae = use_VAE\n",
" self.lr = lr\n",
" self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)\n",
"\n",
" self.loss_vae = Loss_VAE()\n",
" self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)\n",
" self.post_trans_images = Compose(\n",
" [EnsureType(),\n",
" Activations(sigmoid=True), \n",
" AsDiscrete(threshold_values=True), \n",
" ]\n",
" )\n",
"\n",
" self.best_val_dice = 0\n",
" \n",
" self.training_step_outputs = [] \n",
" self.val_step_loss = [] \n",
" self.val_step_dice = []\n",
" self.val_step_dice_tc = [] \n",
" self.val_step_dice_wt = []\n",
" self.val_step_dice_et = [] \n",
" self.test_step_loss = [] \n",
" self.test_step_dice = []\n",
" self.test_step_dice_tc = [] \n",
" self.test_step_dice_wt = []\n",
" self.test_step_dice_et = [] \n",
"\n",
" def forward(self, x, is_validation = True):\n",
" return self.model(x, is_validation) \n",
" def training_step(self, batch, batch_index):\n",
" inputs, labels = (batch['image'], batch['label'])\n",
" \n",
" if not self.use_vae:\n",
" outputs = self.forward(inputs, is_validation=False)\n",
" loss = self.dice_loss(outputs, labels)\n",
" else:\n",
" outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)\n",
" \n",
" vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)\n",
" dice_loss = self.dice_loss(outputs, labels)\n",
" loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss\n",
" self.training_step_outputs.append(loss)\n",
" self.log('train/vae_loss', vae_loss)\n",
" self.log('train/dice_loss', dice_loss)\n",
" if batch_index == 10:\n",
"\n",
" tensorboard = self.logger.experiment \n",
" fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))\n",
" \n",
"\n",
" ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')\n",
" ax[0].set_title(\"Input\")\n",
"\n",
" ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
" ax[1].set_title(\"Reconstruction\")\n",
" \n",
" ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
" ax[2].set_title(\"Labels TC\")\n",
" \n",
" ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n",
" ax[3].set_title(\"TC\")\n",
" \n",
" ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n",
" ax[4].set_title(\"Labels ET\")\n",
" \n",
" ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n",
" ax[5].set_title(\"ET\")\n",
"\n",
" \n",
" tensorboard.add_figure('train_visualize', fig, self.current_epoch)\n",
"\n",
" self.log('train/loss', loss)\n",
" \n",
" return loss\n",
" \n",
" def on_train_epoch_end(self):\n",
" ## F1 Macro all epoch saving outputs and target per batch\n",
"\n",
" # free up the memory\n",
" # --> HERE STEP 3 <--\n",
" epoch_average = torch.stack(self.training_step_outputs).mean()\n",
" self.log(\"training_epoch_average\", epoch_average)\n",
" self.training_step_outputs.clear() # free memory\n",
"\n",
" def validation_step(self, batch, batch_index):\n",
" inputs, labels = (batch['image'], batch['label'])\n",
" roi_size = (128, 128, 128)\n",
" sw_batch_size = 1\n",
" outputs = sliding_window_inference(\n",
" inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)\n",
" loss = self.dice_loss(outputs, labels)\n",
" \n",
" \n",
" val_outputs = self.post_trans_images(outputs)\n",
" \n",
" \n",
" metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n",
" metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n",
" metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n",
" mean_val_dice = (metric_tc + metric_wt + metric_et)/3\n",
" self.val_step_loss.append(loss) \n",
" self.val_step_dice.append(mean_val_dice)\n",
" self.val_step_dice_tc.append(metric_tc) \n",
" self.val_step_dice_wt.append(metric_wt)\n",
" self.val_step_dice_et.append(metric_et) \n",
" return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,\n",
" 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}\n",
" \n",
" def on_validation_epoch_end(self):\n",
"\n",
" loss = torch.stack(self.val_step_loss).mean()\n",
" mean_val_dice = torch.stack(self.val_step_dice).mean()\n",
" metric_tc = torch.stack(self.val_step_dice_tc).mean()\n",
" metric_wt = torch.stack(self.val_step_dice_wt).mean()\n",
" metric_et = torch.stack(self.val_step_dice_et).mean()\n",
" self.log('val/Loss', loss)\n",
" self.log('val/MeanDiceScore', mean_val_dice)\n",
" self.log('val/DiceTC', metric_tc)\n",
" self.log('val/DiceWT', metric_wt)\n",
" self.log('val/DiceET', metric_et)\n",
" os.makedirs(self.logger.log_dir, exist_ok=True)\n",
" if self.current_epoch == 0:\n",
" with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:\n",
" writer = csv.writer(f)\n",
" writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])\n",
" with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:\n",
" writer = csv.writer(f)\n",
" writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])\n",
"\n",
" if mean_val_dice > self.best_val_dice:\n",
" self.best_val_dice = mean_val_dice\n",
" self.best_val_epoch = self.current_epoch\n",
" print(\n",
" f\"\\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}\"\n",
" f\" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}\"\n",
" f\"\\n Best mean dice: {self.best_val_dice}\"\n",
" f\" at epoch: {self.best_val_epoch}\"\n",
" )\n",
" \n",
" self.val_step_loss.clear() \n",
" self.val_step_dice.clear()\n",
" self.val_step_dice_tc.clear() \n",
" self.val_step_dice_wt.clear()\n",
" self.val_step_dice_et.clear()\n",
" return {'val_MeanDiceScore': mean_val_dice}\n",
" def test_step(self, batch, batch_index):\n",
" inputs, labels = (batch['image'], batch['label'])\n",
" \n",
" roi_size = (128, 128, 128)\n",
" sw_batch_size = 1\n",
" test_outputs = sliding_window_inference(\n",
" inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)\n",
" loss = self.dice_loss(test_outputs, labels)\n",
" test_outputs = self.post_trans_images(test_outputs)\n",
" metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n",
" metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n",
" metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n",
" mean_test_dice = (metric_tc + metric_wt + metric_et)/3\n",
" \n",
" self.test_step_loss.append(loss) \n",
" self.test_step_dice.append(mean_test_dice)\n",
" self.test_step_dice_tc.append(metric_tc) \n",
" self.test_step_dice_wt.append(metric_wt)\n",
" self.test_step_dice_et.append(metric_et) \n",
" \n",
" return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,\n",
" 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}\n",
" \n",
" def test_epoch_end(self):\n",
" loss = torch.stack(self.test_step_loss).mean()\n",
" mean_test_dice = torch.stack(self.test_step_dice).mean()\n",
" metric_tc = torch.stack(self.test_step_dice_tc).mean()\n",
" metric_wt = torch.stack(self.test_step_dice_wt).mean()\n",
" metric_et = torch.stack(self.test_step_dice_et).mean()\n",
" self.log('test/Loss', loss)\n",
" self.log('test/MeanDiceScore', mean_test_dice)\n",
" self.log('test/DiceTC', metric_tc)\n",
" self.log('test/DiceWT', metric_wt)\n",
" self.log('test/DiceET', metric_et)\n",
"\n",
" with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:\n",
" writer = csv.writer(f)\n",
" writer.writerow([\"Mean Test Dice\", \"Dice TC\", \"Dice WT\", \"Dice ET\"])\n",
" writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])\n",
"\n",
" self.test_step_loss.clear() \n",
" self.test_step_dice.clear()\n",
" self.test_step_dice_tc.clear() \n",
" self.test_step_dice_wt.clear()\n",
" self.test_step_dice_et.clear()\n",
" return {'test_MeanDiceScore': mean_test_dice}\n",
" \n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(\n",
" self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True\n",
" )\n",
"# optimizer = AdaBelief(self.model.parameters(), \n",
"# lr=self.lr, eps=1e-16, \n",
"# betas=(0.9,0.999), weight_decouple = True, \n",
"# rectify = False)\n",
" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)\n",
" return [optimizer], [scheduler]\n",
" \n",
" def train_dataloader(self):\n",
" return train_loader\n",
" def val_dataloader(self):\n",
" return val_loader\n",
" \n",
" def test_dataloader(self):\n",
" return test_loader"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
"import os \n",
"from pytorch_lightning.loggers import TensorBoardLogger"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sh: 1: cls: not found\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[H\u001b[2JTraining ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.9/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n",
"Using 16bit Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"------------------------------------------\n",
"0 | model | SegTransVAE | 44.7 M\n",
"1 | loss_vae | Loss_VAE | 0 \n",
"2 | dice_loss | DiceLoss | 0 \n",
"------------------------------------------\n",
"44.7 M Trainable params\n",
"0 Non-trainable params\n",
"44.7 M Total params\n",
"178.908 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:05<00:00, 0.37it/s]\n",
" Current epoch: 0 Current mean dice: 0.0097 tc: 0.0029 wt: 0.0234 et: 0.0028\n",
" Best mean dice: 0.009687595069408417 at epoch: 0\n",
"Epoch 0: 100%|██████████| 500/500 [05:38<00:00, 1.48it/s, v_num=6] \n",
" Current epoch: 0 Current mean dice: 0.1927 tc: 0.1647 wt: 0.2843 et: 0.1290\n",
" Best mean dice: 0.1926589012145996 at epoch: 0\n",
"Epoch 1: 100%|██████████| 500/500 [07:35<00:00, 1.10it/s, v_num=6]\n",
" Current epoch: 1 Current mean dice: 0.3212 tc: 0.2691 wt: 0.4253 et: 0.2692\n",
" Best mean dice: 0.32120221853256226 at epoch: 1\n",
"Epoch 2: 100%|██████████| 500/500 [08:11<00:00, 1.02it/s, v_num=6]\n",
" Current epoch: 2 Current mean dice: 0.3912 tc: 0.3510 wt: 0.5087 et: 0.3137\n",
" Best mean dice: 0.39115065336227417 at epoch: 2\n",
"Epoch 3: 100%|██████████| 500/500 [08:58<00:00, 0.93it/s, v_num=6]\n",
" Current epoch: 3 Current mean dice: 0.4268 tc: 0.3828 wt: 0.5424 et: 0.3553\n",
" Best mean dice: 0.42682838439941406 at epoch: 3\n",
"Epoch 4: 41%|████▏ | 207/500 [02:51<04:03, 1.21it/s, v_num=6]"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick here for more info. \n",
"\u001b[1;31mView Jupyter log for further details."
]
}
],
"source": [
"os.system('cls||clear')\n",
"print(\"Training ...\")\n",
"model = BRATS(use_VAE = True)\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor='val/MeanDiceScore',\n",
" dirpath='./app/checkpoints/{}'.format(1),\n",
" filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',\n",
" save_top_k=3,\n",
" mode='max',\n",
" save_last= True,\n",
" auto_insert_metric_name=False\n",
")\n",
"early_stop_callback = EarlyStopping(\n",
" monitor='val/MeanDiceScore',\n",
" min_delta=0.0001,\n",
" patience=15,\n",
" verbose=False,\n",
" mode='max'\n",
")\n",
"tensorboardlogger = TensorBoardLogger(\n",
" 'logs', \n",
" name = \"1\", \n",
" default_hp_metric = None \n",
")\n",
"trainer = pl.Trainer(#fast_dev_run = 10, \n",
"# accelerator='ddp',\n",
" #overfit_batches=5,\n",
" devices = [0], \n",
" precision=16,\n",
" max_epochs = 200, \n",
" enable_progress_bar=True, \n",
" callbacks=[checkpoint_callback, early_stop_callback], \n",
"# auto_lr_find=True,\n",
" num_sanity_val_steps=2,\n",
" logger = tensorboardlogger,\n",
"# limit_train_batches=0.01, \n",
"# limit_val_batches=0.01\n",
" )\n",
"# trainer.tune(model)\n",
"trainer.fit(model)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"from trainer import BRATS\n",
"import os \n",
"import torch\n",
"os.system('cls||clear')\n",
"print(\"Testing ...\")\n",
"\n",
"CKPT = ''\n",
"model = BRATS(use_VAE=True).load_from_checkpoint(CKPT).eval()\n",
"val_dataloader = get_val_dataloader()\n",
"test_dataloader = get_test_dataloader()\n",
"trainer = pl.Trainer(gpus = [0], precision=32, progress_bar_refresh_rate=10)\n",
"\n",
"trainer.test(model, dataloaders = val_dataloader)\n",
"trainer.test(model, dataloaders = test_dataloader)\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}