AdritRao's picture
Upload 62 files
a3290d1
raw
history blame
No virus
12.4 kB
"""
@author: louisblankemeier
"""
import math
import os
import shutil
import cv2
import nibabel as nib
import numpy as np
import scipy.ndimage as ndi
from scipy.ndimage import zoom
from skimage.morphology import ball, binary_erosion
from comp2comp.hip.hip_visualization import method_visualizer
def compute_rois(medical_volume, segmentation, model, output_dir, save=False):
left_femur_mask = segmentation.get_fdata() == model.categories["femur_left"]
left_femur_mask = left_femur_mask.astype(np.uint8)
right_femur_mask = segmentation.get_fdata() == model.categories["femur_right"]
right_femur_mask = right_femur_mask.astype(np.uint8)
left_head_roi, left_head_centroid, left_head_hu = get_femural_head_roi(
left_femur_mask, medical_volume, output_dir, "left_head"
)
right_head_roi, right_head_centroid, right_head_hu = get_femural_head_roi(
right_femur_mask, medical_volume, output_dir, "right_head"
)
(
left_intertrochanter_roi,
left_intertrochanter_centroid,
left_intertrochanter_hu,
) = get_femural_head_roi(
left_femur_mask, medical_volume, output_dir, "left_intertrochanter"
)
(
right_intertrochanter_roi,
right_intertrochanter_centroid,
right_intertrochanter_hu,
) = get_femural_head_roi(
right_femur_mask, medical_volume, output_dir, "right_intertrochanter"
)
(
left_neck_roi,
left_neck_centroid,
left_neck_hu,
) = get_femural_neck_roi(
left_femur_mask,
medical_volume,
left_intertrochanter_roi,
left_intertrochanter_centroid,
left_head_roi,
left_head_centroid,
output_dir,
)
(
right_neck_roi,
right_neck_centroid,
right_neck_hu,
) = get_femural_neck_roi(
right_femur_mask,
medical_volume,
right_intertrochanter_roi,
right_intertrochanter_centroid,
right_head_roi,
right_head_centroid,
output_dir,
)
combined_roi = (
left_head_roi
+ (right_head_roi) # * 2)
+ (left_intertrochanter_roi) # * 3)
+ (right_intertrochanter_roi) # * 4)
+ (left_neck_roi) # * 5)
+ (right_neck_roi) # * 6)
)
if save:
# make roi directory if it doesn't exist
parent_output_dir = os.path.dirname(output_dir)
roi_output_dir = os.path.join(parent_output_dir, "rois")
if not os.path.exists(roi_output_dir):
os.makedirs(roi_output_dir)
# Convert left ROI to NIfTI
left_roi_nifti = nib.Nifti1Image(combined_roi, medical_volume.affine)
left_roi_path = os.path.join(roi_output_dir, "roi.nii.gz")
nib.save(left_roi_nifti, left_roi_path)
shutil.copy(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"tunnelvision.ipynb",
),
parent_output_dir,
)
return {
"left_head": {
"roi": left_head_roi,
"centroid": left_head_centroid,
"hu": left_head_hu,
},
"right_head": {
"roi": right_head_roi,
"centroid": right_head_centroid,
"hu": right_head_hu,
},
"left_intertrochanter": {
"roi": left_intertrochanter_roi,
"centroid": left_intertrochanter_centroid,
"hu": left_intertrochanter_hu,
},
"right_intertrochanter": {
"roi": right_intertrochanter_roi,
"centroid": right_intertrochanter_centroid,
"hu": right_intertrochanter_hu,
},
"left_neck": {
"roi": left_neck_roi,
"centroid": left_neck_centroid,
"hu": left_neck_hu,
},
"right_neck": {
"roi": right_neck_roi,
"centroid": right_neck_centroid,
"hu": right_neck_hu,
},
}
def get_femural_head_roi(
femur_mask,
medical_volume,
output_dir,
anatomy,
visualize_method=False,
min_pixel_count=20,
):
top = np.where(femur_mask.sum(axis=(0, 1)) != 0)[0].max()
top_mask = femur_mask[:, :, top]
print(f"======== Computing {anatomy} femur ROIs ========")
while True:
labeled, num_features = ndi.label(top_mask)
component_sizes = np.bincount(labeled.ravel())
valid_components = np.where(component_sizes >= min_pixel_count)[0][1:]
if len(valid_components) == 2:
break
top -= 1
if top < 0:
print("Two connected components not found in the femur mask.")
break
top_mask = femur_mask[:, :, top]
if len(valid_components) == 2:
# Find the center of mass for each connected component
center_of_mass_1 = list(
ndi.center_of_mass(top_mask, labeled, valid_components[0])
)
center_of_mass_2 = list(
ndi.center_of_mass(top_mask, labeled, valid_components[1])
)
# Assign left_center_of_mass to be the center of mass with lowest value in the first dimension
if center_of_mass_1[0] < center_of_mass_2[0]:
left_center_of_mass = center_of_mass_1
right_center_of_mass = center_of_mass_2
else:
left_center_of_mass = center_of_mass_2
right_center_of_mass = center_of_mass_1
print(f"Left center of mass: {left_center_of_mass}")
print(f"Right center of mass: {right_center_of_mass}")
if anatomy == "left_intertrochanter" or anatomy == "right_head":
center_of_mass = left_center_of_mass
elif anatomy == "right_intertrochanter" or anatomy == "left_head":
center_of_mass = right_center_of_mass
coronal_slice = femur_mask[:, round(center_of_mass[1]), :]
coronal_image = medical_volume.get_fdata()[:, round(center_of_mass[1]), :]
sagittal_slice = femur_mask[round(center_of_mass[0]), :, :]
sagittal_image = medical_volume.get_fdata()[round(center_of_mass[0]), :, :]
zooms = medical_volume.header.get_zooms()
zoom_factor = zooms[2] / zooms[1]
coronal_slice = zoom(coronal_slice, (1, zoom_factor), order=1).round()
coronal_image = zoom(coronal_image, (1, zoom_factor), order=3).round()
sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round()
centroid = [round(center_of_mass[0]), 0, 0]
print(f"Starting centroid: {centroid}")
for _ in range(3):
sagittal_slice = femur_mask[centroid[0], :, :]
sagittal_slice = zoom(sagittal_slice, (1, zoom_factor), order=1).round()
centroid[1], centroid[2], radius_sagittal = inscribe_sagittal(
sagittal_slice, zoom_factor
)
print(f"Centroid after inscribe sagittal: {centroid}")
axial_slice = femur_mask[:, :, centroid[2]]
if anatomy == "left_intertrochanter" or anatomy == "right_head":
axial_slice[round(right_center_of_mass[0]) :, :] = 0
elif anatomy == "right_intertrochanter" or anatomy == "left_head":
axial_slice[: round(left_center_of_mass[0]), :] = 0
centroid[0], centroid[1], radius_axial = inscribe_axial(axial_slice)
print(f"Centroid after inscribe axial: {centroid}")
axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])]
sagittal_image = medical_volume.get_fdata()[round(centroid[0]), :, :]
sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round()
if visualize_method:
method_visualizer(
sagittal_image,
axial_image,
axial_slice,
sagittal_slice,
[centroid[2], centroid[1]],
radius_sagittal,
[centroid[1], centroid[0]],
radius_axial,
output_dir,
anatomy,
)
roi = compute_hip_roi(medical_volume, centroid, radius_sagittal, radius_axial)
# selem = ndi.generate_binary_structure(3, 1)
selem = ball(3)
femur_mask_eroded = binary_erosion(femur_mask, selem)
roi = roi * femur_mask_eroded
roi_eroded = roi.astype(np.uint8)
hu = get_mean_roi_hu(medical_volume, roi_eroded)
return (roi_eroded, centroid, hu)
def get_femural_neck_roi(
femur_mask,
medical_volume,
intertrochanter_roi,
intertrochanter_centroid,
head_roi,
head_centroid,
output_dir,
):
zooms = medical_volume.header.get_zooms()
direction_vector = np.array(head_centroid) - np.array(intertrochanter_centroid)
unit_direction_vector = direction_vector / np.linalg.norm(direction_vector)
z, y, x = np.where(intertrochanter_roi)
intertrochanter_points = np.column_stack((z, y, x))
t_start = np.dot(
intertrochanter_points - intertrochanter_centroid, unit_direction_vector
).max()
z, y, x = np.where(head_roi)
head_points = np.column_stack((z, y, x))
t_end = (
np.linalg.norm(direction_vector)
+ np.dot(head_points - head_centroid, unit_direction_vector).min()
)
z, y, x = np.indices(femur_mask.shape)
coordinates = np.stack((z, y, x), axis=-1)
distance_to_line_origin = np.dot(
coordinates - intertrochanter_centroid, unit_direction_vector
)
coordinates_zoomed = coordinates * zooms
intertrochanter_centroid_zoomed = np.array(intertrochanter_centroid) * zooms
unit_direction_vector_zoomed = unit_direction_vector * zooms
distance_to_line = np.linalg.norm(
np.cross(
coordinates_zoomed - intertrochanter_centroid_zoomed,
coordinates_zoomed
- (intertrochanter_centroid_zoomed + unit_direction_vector_zoomed),
),
axis=-1,
) / np.linalg.norm(unit_direction_vector_zoomed)
cylinder_radius = 10
cylinder_mask = (
(distance_to_line <= cylinder_radius)
& (distance_to_line_origin >= t_start)
& (distance_to_line_origin <= t_end)
)
# selem = ndi.generate_binary_structure(3, 1)
selem = ball(3)
femur_mask_eroded = binary_erosion(femur_mask, selem)
roi = cylinder_mask * femur_mask_eroded
neck_roi = roi.astype(np.uint8)
hu = get_mean_roi_hu(medical_volume, neck_roi)
centroid = list(
intertrochanter_centroid + unit_direction_vector * (t_start + t_end) / 2
)
centroid = [round(x) for x in centroid]
return neck_roi, centroid, hu
def compute_hip_roi(img, centroid, radius_sagittal, radius_axial):
pixel_spacing = img.header.get_zooms()
length_i = radius_axial * 0.75 / pixel_spacing[0]
length_j = radius_axial * 0.75 / pixel_spacing[1]
length_k = radius_sagittal * 0.75 / pixel_spacing[2]
roi = np.zeros(img.get_fdata().shape, dtype=np.uint8)
i_lower = math.floor(centroid[0] - length_i)
j_lower = math.floor(centroid[1] - length_j)
k_lower = math.floor(centroid[2] - length_k)
for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1):
for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1):
for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1):
if (i - centroid[0]) ** 2 / length_i**2 + (
j - centroid[1]
) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1:
roi[i, j, k] = 1
return roi
def inscribe_axial(axial_mask):
dist_map = cv2.distanceTransform(axial_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
_, radius_axial, _, center_axial = cv2.minMaxLoc(dist_map)
center_axial = list(center_axial)
left_right_center = round(center_axial[1])
posterior_anterior_center = round(center_axial[0])
return left_right_center, posterior_anterior_center, radius_axial
def inscribe_sagittal(sagittal_mask, zoom_factor):
dist_map = cv2.distanceTransform(sagittal_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
_, radius_sagittal, _, center_sagittal = cv2.minMaxLoc(dist_map)
center_sagittal = list(center_sagittal)
posterior_anterior_center = round(center_sagittal[1])
inferior_superior_center = round(center_sagittal[0])
inferior_superior_center = round(inferior_superior_center / zoom_factor)
return posterior_anterior_center, inferior_superior_center, radius_sagittal
def get_mean_roi_hu(medical_volume, roi):
masked_medical_volume = medical_volume.get_fdata() * roi
return np.mean(masked_medical_volume[masked_medical_volume != 0])