Spaces:
Sleeping
Sleeping
import math | |
import operator | |
import os | |
import zipfile | |
from pathlib import Path | |
from time import time | |
from tkinter import Tcl | |
from typing import Union | |
import cv2 | |
import matplotlib.pyplot as plt | |
import moviepy.video.io.ImageSequenceClip | |
import nibabel as nib | |
import numpy as np | |
import pandas as pd | |
import pydicom | |
import wget | |
from totalsegmentator.libs import nostdout | |
from comp2comp.inference_class_base import InferenceClass | |
class AortaSegmentation(InferenceClass): | |
"""Spine segmentation.""" | |
def __init__(self, save=True): | |
super().__init__() | |
self.model_name = "totalsegmentator" | |
self.save_segmentations = save | |
def __call__(self, inference_pipeline): | |
# inference_pipeline.dicom_series_path = self.input_path | |
self.output_dir = inference_pipeline.output_dir | |
self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") | |
if not os.path.exists(self.output_dir_segmentations): | |
os.makedirs(self.output_dir_segmentations) | |
self.model_dir = inference_pipeline.model_dir | |
seg, mv = self.spine_seg( | |
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
self.output_dir_segmentations + "spine.nii.gz", | |
inference_pipeline.model_dir, | |
) | |
seg = seg.get_fdata() | |
medical_volume = mv.get_fdata() | |
axial_masks = [] | |
ct_image = [] | |
for i in range(seg.shape[2]): | |
axial_masks.append(seg[:, :, i]) | |
for i in range(medical_volume.shape[2]): | |
ct_image.append(medical_volume[:, :, i]) | |
# Save input axial slices to pipeline | |
inference_pipeline.ct_image = ct_image | |
# Save aorta masks to pipeline | |
inference_pipeline.axial_masks = axial_masks | |
return {} | |
def setup_nnunet_c2c(self, model_dir: Union[str, Path]): | |
"""Adapted from TotalSegmentator.""" | |
model_dir = Path(model_dir) | |
config_dir = model_dir / Path("." + self.model_name) | |
(config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( | |
exist_ok=True, parents=True | |
) | |
(config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) | |
weights_dir = config_dir / "nnunet/results" | |
self.weights_dir = weights_dir | |
os.environ["nnUNet_raw_data_base"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["nnUNet_preprocessed"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["RESULTS_FOLDER"] = str(weights_dir) | |
def download_spine_model(self, model_dir: Union[str, Path]): | |
download_dir = Path( | |
os.path.join( | |
self.weights_dir, | |
"nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", | |
) | |
) | |
print(download_dir) | |
fold_0_path = download_dir / "fold_0" | |
if not os.path.exists(fold_0_path): | |
download_dir.mkdir(parents=True, exist_ok=True) | |
wget.download( | |
"https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", | |
out=os.path.join(download_dir, "fold_0.zip"), | |
) | |
with zipfile.ZipFile( | |
os.path.join(download_dir, "fold_0.zip"), "r" | |
) as zip_ref: | |
zip_ref.extractall(download_dir) | |
os.remove(os.path.join(download_dir, "fold_0.zip")) | |
wget.download( | |
"https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", | |
out=os.path.join(download_dir, "plans.pkl"), | |
) | |
print("Spine model downloaded.") | |
else: | |
print("Spine model already downloaded.") | |
def spine_seg( | |
self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir | |
): | |
"""Run spine segmentation. | |
Args: | |
input_path (Union[str, Path]): Input path. | |
output_path (Union[str, Path]): Output path. | |
""" | |
print("Segmenting spine...") | |
st = time() | |
os.environ["SCRATCH"] = self.model_dir | |
print(self.model_dir) | |
# Setup nnunet | |
model = "3d_fullres" | |
folds = [0] | |
trainer = "nnUNetTrainerV2_ep4000_nomirror" | |
crop_path = None | |
task_id = [253] | |
self.setup_nnunet_c2c(model_dir) | |
self.download_spine_model(model_dir) | |
from totalsegmentator.nnunet import nnUNet_predict_image | |
with nostdout(): | |
img, seg = nnUNet_predict_image( | |
input_path, | |
output_path, | |
task_id, | |
model=model, | |
folds=folds, | |
trainer=trainer, | |
tta=False, | |
multilabel_image=True, | |
resample=1.5, | |
crop=None, | |
crop_path=crop_path, | |
task_name="total", | |
nora_tag="None", | |
preview=False, | |
nr_threads_resampling=1, | |
nr_threads_saving=6, | |
quiet=False, | |
verbose=False, | |
test=0, | |
) | |
end = time() | |
# Log total time for spine segmentation | |
print(f"Total time for spine segmentation: {end-st:.2f}s.") | |
seg_data = seg.get_fdata() | |
seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) | |
return seg, img | |
class AortaDiameter(InferenceClass): | |
def __init__(self): | |
super().__init__() | |
def normalize_img(self, img: np.ndarray) -> np.ndarray: | |
"""Normalize the image. | |
Args: | |
img (np.ndarray): Input image. | |
Returns: | |
np.ndarray: Normalized image. | |
""" | |
return (img - img.min()) / (img.max() - img.min()) | |
def __call__(self, inference_pipeline): | |
axial_masks = ( | |
inference_pipeline.axial_masks | |
) # list of 2D numpy arrays of shape (512, 512) | |
ct_img = ( | |
inference_pipeline.ct_image | |
) # 3D numpy array of shape (512, 512, num_axial_slices) | |
# image output directory | |
output_dir = inference_pipeline.output_dir | |
output_dir_slices = os.path.join(output_dir, "images/slices/") | |
if not os.path.exists(output_dir_slices): | |
os.makedirs(output_dir_slices) | |
output_dir = inference_pipeline.output_dir | |
output_dir_summary = os.path.join(output_dir, "images/summary/") | |
if not os.path.exists(output_dir_summary): | |
os.makedirs(output_dir_summary) | |
DICOM_PATH = inference_pipeline.dicom_series_path | |
dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0]) | |
dicom.PhotometricInterpretation = "YBR_FULL" | |
pixel_conversion = dicom.PixelSpacing | |
print("Pixel conversion: " + str(pixel_conversion)) | |
RATIO_PIXEL_TO_MM = pixel_conversion[0] | |
SLICE_COUNT = dicom["InstanceNumber"].value | |
print(SLICE_COUNT) | |
SLICE_COUNT = len(ct_img) | |
diameterDict = {} | |
for i in range(len(ct_img)): | |
mask = axial_masks[i].astype("uint8") | |
img = ct_img[i] | |
img = np.clip(img, -300, 1800) | |
img = self.normalize_img(img) * 255.0 | |
img = img.reshape((img.shape[0], img.shape[1], 1)) | |
img = np.tile(img, (1, 1, 3)) | |
contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
if len(contours) != 0: | |
areas = [cv2.contourArea(c) for c in contours] | |
sorted_areas = np.sort(areas) | |
areas = [cv2.contourArea(c) for c in contours] | |
sorted_areas = np.sort(areas) | |
contours = contours[areas.index(sorted_areas[-1])] | |
img.copy() | |
back = img.copy() | |
cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) | |
alpha = 0.25 | |
img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0) | |
ellipse = cv2.fitEllipse(contours) | |
(xc, yc), (d1, d2), angle = ellipse | |
cv2.ellipse(img, ellipse, (0, 255, 0), 1) | |
xc, yc = ellipse[0] | |
cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) | |
rmajor = max(d1, d2) / 2 | |
rminor = min(d1, d2) / 2 | |
### Draw major axes | |
if angle > 90: | |
angle = angle - 90 | |
else: | |
angle = angle + 90 | |
print(angle) | |
xtop = xc + math.cos(math.radians(angle)) * rmajor | |
ytop = yc + math.sin(math.radians(angle)) * rmajor | |
xbot = xc + math.cos(math.radians(angle + 180)) * rmajor | |
ybot = yc + math.sin(math.radians(angle + 180)) * rmajor | |
cv2.line( | |
img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3 | |
) | |
### Draw minor axes | |
if angle > 90: | |
angle = angle - 90 | |
else: | |
angle = angle + 90 | |
print(angle) | |
x1 = xc + math.cos(math.radians(angle)) * rminor | |
y1 = yc + math.sin(math.radians(angle)) * rminor | |
x2 = xc + math.cos(math.radians(angle + 180)) * rminor | |
y2 = yc + math.sin(math.radians(angle + 180)) * rminor | |
cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3) | |
# pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) | |
pixel_length = rminor * 2 | |
print("Pixel_length_minor: " + str(pixel_length)) | |
area_px = cv2.contourArea(contours) | |
area_mm = round(area_px * RATIO_PIXEL_TO_MM) | |
area_cm = area_mm / 10 | |
diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM) | |
diameter_cm = diameter_mm / 10 | |
diameterDict[(SLICE_COUNT - (i))] = diameter_cm | |
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
h, w, c = img.shape | |
lbls = [ | |
"Area (mm): " + str(area_mm) + "mm", | |
"Area (cm): " + str(area_cm) + "cm", | |
"Diameter (mm): " + str(diameter_mm) + "mm", | |
"Diameter (cm): " + str(diameter_cm) + "cm", | |
"Slice: " + str(SLICE_COUNT - (i)), | |
] | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
scale = 0.03 | |
fontScale = min(w, h) / (25 / scale) | |
cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) | |
cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) | |
cv2.imwrite( | |
output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img | |
) | |
plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b") | |
plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") | |
plt.xlabel("Slice Number") | |
plt.ylabel("Diameter Measurement (cm)") | |
plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500) | |
print(diameterDict) | |
print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) | |
inference_pipeline.max_diameter = diameterDict[ | |
max(diameterDict.items(), key=operator.itemgetter(1))[0] | |
] | |
img = ct_img[ | |
SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
] | |
img = np.clip(img, -300, 1800) | |
img = self.normalize_img(img) * 255.0 | |
img = img.reshape((img.shape[0], img.shape[1], 1)) | |
img2 = np.tile(img, (1, 1, 3)) | |
img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
img1 = cv2.imread( | |
output_dir_slices | |
+ "slice" | |
+ str(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
+ ".png" | |
) | |
border_size = 3 | |
img1 = cv2.copyMakeBorder( | |
img1, | |
top=border_size, | |
bottom=border_size, | |
left=border_size, | |
right=border_size, | |
borderType=cv2.BORDER_CONSTANT, | |
value=[0, 244, 0], | |
) | |
img2 = cv2.copyMakeBorder( | |
img2, | |
top=border_size, | |
bottom=border_size, | |
left=border_size, | |
right=border_size, | |
borderType=cv2.BORDER_CONSTANT, | |
value=[244, 0, 0], | |
) | |
vis = np.concatenate((img2, img1), axis=1) | |
cv2.imwrite(output_dir_summary + "out.png", vis) | |
image_folder = output_dir_slices | |
fps = 20 | |
image_files = [ | |
os.path.join(image_folder, img) | |
for img in Tcl().call("lsort", "-dict", os.listdir(image_folder)) | |
if img.endswith(".png") | |
] | |
clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( | |
image_files, fps=fps | |
) | |
clip.write_videofile(output_dir_summary + "aaa.mp4") | |
return {} | |
class AortaMetricsSaver(InferenceClass): | |
"""Save metrics to a CSV file.""" | |
def __init__(self): | |
super().__init__() | |
def __call__(self, inference_pipeline): | |
"""Save metrics to a CSV file.""" | |
self.max_diameter = inference_pipeline.max_diameter | |
self.dicom_series_path = inference_pipeline.dicom_series_path | |
self.output_dir = inference_pipeline.output_dir | |
self.csv_output_dir = os.path.join(self.output_dir, "metrics") | |
if not os.path.exists(self.csv_output_dir): | |
os.makedirs(self.csv_output_dir, exist_ok=True) | |
self.save_results() | |
return {} | |
def save_results(self): | |
"""Save results to a CSV file.""" | |
_, filename = os.path.split(self.dicom_series_path) | |
data = [[filename, str(self.max_diameter)]] | |
df = pd.DataFrame(data, columns=["Filename", "Max Diameter"]) | |
df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) | |