Spaces:
Sleeping
Sleeping
""" | |
@author: louisblankemeier | |
""" | |
import math | |
import os | |
import shutil | |
import zipfile | |
from pathlib import Path | |
from time import time | |
from typing import Union | |
import nibabel as nib | |
import numpy as np | |
import pandas as pd | |
import wget | |
from PIL import Image | |
from totalsegmentatorv2.python_api import totalsegmentator | |
from comp2comp.inference_class_base import InferenceClass | |
from comp2comp.io import io_utils | |
from comp2comp.models.models import Models | |
from comp2comp.spine import spine_utils | |
from comp2comp.visualization.dicom import to_dicom | |
# from totalsegmentator.libs import ( | |
# download_pretrained_weights, | |
# nostdout, | |
# setup_nnunet, | |
# ) | |
class SpineSegmentation(InferenceClass): | |
"""Spine segmentation.""" | |
def __init__(self, model_name, save=True): | |
super().__init__() | |
self.model_name = model_name | |
self.save_segmentations = save | |
def __call__(self, inference_pipeline): | |
# inference_pipeline.dicom_series_path = self.input_path | |
self.output_dir = inference_pipeline.output_dir | |
self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") | |
if not os.path.exists(self.output_dir_segmentations): | |
os.makedirs(self.output_dir_segmentations) | |
self.model_dir = inference_pipeline.model_dir | |
# seg, mv = self.spine_seg( | |
# os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
# self.output_dir_segmentations + "spine.nii.gz", | |
# inference_pipeline.model_dir, | |
# ) | |
os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir | |
seg = totalsegmentator( | |
input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), | |
task_ids=[292], | |
ml=True, | |
nr_thr_resamp=1, | |
nr_thr_saving=6, | |
fast=False, | |
nora_tag="None", | |
preview=False, | |
task="total", | |
# roi_subset=[ | |
# "vertebrae_T12", | |
# "vertebrae_L1", | |
# "vertebrae_L2", | |
# "vertebrae_L3", | |
# "vertebrae_L4", | |
# "vertebrae_L5", | |
# ], | |
roi_subset=None, | |
statistics=False, | |
radiomics=False, | |
crop_path=None, | |
body_seg=False, | |
force_split=False, | |
output_type="nifti", | |
quiet=False, | |
verbose=False, | |
test=0, | |
skip_saving=True, | |
device="gpu", | |
license_number=None, | |
statistics_exclude_masks_at_border=True, | |
no_derived_masks=False, | |
v1_order=False, | |
) | |
mv = nib.load( | |
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") | |
) | |
# inference_pipeline.segmentation = nib.load( | |
# os.path.join(self.output_dir_segmentations, "segmentation.nii") | |
# ) | |
inference_pipeline.segmentation = seg | |
inference_pipeline.medical_volume = mv | |
inference_pipeline.save_segmentations = self.save_segmentations | |
return {} | |
def setup_nnunet_c2c(self, model_dir: Union[str, Path]): | |
"""Adapted from TotalSegmentator.""" | |
model_dir = Path(model_dir) | |
config_dir = model_dir / Path("." + self.model_name) | |
(config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( | |
exist_ok=True, parents=True | |
) | |
(config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) | |
weights_dir = config_dir / "nnunet/results" | |
self.weights_dir = weights_dir | |
os.environ["nnUNet_raw_data_base"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["nnUNet_preprocessed"] = str( | |
weights_dir | |
) # not needed, just needs to be an existing directory | |
os.environ["RESULTS_FOLDER"] = str(weights_dir) | |
def download_spine_model(self, model_dir: Union[str, Path]): | |
download_dir = Path( | |
os.path.join( | |
self.weights_dir, | |
"nnUNet/3d_fullres/Task252_Spine/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", | |
) | |
) | |
fold_0_path = download_dir / "fold_0" | |
if not os.path.exists(fold_0_path): | |
download_dir.mkdir(parents=True, exist_ok=True) | |
wget.download( | |
"https://huggingface.co/louisblankemeier/spine_v1/resolve/main/fold_0.zip", | |
out=os.path.join(download_dir, "fold_0.zip"), | |
) | |
with zipfile.ZipFile( | |
os.path.join(download_dir, "fold_0.zip"), "r" | |
) as zip_ref: | |
zip_ref.extractall(download_dir) | |
os.remove(os.path.join(download_dir, "fold_0.zip")) | |
wget.download( | |
"https://huggingface.co/louisblankemeier/spine_v1/resolve/main/plans.pkl", | |
out=os.path.join(download_dir, "plans.pkl"), | |
) | |
print("Spine model downloaded.") | |
else: | |
print("Spine model already downloaded.") | |
def spine_seg( | |
self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir | |
): | |
"""Run spine segmentation. | |
Args: | |
input_path (Union[str, Path]): Input path. | |
output_path (Union[str, Path]): Output path. | |
""" | |
print("Segmenting spine...") | |
st = time() | |
os.environ["SCRATCH"] = self.model_dir | |
os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir | |
# Setup nnunet | |
model = "3d_fullres" | |
folds = [0] | |
trainer = "nnUNetTrainerV2_ep4000_nomirror" | |
crop_path = None | |
task_id = [252] | |
if self.model_name == "ts_spine": | |
setup_nnunet() | |
download_pretrained_weights(task_id[0]) | |
elif self.model_name == "stanford_spine_v0.0.1": | |
self.setup_nnunet_c2c(model_dir) | |
self.download_spine_model(model_dir) | |
else: | |
raise ValueError("Invalid model name.") | |
if not self.save_segmentations: | |
output_path = None | |
from totalsegmentator.nnunet import nnUNet_predict_image | |
with nostdout(): | |
img, seg = nnUNet_predict_image( | |
input_path, | |
output_path, | |
task_id, | |
model=model, | |
folds=folds, | |
trainer=trainer, | |
tta=False, | |
multilabel_image=True, | |
resample=1.5, | |
crop=None, | |
crop_path=crop_path, | |
task_name="total", | |
nora_tag="None", | |
preview=False, | |
nr_threads_resampling=1, | |
nr_threads_saving=6, | |
quiet=False, | |
verbose=False, | |
test=0, | |
) | |
end = time() | |
# Log total time for spine segmentation | |
print(f"Total time for spine segmentation: {end-st:.2f}s.") | |
if self.model_name == "stanford_spine_v0.0.1": | |
seg_data = seg.get_fdata() | |
# subtract 17 from seg values except for 0 | |
seg_data = np.where(seg_data == 0, 0, seg_data - 17) | |
seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) | |
return seg, img | |
class AxialCropper(InferenceClass): | |
"""Crop the CT image (medical_volume) and segmentation based on user-specified | |
lower and upper levels of the spine. | |
""" | |
def __init__(self, lower_level: str = "L5", upper_level: str = "L1", save=True): | |
""" | |
Args: | |
lower_level (str, optional): Lower level of the spine. Defaults to "L5". | |
upper_level (str, optional): Upper level of the spine. Defaults to "L1". | |
save (bool, optional): Save cropped image and segmentation. Defaults to True. | |
Raises: | |
ValueError: If lower_level or upper_level is not a valid spine level. | |
""" | |
super().__init__() | |
self.lower_level = lower_level | |
self.upper_level = upper_level | |
ts_spine_full_model = Models.model_from_name("ts_spine_full") | |
categories = ts_spine_full_model.categories | |
try: | |
self.lower_level_index = categories[self.lower_level] | |
self.upper_level_index = categories[self.upper_level] | |
except KeyError: | |
raise ValueError("Invalid spine level.") from None | |
self.save = save | |
def __call__(self, inference_pipeline): | |
""" | |
First dim goes from L to R. | |
Second dim goes from P to A. | |
Third dim goes from I to S. | |
""" | |
segmentation = inference_pipeline.segmentation | |
segmentation_data = segmentation.get_fdata() | |
upper_level_index = np.where(segmentation_data == self.upper_level_index)[ | |
2 | |
].max() | |
lower_level_index = np.where(segmentation_data == self.lower_level_index)[ | |
2 | |
].min() | |
segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index] | |
inference_pipeline.segmentation = segmentation | |
medical_volume = inference_pipeline.medical_volume | |
medical_volume = medical_volume.slicer[ | |
:, :, lower_level_index:upper_level_index | |
] | |
inference_pipeline.medical_volume = medical_volume | |
if self.save: | |
nib.save( | |
segmentation, | |
os.path.join( | |
inference_pipeline.output_dir, "segmentations", "spine.nii.gz" | |
), | |
) | |
nib.save( | |
medical_volume, | |
os.path.join( | |
inference_pipeline.output_dir, | |
"segmentations", | |
"converted_dcm.nii.gz", | |
), | |
) | |
return {} | |
class SpineComputeROIs(InferenceClass): | |
def __init__(self, spine_model): | |
super().__init__() | |
self.spine_model_name = spine_model | |
self.spine_model_type = Models.model_from_name(self.spine_model_name) | |
def __call__(self, inference_pipeline): | |
# Compute ROIs | |
inference_pipeline.spine_model_type = self.spine_model_type | |
(spine_hus, rois, segmentation_hus, centroids_3d) = spine_utils.compute_rois( | |
inference_pipeline.segmentation, | |
inference_pipeline.medical_volume, | |
self.spine_model_type, | |
) | |
inference_pipeline.spine_hus = spine_hus | |
inference_pipeline.segmentation_hus = segmentation_hus | |
inference_pipeline.rois = rois | |
inference_pipeline.centroids_3d = centroids_3d | |
return {} | |
class SpineMetricsSaver(InferenceClass): | |
"""Save metrics to a CSV file.""" | |
def __init__(self): | |
super().__init__() | |
def __call__(self, inference_pipeline): | |
"""Save metrics to a CSV file.""" | |
self.spine_hus = inference_pipeline.spine_hus | |
self.seg_hus = inference_pipeline.segmentation_hus | |
self.output_dir = inference_pipeline.output_dir | |
self.csv_output_dir = os.path.join(self.output_dir, "metrics") | |
if not os.path.exists(self.csv_output_dir): | |
os.makedirs(self.csv_output_dir, exist_ok=True) | |
self.save_results() | |
if hasattr(inference_pipeline, "dicom_ds"): | |
if not os.path.exists(os.path.join(self.output_dir, "dicom_metadata.csv")): | |
io_utils.write_dicom_metadata_to_csv( | |
inference_pipeline.dicom_ds, | |
os.path.join(self.output_dir, "dicom_metadata.csv"), | |
) | |
return {} | |
def save_results(self): | |
"""Save results to a CSV file.""" | |
df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"]) | |
for i, level in enumerate(self.spine_hus): | |
hu = self.spine_hus[level] | |
seg_hu = self.seg_hus[level] | |
row = [level, hu, seg_hu] | |
df.loc[i] = row | |
df = df.iloc[::-1] | |
df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False) | |
class SpineFindDicoms(InferenceClass): | |
def __init__(self): | |
super().__init__() | |
def __call__(self, inference_pipeline): | |
inferior_superior_centers = spine_utils.find_spine_dicoms( | |
inference_pipeline.centroids_3d, | |
) | |
spine_utils.save_nifti_select_slices( | |
inference_pipeline.output_dir, inferior_superior_centers | |
) | |
inference_pipeline.dicom_file_paths = [ | |
str(center) for center in inferior_superior_centers | |
] | |
inference_pipeline.names = list(inference_pipeline.rois.keys()) | |
inference_pipeline.dicom_file_names = list(inference_pipeline.rois.keys()) | |
inference_pipeline.inferior_superior_centers = inferior_superior_centers | |
return {} | |
class SpineCoronalSagittalVisualizer(InferenceClass): | |
def __init__(self, format="png"): | |
super().__init__() | |
self.format = format | |
def __call__(self, inference_pipeline): | |
output_path = inference_pipeline.output_dir | |
spine_model_type = inference_pipeline.spine_model_type | |
img_sagittal, img_coronal = spine_utils.visualize_coronal_sagittal_spine( | |
inference_pipeline.segmentation.get_fdata(), | |
list(inference_pipeline.rois.values()), | |
inference_pipeline.medical_volume.get_fdata(), | |
list(inference_pipeline.centroids_3d.values()), | |
output_path, | |
spine_hus=inference_pipeline.spine_hus, | |
seg_hus=inference_pipeline.segmentation_hus, | |
model_type=spine_model_type, | |
pixel_spacing=inference_pipeline.pixel_spacing_list, | |
format=self.format, | |
) | |
inference_pipeline.spine_vis_sagittal = img_sagittal | |
inference_pipeline.spine_vis_coronal = img_coronal | |
inference_pipeline.spine = True | |
if not inference_pipeline.save_segmentations: | |
shutil.rmtree(os.path.join(output_path, "segmentations")) | |
return {} | |
class SpineReport(InferenceClass): | |
def __init__(self, format="png"): | |
super().__init__() | |
self.format = format | |
def __call__(self, inference_pipeline): | |
sagittal_image = inference_pipeline.spine_vis_sagittal | |
coronal_image = inference_pipeline.spine_vis_coronal | |
# concatenate these numpy arrays laterally | |
img = np.concatenate((coronal_image, sagittal_image), axis=1) | |
output_path = os.path.join( | |
inference_pipeline.output_dir, "images", "spine_report" | |
) | |
if self.format == "png": | |
im = Image.fromarray(img) | |
im.save(output_path + ".png") | |
elif self.format == "dcm": | |
to_dicom(img, output_path + ".dcm") | |
return {} | |
class SpineMuscleAdiposeTissueReport(InferenceClass): | |
"""Spine muscle adipose tissue report class.""" | |
def __init__(self): | |
super().__init__() | |
self.image_files = [ | |
"spine_coronal.png", | |
"spine_sagittal.png", | |
"T12.png", | |
"L1.png", | |
"L2.png", | |
"L3.png", | |
"L4.png", | |
"L5.png", | |
] | |
def __call__(self, inference_pipeline): | |
image_dir = Path(inference_pipeline.output_dir) / "images" | |
self.generate_panel(image_dir) | |
return {} | |
def generate_panel(self, image_dir: Union[str, Path]): | |
"""Generate panel. | |
Args: | |
image_dir (Union[str, Path]): Path to the image directory. | |
""" | |
image_files = [os.path.join(image_dir, path) for path in self.image_files] | |
# construct a list which includes only the images that exist | |
image_files = [path for path in image_files if os.path.exists(path)] | |
im_cor = Image.open(image_files[0]) | |
im_sag = Image.open(image_files[1]) | |
im_cor_width = int(im_cor.width / im_cor.height * 512) | |
num_muscle_fat_cols = math.ceil((len(image_files) - 2) / 2) | |
width = (8 + im_cor_width + 8) + ((512 + 8) * num_muscle_fat_cols) | |
height = 1048 | |
new_im = Image.new("RGB", (width, height)) | |
index = 2 | |
for j in range(8, height, 520): | |
for i in range(8 + im_cor_width + 8, width, 520): | |
try: | |
im = Image.open(image_files[index]) | |
im.thumbnail((512, 512)) | |
new_im.paste(im, (i, j)) | |
index += 1 | |
im.close() | |
except Exception: | |
continue | |
im_cor.thumbnail((im_cor_width, 512)) | |
new_im.paste(im_cor, (8, 8)) | |
im_sag.thumbnail((im_cor_width, 512)) | |
new_im.paste(im_sag, (8, 528)) | |
new_im.save(os.path.join(image_dir, "spine_muscle_adipose_tissue_report.png")) | |
im_cor.close() | |
im_sag.close() | |
new_im.close() | |