import os |
from glob import glob |
import torch |
from PIL import Image |
from torch.utils.tensorboard import SummaryWriter |
import monai |
from monai.data import ArrayDataset, decollate_batch, DataLoader |
from monai.inferers import sliding_window_inference |
from monai.metrics import DiceMetric |
from monai.transforms import ( |
Activations, |
AsDiscrete, |
Compose, |
LoadImage, |
RandRotate90, |
ScaleIntensity, |
) |
from monai.visualize import plot_2d_or_3d_image |
from PIL import Image |
import cv2 |
import tifffile |
import os |
def convert_to_png(img_dir): |
img_files = [file for file in os.listdir(img_dir) if file.endswith(('.jpg', '.jpeg', '.png', '.tif'))] |
for img_file in img_files: |
img_path = os.path.join(img_dir, img_file) |
if img_file.endswith('.tif'): |
with tifffile.TiffFile(img_path) as tif: |
img = Image.fromarray(tif.asarray()) |
else: |
img = Image.open(img_path) |
png_path = os.path.join(img_dir, os.path.splitext(img_file)[0] + '.png') |
img.save(png_path) |
convert_to_png("./tamp500/imgs") |
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) |
segs = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) |
def resize_images_and_masks(image_paths, mask_paths, output_dir, target_width, target_height): |
""" |
Resize images and corresponding segmentation masks to the specified dimensions. |
Args: |
- image_paths (list): List of paths to the input images. |
- mask_paths (list): List of paths to the segmentation masks. |
- output_dir (str): Directory to save the resized images and masks. |
- target_width (int): Target width for resizing. |
- target_height (int): Target height for resizing. |
Returns: |
- resized_image_paths (list): List of paths to the resized images. |
- resized_mask_paths (list): List of paths to the resized segmentation masks. |
""" |
if not os.path.exists(output_dir): |
os.makedirs(output_dir) |
resized_image_dir = os.path.join(output_dir, 'resized_images') |
resized_mask_dir = os.path.join(output_dir, 'resized_masks') |
if not os.path.exists(resized_image_dir): |
os.makedirs(resized_image_dir) |
if not os.path.exists(resized_mask_dir): |
os.makedirs(resized_mask_dir) |
resized_image_paths = [] |
resized_mask_paths = [] |
for img_path, mask_path in zip(image_paths, mask_paths): |
img = cv2.imread(img_path) |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
resized_img = cv2.resize(img, (target_width, target_height)) |
resized_mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) |
img_filename = os.path.basename(img_path) |
output_img_path = os.path.join(resized_image_dir, img_filename) |
cv2.imwrite(output_img_path, resized_img) |
resized_image_paths.append(output_img_path) |
mask_filename = os.path.basename(mask_path) |
output_mask_path = os.path.join(resized_mask_dir, mask_filename) |
cv2.imwrite(output_mask_path, resized_mask) |
resized_mask_paths.append(output_mask_path) |
return resized_image_paths, resized_mask_paths |
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) |
masks = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) |
output_directory = 'resized_' |
target_width = 448 |
target_height = 448 |
resized_image_paths, resized_mask_paths = resize_images_and_masks(images, masks, output_directory, target_width, target_height) |
images = sorted(glob(os.path.join("./resized_/resized_images", "*.png"))) |
segs = sorted(glob(os.path.join("./resized_/resized_masks", "*.png"))) |
from sklearn.model_selection import train_test_split |
train_images,test_images,train_segs,test_segs = train_test_split(images,segs,test_size = 0.2,random_state = 42) |
train_imtrans = Compose( |
[ |
LoadImage(image_only=True, ensure_channel_first=True), |
ScaleIntensity(), |
RandRotate90(prob=0.5, spatial_axes=(0, 1)), |
] |
) |
train_segtrans = Compose( |
[ |
LoadImage(image_only=True, ensure_channel_first=True), |
ScaleIntensity(), |
RandRotate90(prob=0.5, spatial_axes=(0, 1)), |
] |
) |
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) |
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) |
train_ds = ArrayDataset(train_images, train_imtrans, train_segs, train_segtrans) |
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) |
val_ds = ArrayDataset(test_images, val_imtrans, test_segs, val_segtrans) |
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available()) |
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) |
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
model = monai.networks.nets.UNETR( |
spatial_dims=2, |
in_channels=3, |
out_channels=1, |
img_size =(448,448), |
).to(device) |
loss_function = monai.losses.DiceLoss(sigmoid=True) |
optimizer = torch.optim.Adam(model.parameters(), 1e-3) |
val_interval = 2 |
best_metric = -1 |
best_metric_epoch = -1 |
epoch_loss_values = list() |
metric_values = list() |
writer = SummaryWriter() |
for epoch in range(500): |
print("-" * 10) |
print(f"epoch {epoch + 1}/{500}") |
model.train() |
epoch_loss = 0 |
step = 0 |
for batch_data in train_loader: |
step += 1 |
inputs, labels = batch_data[0].to(device), batch_data[1].to(device) |
optimizer.zero_grad() |
outputs = model(inputs) |
loss = loss_function(outputs, labels) |
loss.backward() |
optimizer.step() |
epoch_loss += loss.item() |
epoch_len = len(train_ds) // train_loader.batch_size |
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") |
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) |
epoch_loss /= step |
epoch_loss_values.append(epoch_loss) |
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") |
if (epoch + 1) % val_interval == 0: |
model.eval() |
with torch.no_grad(): |
val_images = None |
val_labels = None |
val_outputs = None |
for val_data in val_loader: |
val_images, val_labels = val_data[0].to(device), val_data[1].to(device) |
roi_size = (448, 448) |
sw_batch_size = 4 |
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) |
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] |
dice_metric(y_pred=val_outputs, y=val_labels) |
metric = dice_metric.aggregate().item() |
dice_metric.reset() |
metric_values.append(metric) |
if metric > best_metric: |
best_metric = metric |
best_metric_epoch = epoch + 1 |
torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") |
print("saved new best metric model") |
print( |
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( |
epoch + 1, metric, best_metric, best_metric_epoch |
) |
) |
writer.add_scalar("val_mean_dice", metric, epoch + 1) |
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") |
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") |
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") |
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") |
writer.close() |