import argparse import os import pickle import sys import nibabel as nib import numpy as np import scipy import SimpleITK as sitk from scipy import ndimage as ndi def loadNiiToArray(path): NiImg = nib.load(path) array = np.array(NiImg.dataobj) return array def loadNiiWithSitk(path): reader = sitk.ImageFileReader() reader.SetImageIO("NiftiImageIO") reader.SetFileName(path) image = reader.Execute() array = sitk.GetArrayFromImage(image) return array def loadNiiImageWithSitk(path): reader = sitk.ImageFileReader() reader.SetImageIO("NiftiImageIO") reader.SetFileName(path) image = reader.Execute() # invert the image to be compatible with Nibabel image = sitk.Flip(image, [False, True, False]) return image def keep_masked_values(arr, mask): # Get the indices of the non-zero elements in the mask mask_indices = np.nonzero(mask) # Use the indices to select the corresponding elements from the array masked_values = arr[mask_indices] # Return the selected elements as a new array return masked_values def get_stats(arr): # # Get the indices of the non-zero elements in the array # nonzero_indices = np.nonzero(arr) # # Use the indices to get the non-zero elements of the array # nonzero_elements = arr[nonzero_indices] nonzero_elements = arr # Calculate the stats for the non-zero elements max_val = np.max(nonzero_elements) min_val = np.min(nonzero_elements) mean_val = np.mean(nonzero_elements) median_val = np.median(nonzero_elements) std_val = np.std(nonzero_elements) variance_val = np.var(nonzero_elements) return max_val, min_val, mean_val, median_val, std_val, variance_val def getMaskAnteriorAtrium(mask): erasePreAtriumMask = mask.copy() for sliceNum in range(mask.shape[-1]): mask2D = mask[:, :, sliceNum] itemindex = np.where(mask2D == 1) if itemindex[0].size > 0: row = itemindex[0][0] erasePreAtriumMask[:, :, sliceNum][:row, :] = 1 return erasePreAtriumMask """ Function from https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays/46314485#46314485 """ def fill_hull(image): points = np.transpose(np.where(image)) hull = scipy.spatial.ConvexHull(points) deln = scipy.spatial.Delaunay(points[hull.vertices]) idx = np.stack(np.indices(image.shape), axis=-1) out_idx = np.nonzero(deln.find_simplex(idx) + 1) out_img = np.zeros(image.shape) out_img[out_idx] = 1 return out_img def getClassBinaryMask(TSOutArray, classNum): binaryMask = np.zeros(TSOutArray.shape) binaryMask[TSOutArray == classNum] = 1 return binaryMask def loadNiftis(TSNiftiPath, imageNiftiPath): TSArray = loadNiiToArray(TSNiftiPath) scanArray = loadNiiToArray(imageNiftiPath) return TSArray, scanArray # function to select one slice from 3D volume of SimpleITK image def selectSlice(scanImage, zslice): size = list(scanImage.GetSize()) size[2] = 0 index = [0, 0, zslice] Extractor = sitk.ExtractImageFilter() Extractor.SetSize(size) Extractor.SetIndex(index) sliceImage = Extractor.Execute(scanImage) return sliceImage # function to apply windowing def windowing(sliceImage, center=400, width=400): windowMinimum = center - (width / 2) windowMaximum = center + (width / 2) img_255 = sitk.Cast( sitk.IntensityWindowing( sliceImage, windowMinimum=-windowMinimum, windowMaximum=windowMaximum, outputMinimum=0.0, outputMaximum=255.0, ), sitk.sitkUInt8, ) return img_255 def selectSampleSlice(kidneyLMask, adRMask, scanImage): # Get the middle slice of the kidney mask from where there is the first 1 value to the last 1 value middleSlice = np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] + int( ( np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][-1] - np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] ) / 2 ) # print("Middle slice: ", middleSlice) # make middleSlice int middleSlice = int(middleSlice) # select one slice using simple itk sliceImageK = selectSlice(scanImage, middleSlice) # Get the middle slice of the addrenal mask from where there is the first 1 value to the last 1 value middleSlice = np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] + int( ( np.where(adRMask.sum(axis=(0, 1)) > 0)[0][-1] - np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] ) / 2 ) # print("Middle slice: ", middleSlice) # make middleSlice int middleSlice = int(middleSlice) # select one slice using simple itk sliceImageA = selectSlice(scanImage, middleSlice) sliceImageK = windowing(sliceImageK) sliceImageA = windowing(sliceImageA) return sliceImageK, sliceImageA def getFeatures(TSArray, scanArray): aortaMask = getClassBinaryMask(TSArray, 7) IVCMask = getClassBinaryMask(TSArray, 8) portalMask = getClassBinaryMask(TSArray, 9) atriumMask = getClassBinaryMask(TSArray, 45) kidneyLMask = getClassBinaryMask(TSArray, 3) kidneyRMask = getClassBinaryMask(TSArray, 2) adRMask = getClassBinaryMask(TSArray, 11) # Remove toraccic aorta adn IVC from aorta and IVC masks anteriorAtriumMask = getMaskAnteriorAtrium(atriumMask) aortaMask = aortaMask * (anteriorAtriumMask == 0) IVCMask = IVCMask * (anteriorAtriumMask == 0) # Erode vessels to get only the center of the vessels struct2 = np.ones((3, 3, 3)) aortaMaskEroded = ndi.binary_erosion(aortaMask, structure=struct2).astype( aortaMask.dtype ) IVCMaskEroded = ndi.binary_erosion(IVCMask, structure=struct2).astype(IVCMask.dtype) struct3 = np.ones((1, 1, 1)) portalMaskEroded = ndi.binary_erosion(portalMask, structure=struct3).astype( portalMask.dtype ) # If portalMaskEroded has less then 500 values, use the original portalMask if np.count_nonzero(portalMaskEroded) < 500: portalMaskEroded = portalMask # Get masked values from scan aortaArray = keep_masked_values(scanArray, aortaMaskEroded) IVCArray = keep_masked_values(scanArray, IVCMaskEroded) portalArray = keep_masked_values(scanArray, portalMaskEroded) kidneyLArray = keep_masked_values(scanArray, kidneyLMask) kidneyRArray = keep_masked_values(scanArray, kidneyRMask) """Put this on a separate function and return only the pelvis arrays""" # process the Renal Pelvis masks from the Kidney masks # create the convex hull of the Left Kidney kidneyLHull = fill_hull(kidneyLMask) # exclude the Left Kidney mask from the Left Convex Hull kidneyLHull = kidneyLHull * (kidneyLMask == 0) # erode the kidneyHull to remove the edges struct = np.ones((3, 3, 3)) kidneyLHull = ndi.binary_erosion(kidneyLHull, structure=struct).astype( kidneyLHull.dtype ) # keep the values of the scanArray that are in the Left Convex Hull pelvisLArray = keep_masked_values(scanArray, kidneyLHull) # create the convex hull of the Right Kidney kidneyRHull = fill_hull(kidneyRMask) # exclude the Right Kidney mask from the Right Convex Hull kidneyRHull = kidneyRHull * (kidneyRMask == 0) # erode the kidneyHull to remove the edges struct = np.ones((3, 3, 3)) kidneyRHull = ndi.binary_erosion(kidneyRHull, structure=struct).astype( kidneyRHull.dtype ) # keep the values of the scanArray that are in the Right Convex Hull pelvisRArray = keep_masked_values(scanArray, kidneyRHull) # Get the stats # Get the stats for the aortaArray ( aorta_max_val, aorta_min_val, aorta_mean_val, aorta_median_val, aorta_std_val, aorta_variance_val, ) = get_stats(aortaArray) # Get the stats for the IVCArray ( IVC_max_val, IVC_min_val, IVC_mean_val, IVC_median_val, IVC_std_val, IVC_variance_val, ) = get_stats(IVCArray) # Get the stats for the portalArray ( portal_max_val, portal_min_val, portal_mean_val, portal_median_val, portal_std_val, portal_variance_val, ) = get_stats(portalArray) # Get the stats for the kidneyLArray and kidneyRArray ( kidneyL_max_val, kidneyL_min_val, kidneyL_mean_val, kidneyL_median_val, kidneyL_std_val, kidneyL_variance_val, ) = get_stats(kidneyLArray) ( kidneyR_max_val, kidneyR_min_val, kidneyR_mean_val, kidneyR_median_val, kidneyR_std_val, kidneyR_variance_val, ) = get_stats(kidneyRArray) ( pelvisL_max_val, pelvisL_min_val, pelvisL_mean_val, pelvisL_median_val, pelvisL_std_val, pelvisL_variance_val, ) = get_stats(pelvisLArray) ( pelvisR_max_val, pelvisR_min_val, pelvisR_mean_val, pelvisR_median_val, pelvisR_std_val, pelvisR_variance_val, ) = get_stats(pelvisRArray) # create three new columns for the decision tree # aorta - porta, Max min and mean columns aorta_porta_max = aorta_max_val - portal_max_val aorta_porta_min = aorta_min_val - portal_min_val aorta_porta_mean = aorta_mean_val - portal_mean_val # aorta - IVC, Max min and mean columns aorta_IVC_max = aorta_max_val - IVC_max_val aorta_IVC_min = aorta_min_val - IVC_min_val aorta_IVC_mean = aorta_mean_val - IVC_mean_val # Save stats in CSV: # Create a list to store the stats stats = [] # Add the stats for the aortaArray to the list stats.extend( [ aorta_max_val, aorta_min_val, aorta_mean_val, aorta_median_val, aorta_std_val, aorta_variance_val, ] ) # Add the stats for the IVCArray to the list stats.extend( [ IVC_max_val, IVC_min_val, IVC_mean_val, IVC_median_val, IVC_std_val, IVC_variance_val, ] ) # Add the stats for the portalArray to the list stats.extend( [ portal_max_val, portal_min_val, portal_mean_val, portal_median_val, portal_std_val, portal_variance_val, ] ) # Add the stats for the kidneyLArray and kidneyRArray to the list stats.extend( [ kidneyL_max_val, kidneyL_min_val, kidneyL_mean_val, kidneyL_median_val, kidneyL_std_val, kidneyL_variance_val, ] ) stats.extend( [ kidneyR_max_val, kidneyR_min_val, kidneyR_mean_val, kidneyR_median_val, kidneyR_std_val, kidneyR_variance_val, ] ) # Add the stats for the kidneyLHull and kidneyRHull to the list stats.extend( [ pelvisL_max_val, pelvisL_min_val, pelvisL_mean_val, pelvisL_median_val, pelvisL_std_val, pelvisL_variance_val, ] ) stats.extend( [ pelvisR_max_val, pelvisR_min_val, pelvisR_mean_val, pelvisR_median_val, pelvisR_std_val, pelvisR_variance_val, ] ) stats.extend( [ aorta_porta_max, aorta_porta_min, aorta_porta_mean, aorta_IVC_max, aorta_IVC_min, aorta_IVC_mean, ] ) return stats, kidneyLMask, adRMask def loadModel(): c2cPath = os.path.dirname(sys.path[0]) filename = os.path.join(c2cPath, "comp2comp", "contrast_phase", "xgboost.pkl") model = pickle.load(open(filename, "rb")) return model def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): TS_array, image_array = loadNiftis(TS_path, scan_path) model = loadModel() # TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path) featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array) y_pred = model.predict([featureArray]) if y_pred == 0: pred_phase = "non-contrast" if y_pred == 1: pred_phase = "arterial" if y_pred == 2: pred_phase = "venous" if y_pred == 3: pred_phase = "delayed" output_path_metrics = os.path.join(outputPath, "metrics") if not os.path.exists(output_path_metrics): os.makedirs(output_path_metrics) outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt") with open(outputTxt, "w") as text_file: text_file.write(pred_phase) print(pred_phase) output_path_images = os.path.join(outputPath, "images") if not os.path.exists(output_path_images): os.makedirs(output_path_images) scanImage = loadNiiImageWithSitk(scan_path) sliceImageK, sliceImageA = selectSampleSlice(kidneyLMask, adRMask, scanImage) outJpgK = os.path.join(output_path_images, "sampleSliceKidney.png") sitk.WriteImage(sliceImageK, outJpgK) outJpgA = os.path.join(output_path_images, "sampleSliceAdrenal.png") sitk.WriteImage(sliceImageA, outJpgA) if __name__ == "__main__": # parse arguments optional parser = argparse.ArgumentParser() parser.add_argument("--TS_path", type=str, required=True, help="Input image") parser.add_argument("--scan_path", type=str, required=True, help="Input image") parser.add_argument( "--output_dir", type=str, required=False, help="Output .txt prediction", default=None, ) parser.add_argument( "--save_sample", type=bool, required=False, help="Save jpeg sample ", default=False, ) args = parser.parse_args() predict_phase(args.TS_path, args.scan_path, args.output_dir, args.save_sample)