DuyTa's picture
Upload folder using huggingface_hub
70e068a verified
import os
from glob import glob
import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import monai
from monai.data import ArrayDataset, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
LoadImage,
RandRotate90,
ScaleIntensity,
#AsChannelFirst
)
from monai.visualize import plot_2d_or_3d_image
from PIL import Image
import cv2
import tifffile
import os
# Preprocess
def convert_to_png(img_dir):
# Lấy danh sách tệp tin trong thư mục ảnh
img_files = [file for file in os.listdir(img_dir) if file.endswith(('.jpg', '.jpeg', '.png', '.tif'))]
# Chuyển đổi từng ảnh sang định dạng .png
for img_file in img_files:
img_path = os.path.join(img_dir, img_file)
if img_file.endswith('.tif'):
# Đọc tệp .tif và chuyển đổi thành ảnh
with tifffile.TiffFile(img_path) as tif:
img = Image.fromarray(tif.asarray())
else:
# Đọc ảnh từ các định dạng khác
img = Image.open(img_path)
# Lưu ảnh dưới dạng .png
png_path = os.path.join(img_dir, os.path.splitext(img_file)[0] + '.png')
img.save(png_path)
convert_to_png("./tamp500/imgs")
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png")))
segs = sorted(glob(os.path.join("./tamp500/masks", "*.png")))
def resize_images_and_masks(image_paths, mask_paths, output_dir, target_width, target_height):
"""
Resize images and corresponding segmentation masks to the specified dimensions.
Args:
- image_paths (list): List of paths to the input images.
- mask_paths (list): List of paths to the segmentation masks.
- output_dir (str): Directory to save the resized images and masks.
- target_width (int): Target width for resizing.
- target_height (int): Target height for resizing.
Returns:
- resized_image_paths (list): List of paths to the resized images.
- resized_mask_paths (list): List of paths to the resized segmentation masks.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
resized_image_dir = os.path.join(output_dir, 'resized_images')
resized_mask_dir = os.path.join(output_dir, 'resized_masks')
if not os.path.exists(resized_image_dir):
os.makedirs(resized_image_dir)
if not os.path.exists(resized_mask_dir):
os.makedirs(resized_mask_dir)
resized_image_paths = []
resized_mask_paths = []
for img_path, mask_path in zip(image_paths, mask_paths):
# Read the image and mask
img = cv2.imread(img_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# Resize the image
resized_img = cv2.resize(img, (target_width, target_height))
# Resize the mask
resized_mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
# Extract the filename from the image path
img_filename = os.path.basename(img_path)
# Construct the output image path
output_img_path = os.path.join(resized_image_dir, img_filename)
# Write the resized image to the output path
cv2.imwrite(output_img_path, resized_img)
resized_image_paths.append(output_img_path)
# Extract the filename from the mask path
mask_filename = os.path.basename(mask_path)
# Construct the output mask path
output_mask_path = os.path.join(resized_mask_dir, mask_filename)
# Write the resized mask to the output path
cv2.imwrite(output_mask_path, resized_mask)
resized_mask_paths.append(output_mask_path)
return resized_image_paths, resized_mask_paths
images = sorted(glob(os.path.join("./tamp500/imgs", "*.png")))
masks = sorted(glob(os.path.join("./tamp500/masks", "*.png")))
output_directory = 'resized_'
target_width = 448
target_height = 448
resized_image_paths, resized_mask_paths = resize_images_and_masks(images, masks, output_directory, target_width, target_height)
images = sorted(glob(os.path.join("./resized_/resized_images", "*.png")))
segs = sorted(glob(os.path.join("./resized_/resized_masks", "*.png")))
from sklearn.model_selection import train_test_split
train_images,test_images,train_segs,test_segs = train_test_split(images,segs,test_size = 0.2,random_state = 42)
# define transforms for image and segmentation
train_imtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
#RandSpatialCrop((224, 224), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
train_segtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
#RandSpatialCrop((224, 224), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
# create a training data loader
train_ds = ArrayDataset(train_images, train_imtrans, train_segs, train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
# create a validation data loader
val_ds = ArrayDataset(test_images, val_imtrans, test_segs, val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = monai.networks.nets.UNet(
# spatial_dims=2,
# in_channels=3,
# out_channels=1,
# channels=(16, 32, 64, 128, 256),
# strides=(2, 2, 2, 2),
# num_res_units=2,
# ).to(device)
model = monai.networks.nets.UNETR(
spatial_dims=2,
in_channels=3,
out_channels=1,
img_size =(448,448),
#channels=(16, 32, 64, 128, 256),
#strides=(2, 2, 2, 2),
#num_res_units=2,
).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(500):
print("-" * 10)
print(f"epoch {epoch + 1}/{500}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
val_images = None
val_labels = None
val_outputs = None
for val_data in val_loader:
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
roi_size = (448, 448)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
print("saved new best metric model")
print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch + 1, metric, best_metric, best_metric_epoch
)
)
writer.add_scalar("val_mean_dice", metric, epoch + 1)
# plot the last model output as GIF image in TensorBoard with the corresponding image and label
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()