Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import pickle | |
import sys | |
import nibabel as nib | |
import numpy as np | |
import scipy | |
import SimpleITK as sitk | |
from scipy import ndimage as ndi | |
def loadNiiToArray(path): | |
NiImg = nib.load(path) | |
array = np.array(NiImg.dataobj) | |
return array | |
def loadNiiWithSitk(path): | |
reader = sitk.ImageFileReader() | |
reader.SetImageIO("NiftiImageIO") | |
reader.SetFileName(path) | |
image = reader.Execute() | |
array = sitk.GetArrayFromImage(image) | |
return array | |
def loadNiiImageWithSitk(path): | |
reader = sitk.ImageFileReader() | |
reader.SetImageIO("NiftiImageIO") | |
reader.SetFileName(path) | |
image = reader.Execute() | |
# invert the image to be compatible with Nibabel | |
image = sitk.Flip(image, [False, True, False]) | |
return image | |
def keep_masked_values(arr, mask): | |
# Get the indices of the non-zero elements in the mask | |
mask_indices = np.nonzero(mask) | |
# Use the indices to select the corresponding elements from the array | |
masked_values = arr[mask_indices] | |
# Return the selected elements as a new array | |
return masked_values | |
def get_stats(arr): | |
# # Get the indices of the non-zero elements in the array | |
# nonzero_indices = np.nonzero(arr) | |
# # Use the indices to get the non-zero elements of the array | |
# nonzero_elements = arr[nonzero_indices] | |
nonzero_elements = arr | |
# Calculate the stats for the non-zero elements | |
max_val = np.max(nonzero_elements) | |
min_val = np.min(nonzero_elements) | |
mean_val = np.mean(nonzero_elements) | |
median_val = np.median(nonzero_elements) | |
std_val = np.std(nonzero_elements) | |
variance_val = np.var(nonzero_elements) | |
return max_val, min_val, mean_val, median_val, std_val, variance_val | |
def getMaskAnteriorAtrium(mask): | |
erasePreAtriumMask = mask.copy() | |
for sliceNum in range(mask.shape[-1]): | |
mask2D = mask[:, :, sliceNum] | |
itemindex = np.where(mask2D == 1) | |
if itemindex[0].size > 0: | |
row = itemindex[0][0] | |
erasePreAtriumMask[:, :, sliceNum][:row, :] = 1 | |
return erasePreAtriumMask | |
""" | |
Function from | |
https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays/46314485#46314485 | |
""" | |
def fill_hull(image): | |
points = np.transpose(np.where(image)) | |
hull = scipy.spatial.ConvexHull(points) | |
deln = scipy.spatial.Delaunay(points[hull.vertices]) | |
idx = np.stack(np.indices(image.shape), axis=-1) | |
out_idx = np.nonzero(deln.find_simplex(idx) + 1) | |
out_img = np.zeros(image.shape) | |
out_img[out_idx] = 1 | |
return out_img | |
def getClassBinaryMask(TSOutArray, classNum): | |
binaryMask = np.zeros(TSOutArray.shape) | |
binaryMask[TSOutArray == classNum] = 1 | |
return binaryMask | |
def loadNiftis(TSNiftiPath, imageNiftiPath): | |
TSArray = loadNiiToArray(TSNiftiPath) | |
scanArray = loadNiiToArray(imageNiftiPath) | |
return TSArray, scanArray | |
# function to select one slice from 3D volume of SimpleITK image | |
def selectSlice(scanImage, zslice): | |
size = list(scanImage.GetSize()) | |
size[2] = 0 | |
index = [0, 0, zslice] | |
Extractor = sitk.ExtractImageFilter() | |
Extractor.SetSize(size) | |
Extractor.SetIndex(index) | |
sliceImage = Extractor.Execute(scanImage) | |
return sliceImage | |
# function to apply windowing | |
def windowing(sliceImage, center=400, width=400): | |
windowMinimum = center - (width / 2) | |
windowMaximum = center + (width / 2) | |
img_255 = sitk.Cast( | |
sitk.IntensityWindowing( | |
sliceImage, | |
windowMinimum=-windowMinimum, | |
windowMaximum=windowMaximum, | |
outputMinimum=0.0, | |
outputMaximum=255.0, | |
), | |
sitk.sitkUInt8, | |
) | |
return img_255 | |
def selectSampleSlice(kidneyLMask, adRMask, scanImage): | |
# Get the middle slice of the kidney mask from where there is the first 1 value to the last 1 value | |
middleSlice = np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] + int( | |
( | |
np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][-1] | |
- np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] | |
) | |
/ 2 | |
) | |
# print("Middle slice: ", middleSlice) | |
# make middleSlice int | |
middleSlice = int(middleSlice) | |
# select one slice using simple itk | |
sliceImageK = selectSlice(scanImage, middleSlice) | |
# Get the middle slice of the addrenal mask from where there is the first 1 value to the last 1 value | |
middleSlice = np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] + int( | |
( | |
np.where(adRMask.sum(axis=(0, 1)) > 0)[0][-1] | |
- np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] | |
) | |
/ 2 | |
) | |
# print("Middle slice: ", middleSlice) | |
# make middleSlice int | |
middleSlice = int(middleSlice) | |
# select one slice using simple itk | |
sliceImageA = selectSlice(scanImage, middleSlice) | |
sliceImageK = windowing(sliceImageK) | |
sliceImageA = windowing(sliceImageA) | |
return sliceImageK, sliceImageA | |
def getFeatures(TSArray, scanArray): | |
aortaMask = getClassBinaryMask(TSArray, 7) | |
IVCMask = getClassBinaryMask(TSArray, 8) | |
portalMask = getClassBinaryMask(TSArray, 9) | |
atriumMask = getClassBinaryMask(TSArray, 45) | |
kidneyLMask = getClassBinaryMask(TSArray, 3) | |
kidneyRMask = getClassBinaryMask(TSArray, 2) | |
adRMask = getClassBinaryMask(TSArray, 11) | |
# Remove toraccic aorta adn IVC from aorta and IVC masks | |
anteriorAtriumMask = getMaskAnteriorAtrium(atriumMask) | |
aortaMask = aortaMask * (anteriorAtriumMask == 0) | |
IVCMask = IVCMask * (anteriorAtriumMask == 0) | |
# Erode vessels to get only the center of the vessels | |
struct2 = np.ones((3, 3, 3)) | |
aortaMaskEroded = ndi.binary_erosion(aortaMask, structure=struct2).astype( | |
aortaMask.dtype | |
) | |
IVCMaskEroded = ndi.binary_erosion(IVCMask, structure=struct2).astype(IVCMask.dtype) | |
struct3 = np.ones((1, 1, 1)) | |
portalMaskEroded = ndi.binary_erosion(portalMask, structure=struct3).astype( | |
portalMask.dtype | |
) | |
# If portalMaskEroded has less then 500 values, use the original portalMask | |
if np.count_nonzero(portalMaskEroded) < 500: | |
portalMaskEroded = portalMask | |
# Get masked values from scan | |
aortaArray = keep_masked_values(scanArray, aortaMaskEroded) | |
IVCArray = keep_masked_values(scanArray, IVCMaskEroded) | |
portalArray = keep_masked_values(scanArray, portalMaskEroded) | |
kidneyLArray = keep_masked_values(scanArray, kidneyLMask) | |
kidneyRArray = keep_masked_values(scanArray, kidneyRMask) | |
"""Put this on a separate function and return only the pelvis arrays""" | |
# process the Renal Pelvis masks from the Kidney masks | |
# create the convex hull of the Left Kidney | |
kidneyLHull = fill_hull(kidneyLMask) | |
# exclude the Left Kidney mask from the Left Convex Hull | |
kidneyLHull = kidneyLHull * (kidneyLMask == 0) | |
# erode the kidneyHull to remove the edges | |
struct = np.ones((3, 3, 3)) | |
kidneyLHull = ndi.binary_erosion(kidneyLHull, structure=struct).astype( | |
kidneyLHull.dtype | |
) | |
# keep the values of the scanArray that are in the Left Convex Hull | |
pelvisLArray = keep_masked_values(scanArray, kidneyLHull) | |
# create the convex hull of the Right Kidney | |
kidneyRHull = fill_hull(kidneyRMask) | |
# exclude the Right Kidney mask from the Right Convex Hull | |
kidneyRHull = kidneyRHull * (kidneyRMask == 0) | |
# erode the kidneyHull to remove the edges | |
struct = np.ones((3, 3, 3)) | |
kidneyRHull = ndi.binary_erosion(kidneyRHull, structure=struct).astype( | |
kidneyRHull.dtype | |
) | |
# keep the values of the scanArray that are in the Right Convex Hull | |
pelvisRArray = keep_masked_values(scanArray, kidneyRHull) | |
# Get the stats | |
# Get the stats for the aortaArray | |
( | |
aorta_max_val, | |
aorta_min_val, | |
aorta_mean_val, | |
aorta_median_val, | |
aorta_std_val, | |
aorta_variance_val, | |
) = get_stats(aortaArray) | |
# Get the stats for the IVCArray | |
( | |
IVC_max_val, | |
IVC_min_val, | |
IVC_mean_val, | |
IVC_median_val, | |
IVC_std_val, | |
IVC_variance_val, | |
) = get_stats(IVCArray) | |
# Get the stats for the portalArray | |
( | |
portal_max_val, | |
portal_min_val, | |
portal_mean_val, | |
portal_median_val, | |
portal_std_val, | |
portal_variance_val, | |
) = get_stats(portalArray) | |
# Get the stats for the kidneyLArray and kidneyRArray | |
( | |
kidneyL_max_val, | |
kidneyL_min_val, | |
kidneyL_mean_val, | |
kidneyL_median_val, | |
kidneyL_std_val, | |
kidneyL_variance_val, | |
) = get_stats(kidneyLArray) | |
( | |
kidneyR_max_val, | |
kidneyR_min_val, | |
kidneyR_mean_val, | |
kidneyR_median_val, | |
kidneyR_std_val, | |
kidneyR_variance_val, | |
) = get_stats(kidneyRArray) | |
( | |
pelvisL_max_val, | |
pelvisL_min_val, | |
pelvisL_mean_val, | |
pelvisL_median_val, | |
pelvisL_std_val, | |
pelvisL_variance_val, | |
) = get_stats(pelvisLArray) | |
( | |
pelvisR_max_val, | |
pelvisR_min_val, | |
pelvisR_mean_val, | |
pelvisR_median_val, | |
pelvisR_std_val, | |
pelvisR_variance_val, | |
) = get_stats(pelvisRArray) | |
# create three new columns for the decision tree | |
# aorta - porta, Max min and mean columns | |
aorta_porta_max = aorta_max_val - portal_max_val | |
aorta_porta_min = aorta_min_val - portal_min_val | |
aorta_porta_mean = aorta_mean_val - portal_mean_val | |
# aorta - IVC, Max min and mean columns | |
aorta_IVC_max = aorta_max_val - IVC_max_val | |
aorta_IVC_min = aorta_min_val - IVC_min_val | |
aorta_IVC_mean = aorta_mean_val - IVC_mean_val | |
# Save stats in CSV: | |
# Create a list to store the stats | |
stats = [] | |
# Add the stats for the aortaArray to the list | |
stats.extend( | |
[ | |
aorta_max_val, | |
aorta_min_val, | |
aorta_mean_val, | |
aorta_median_val, | |
aorta_std_val, | |
aorta_variance_val, | |
] | |
) | |
# Add the stats for the IVCArray to the list | |
stats.extend( | |
[ | |
IVC_max_val, | |
IVC_min_val, | |
IVC_mean_val, | |
IVC_median_val, | |
IVC_std_val, | |
IVC_variance_val, | |
] | |
) | |
# Add the stats for the portalArray to the list | |
stats.extend( | |
[ | |
portal_max_val, | |
portal_min_val, | |
portal_mean_val, | |
portal_median_val, | |
portal_std_val, | |
portal_variance_val, | |
] | |
) | |
# Add the stats for the kidneyLArray and kidneyRArray to the list | |
stats.extend( | |
[ | |
kidneyL_max_val, | |
kidneyL_min_val, | |
kidneyL_mean_val, | |
kidneyL_median_val, | |
kidneyL_std_val, | |
kidneyL_variance_val, | |
] | |
) | |
stats.extend( | |
[ | |
kidneyR_max_val, | |
kidneyR_min_val, | |
kidneyR_mean_val, | |
kidneyR_median_val, | |
kidneyR_std_val, | |
kidneyR_variance_val, | |
] | |
) | |
# Add the stats for the kidneyLHull and kidneyRHull to the list | |
stats.extend( | |
[ | |
pelvisL_max_val, | |
pelvisL_min_val, | |
pelvisL_mean_val, | |
pelvisL_median_val, | |
pelvisL_std_val, | |
pelvisL_variance_val, | |
] | |
) | |
stats.extend( | |
[ | |
pelvisR_max_val, | |
pelvisR_min_val, | |
pelvisR_mean_val, | |
pelvisR_median_val, | |
pelvisR_std_val, | |
pelvisR_variance_val, | |
] | |
) | |
stats.extend( | |
[ | |
aorta_porta_max, | |
aorta_porta_min, | |
aorta_porta_mean, | |
aorta_IVC_max, | |
aorta_IVC_min, | |
aorta_IVC_mean, | |
] | |
) | |
return stats, kidneyLMask, adRMask | |
def loadModel(): | |
c2cPath = os.path.dirname(sys.path[0]) | |
filename = os.path.join(c2cPath, "comp2comp", "contrast_phase", "xgboost.pkl") | |
model = pickle.load(open(filename, "rb")) | |
return model | |
def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): | |
TS_array, image_array = loadNiftis(TS_path, scan_path) | |
model = loadModel() | |
# TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path) | |
featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array) | |
y_pred = model.predict([featureArray]) | |
if y_pred == 0: | |
pred_phase = "non-contrast" | |
if y_pred == 1: | |
pred_phase = "arterial" | |
if y_pred == 2: | |
pred_phase = "venous" | |
if y_pred == 3: | |
pred_phase = "delayed" | |
output_path_metrics = os.path.join(outputPath, "metrics") | |
if not os.path.exists(output_path_metrics): | |
os.makedirs(output_path_metrics) | |
outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt") | |
with open(outputTxt, "w") as text_file: | |
text_file.write(pred_phase) | |
print(pred_phase) | |
output_path_images = os.path.join(outputPath, "images") | |
if not os.path.exists(output_path_images): | |
os.makedirs(output_path_images) | |
scanImage = loadNiiImageWithSitk(scan_path) | |
sliceImageK, sliceImageA = selectSampleSlice(kidneyLMask, adRMask, scanImage) | |
outJpgK = os.path.join(output_path_images, "sampleSliceKidney.png") | |
sitk.WriteImage(sliceImageK, outJpgK) | |
outJpgA = os.path.join(output_path_images, "sampleSliceAdrenal.png") | |
sitk.WriteImage(sliceImageA, outJpgA) | |
if __name__ == "__main__": | |
# parse arguments optional | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--TS_path", type=str, required=True, help="Input image") | |
parser.add_argument("--scan_path", type=str, required=True, help="Input image") | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
required=False, | |
help="Output .txt prediction", | |
default=None, | |
) | |
parser.add_argument( | |
"--save_sample", | |
type=bool, | |
required=False, | |
help="Save jpeg sample ", | |
default=False, | |
) | |
args = parser.parse_args() | |
predict_phase(args.TS_path, args.scan_path, args.output_dir, args.save_sample) | |