|
import os |
|
from glob import glob |
|
import torch |
|
from PIL import Image |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
import monai |
|
from monai.data import ArrayDataset, decollate_batch, DataLoader |
|
from monai.inferers import sliding_window_inference |
|
from monai.metrics import DiceMetric |
|
from monai.transforms import ( |
|
Activations, |
|
AsDiscrete, |
|
Compose, |
|
LoadImage, |
|
RandRotate90, |
|
ScaleIntensity, |
|
|
|
|
|
) |
|
from monai.visualize import plot_2d_or_3d_image |
|
from PIL import Image |
|
import cv2 |
|
import tifffile |
|
import os |
|
|
|
|
|
def convert_to_png(img_dir): |
|
|
|
img_files = [file for file in os.listdir(img_dir) if file.endswith(('.jpg', '.jpeg', '.png', '.tif'))] |
|
|
|
|
|
for img_file in img_files: |
|
img_path = os.path.join(img_dir, img_file) |
|
if img_file.endswith('.tif'): |
|
|
|
with tifffile.TiffFile(img_path) as tif: |
|
img = Image.fromarray(tif.asarray()) |
|
else: |
|
|
|
img = Image.open(img_path) |
|
|
|
|
|
png_path = os.path.join(img_dir, os.path.splitext(img_file)[0] + '.png') |
|
img.save(png_path) |
|
|
|
convert_to_png("./tamp500/imgs") |
|
|
|
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) |
|
segs = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) |
|
|
|
def resize_images_and_masks(image_paths, mask_paths, output_dir, target_width, target_height): |
|
""" |
|
Resize images and corresponding segmentation masks to the specified dimensions. |
|
|
|
Args: |
|
- image_paths (list): List of paths to the input images. |
|
- mask_paths (list): List of paths to the segmentation masks. |
|
- output_dir (str): Directory to save the resized images and masks. |
|
- target_width (int): Target width for resizing. |
|
- target_height (int): Target height for resizing. |
|
|
|
Returns: |
|
- resized_image_paths (list): List of paths to the resized images. |
|
- resized_mask_paths (list): List of paths to the resized segmentation masks. |
|
""" |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
resized_image_dir = os.path.join(output_dir, 'resized_images') |
|
resized_mask_dir = os.path.join(output_dir, 'resized_masks') |
|
if not os.path.exists(resized_image_dir): |
|
os.makedirs(resized_image_dir) |
|
if not os.path.exists(resized_mask_dir): |
|
os.makedirs(resized_mask_dir) |
|
|
|
resized_image_paths = [] |
|
resized_mask_paths = [] |
|
for img_path, mask_path in zip(image_paths, mask_paths): |
|
|
|
img = cv2.imread(img_path) |
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
|
|
|
|
resized_img = cv2.resize(img, (target_width, target_height)) |
|
|
|
resized_mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
img_filename = os.path.basename(img_path) |
|
|
|
output_img_path = os.path.join(resized_image_dir, img_filename) |
|
|
|
cv2.imwrite(output_img_path, resized_img) |
|
resized_image_paths.append(output_img_path) |
|
|
|
|
|
mask_filename = os.path.basename(mask_path) |
|
|
|
output_mask_path = os.path.join(resized_mask_dir, mask_filename) |
|
|
|
cv2.imwrite(output_mask_path, resized_mask) |
|
resized_mask_paths.append(output_mask_path) |
|
|
|
return resized_image_paths, resized_mask_paths |
|
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) |
|
masks = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) |
|
output_directory = 'resized_' |
|
target_width = 448 |
|
target_height = 448 |
|
|
|
resized_image_paths, resized_mask_paths = resize_images_and_masks(images, masks, output_directory, target_width, target_height) |
|
images = sorted(glob(os.path.join("./resized_/resized_images", "*.png"))) |
|
segs = sorted(glob(os.path.join("./resized_/resized_masks", "*.png"))) |
|
|
|
from sklearn.model_selection import train_test_split |
|
train_images,test_images,train_segs,test_segs = train_test_split(images,segs,test_size = 0.2,random_state = 42) |
|
|
|
|
|
train_imtrans = Compose( |
|
[ |
|
LoadImage(image_only=True, ensure_channel_first=True), |
|
ScaleIntensity(), |
|
|
|
RandRotate90(prob=0.5, spatial_axes=(0, 1)), |
|
] |
|
) |
|
train_segtrans = Compose( |
|
[ |
|
LoadImage(image_only=True, ensure_channel_first=True), |
|
ScaleIntensity(), |
|
|
|
RandRotate90(prob=0.5, spatial_axes=(0, 1)), |
|
] |
|
) |
|
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) |
|
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) |
|
|
|
|
|
train_ds = ArrayDataset(train_images, train_imtrans, train_segs, train_segtrans) |
|
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) |
|
|
|
val_ds = ArrayDataset(test_images, val_imtrans, test_segs, val_segtrans) |
|
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available()) |
|
|
|
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) |
|
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = monai.networks.nets.UNETR( |
|
spatial_dims=2, |
|
in_channels=3, |
|
out_channels=1, |
|
img_size =(448,448), |
|
|
|
|
|
|
|
).to(device) |
|
|
|
loss_function = monai.losses.DiceLoss(sigmoid=True) |
|
optimizer = torch.optim.Adam(model.parameters(), 1e-3) |
|
|
|
|
|
val_interval = 2 |
|
best_metric = -1 |
|
best_metric_epoch = -1 |
|
epoch_loss_values = list() |
|
metric_values = list() |
|
writer = SummaryWriter() |
|
for epoch in range(500): |
|
print("-" * 10) |
|
print(f"epoch {epoch + 1}/{500}") |
|
model.train() |
|
epoch_loss = 0 |
|
step = 0 |
|
for batch_data in train_loader: |
|
step += 1 |
|
inputs, labels = batch_data[0].to(device), batch_data[1].to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = loss_function(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
epoch_loss += loss.item() |
|
epoch_len = len(train_ds) // train_loader.batch_size |
|
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") |
|
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) |
|
epoch_loss /= step |
|
epoch_loss_values.append(epoch_loss) |
|
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") |
|
|
|
if (epoch + 1) % val_interval == 0: |
|
model.eval() |
|
with torch.no_grad(): |
|
val_images = None |
|
val_labels = None |
|
val_outputs = None |
|
for val_data in val_loader: |
|
val_images, val_labels = val_data[0].to(device), val_data[1].to(device) |
|
roi_size = (448, 448) |
|
sw_batch_size = 4 |
|
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) |
|
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] |
|
|
|
dice_metric(y_pred=val_outputs, y=val_labels) |
|
|
|
metric = dice_metric.aggregate().item() |
|
|
|
dice_metric.reset() |
|
metric_values.append(metric) |
|
if metric > best_metric: |
|
best_metric = metric |
|
best_metric_epoch = epoch + 1 |
|
torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") |
|
print("saved new best metric model") |
|
print( |
|
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( |
|
epoch + 1, metric, best_metric, best_metric_epoch |
|
) |
|
) |
|
writer.add_scalar("val_mean_dice", metric, epoch + 1) |
|
|
|
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") |
|
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") |
|
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") |
|
|
|
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") |
|
writer.close() |