AdritRao's picture
Upload 62 files
a3290d1
raw
history blame
No virus
10.1 kB
"""
@author: louisblankemeier
"""
import os
from pathlib import Path
from time import time
from typing import Union
import pandas as pd
from totalsegmentator.libs import (
download_pretrained_weights,
nostdout,
setup_nnunet,
)
from comp2comp.hip import hip_utils
from comp2comp.hip.hip_visualization import (
hip_report_visualizer,
hip_roi_visualizer,
)
from comp2comp.inference_class_base import InferenceClass
from comp2comp.models.models import Models
class HipSegmentation(InferenceClass):
"""Spine segmentation."""
def __init__(self, model_name):
super().__init__()
self.model_name = model_name
self.model = Models.model_from_name(model_name)
def __call__(self, inference_pipeline):
# inference_pipeline.dicom_series_path = self.input_path
self.output_dir = inference_pipeline.output_dir
self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
if not os.path.exists(self.output_dir_segmentations):
os.makedirs(self.output_dir_segmentations)
self.model_dir = inference_pipeline.model_dir
seg, mv = self.hip_seg(
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
self.output_dir_segmentations + "hip.nii.gz",
inference_pipeline.model_dir,
)
inference_pipeline.model = self.model
inference_pipeline.segmentation = seg
inference_pipeline.medical_volume = mv
return {}
def hip_seg(
self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
):
"""Run spine segmentation.
Args:
input_path (Union[str, Path]): Input path.
output_path (Union[str, Path]): Output path.
"""
print("Segmenting hip...")
st = time()
os.environ["SCRATCH"] = self.model_dir
# Setup nnunet
model = "3d_fullres"
folds = [0]
trainer = "nnUNetTrainerV2_ep4000_nomirror"
crop_path = None
task_id = [254]
if self.model_name == "ts_hip":
setup_nnunet()
download_pretrained_weights(task_id[0])
else:
raise ValueError("Invalid model name.")
from totalsegmentator.nnunet import nnUNet_predict_image
with nostdout():
img, seg = nnUNet_predict_image(
input_path,
output_path,
task_id,
model=model,
folds=folds,
trainer=trainer,
tta=False,
multilabel_image=True,
resample=1.5,
crop=None,
crop_path=crop_path,
task_name="total",
nora_tag=None,
preview=False,
nr_threads_resampling=1,
nr_threads_saving=6,
quiet=False,
verbose=False,
test=0,
)
end = time()
# Log total time for hip segmentation
print(f"Total time for hip segmentation: {end-st:.2f}s.")
return seg, img
class HipComputeROIs(InferenceClass):
def __init__(self, hip_model):
super().__init__()
self.hip_model_name = hip_model
self.hip_model_type = Models.model_from_name(self.hip_model_name)
def __call__(self, inference_pipeline):
segmentation = inference_pipeline.segmentation
medical_volume = inference_pipeline.medical_volume
model = inference_pipeline.model
images_folder = os.path.join(inference_pipeline.output_dir, "dev")
results_dict = hip_utils.compute_rois(
medical_volume, segmentation, model, images_folder
)
inference_pipeline.femur_results_dict = results_dict
return {}
class HipMetricsSaver(InferenceClass):
"""Save metrics to a CSV file."""
def __init__(self):
super().__init__()
def __call__(self, inference_pipeline):
metrics_output_dir = os.path.join(inference_pipeline.output_dir, "metrics")
if not os.path.exists(metrics_output_dir):
os.makedirs(metrics_output_dir)
results_dict = inference_pipeline.femur_results_dict
left_head_hu = results_dict["left_head"]["hu"]
right_head_hu = results_dict["right_head"]["hu"]
left_intertrochanter_hu = results_dict["left_intertrochanter"]["hu"]
right_intertrochanter_hu = results_dict["right_intertrochanter"]["hu"]
left_neck_hu = results_dict["left_neck"]["hu"]
right_neck_hu = results_dict["right_neck"]["hu"]
# save to csv
df = pd.DataFrame(
{
"Left Head (HU)": [left_head_hu],
"Right Head (HU)": [right_head_hu],
"Left Intertrochanter (HU)": [left_intertrochanter_hu],
"Right Intertrochanter (HU)": [right_intertrochanter_hu],
"Left Neck (HU)": [left_neck_hu],
"Right Neck (HU)": [right_neck_hu],
}
)
df.to_csv(os.path.join(metrics_output_dir, "hip_metrics.csv"), index=False)
return {}
class HipVisualizer(InferenceClass):
def __init__(self):
super().__init__()
def __call__(self, inference_pipeline):
medical_volume = inference_pipeline.medical_volume
left_head_roi = inference_pipeline.femur_results_dict["left_head"]["roi"]
left_head_centroid = inference_pipeline.femur_results_dict["left_head"][
"centroid"
]
left_head_hu = inference_pipeline.femur_results_dict["left_head"]["hu"]
left_intertrochanter_roi = inference_pipeline.femur_results_dict[
"left_intertrochanter"
]["roi"]
left_intertrochanter_centroid = inference_pipeline.femur_results_dict[
"left_intertrochanter"
]["centroid"]
left_intertrochanter_hu = inference_pipeline.femur_results_dict[
"left_intertrochanter"
]["hu"]
left_neck_roi = inference_pipeline.femur_results_dict["left_neck"]["roi"]
left_neck_centroid = inference_pipeline.femur_results_dict["left_neck"][
"centroid"
]
left_neck_hu = inference_pipeline.femur_results_dict["left_neck"]["hu"]
right_head_roi = inference_pipeline.femur_results_dict["right_head"]["roi"]
right_head_centroid = inference_pipeline.femur_results_dict["right_head"][
"centroid"
]
right_head_hu = inference_pipeline.femur_results_dict["right_head"]["hu"]
right_intertrochanter_roi = inference_pipeline.femur_results_dict[
"right_intertrochanter"
]["roi"]
right_intertrochanter_centroid = inference_pipeline.femur_results_dict[
"right_intertrochanter"
]["centroid"]
right_intertrochanter_hu = inference_pipeline.femur_results_dict[
"right_intertrochanter"
]["hu"]
right_neck_roi = inference_pipeline.femur_results_dict["right_neck"]["roi"]
right_neck_centroid = inference_pipeline.femur_results_dict["right_neck"][
"centroid"
]
right_neck_hu = inference_pipeline.femur_results_dict["right_neck"]["hu"]
output_dir = inference_pipeline.output_dir
images_output_dir = os.path.join(output_dir, "images")
if not os.path.exists(images_output_dir):
os.makedirs(images_output_dir)
hip_roi_visualizer(
medical_volume,
left_head_roi,
left_head_centroid,
left_head_hu,
images_output_dir,
"left_head",
)
hip_roi_visualizer(
medical_volume,
left_intertrochanter_roi,
left_intertrochanter_centroid,
left_intertrochanter_hu,
images_output_dir,
"left_intertrochanter",
)
hip_roi_visualizer(
medical_volume,
left_neck_roi,
left_neck_centroid,
left_neck_hu,
images_output_dir,
"left_neck",
)
hip_roi_visualizer(
medical_volume,
right_head_roi,
right_head_centroid,
right_head_hu,
images_output_dir,
"right_head",
)
hip_roi_visualizer(
medical_volume,
right_intertrochanter_roi,
right_intertrochanter_centroid,
right_intertrochanter_hu,
images_output_dir,
"right_intertrochanter",
)
hip_roi_visualizer(
medical_volume,
right_neck_roi,
right_neck_centroid,
right_neck_hu,
images_output_dir,
"right_neck",
)
hip_report_visualizer(
medical_volume.get_fdata(),
left_head_roi + right_head_roi,
[left_head_centroid, right_head_centroid],
images_output_dir,
"head",
{
"Left Head HU": round(left_head_hu),
"Right Head HU": round(right_head_hu),
},
)
hip_report_visualizer(
medical_volume.get_fdata(),
left_intertrochanter_roi + right_intertrochanter_roi,
[left_intertrochanter_centroid, right_intertrochanter_centroid],
images_output_dir,
"intertrochanter",
{
"Left Intertrochanter HU": round(left_intertrochanter_hu),
"Right Intertrochanter HU": round(right_intertrochanter_hu),
},
)
hip_report_visualizer(
medical_volume.get_fdata(),
left_neck_roi + right_neck_roi,
[left_neck_centroid, right_neck_centroid],
images_output_dir,
"neck",
{
"Left Neck HU": round(left_neck_hu),
"Right Neck HU": round(right_neck_hu),
},
)
return {}