""" @author: louisblankemeier """ import os from pathlib import Path from time import time from typing import Union import pandas as pd from totalsegmentator.libs import ( download_pretrained_weights, nostdout, setup_nnunet, ) from comp2comp.hip import hip_utils from comp2comp.hip.hip_visualization import ( hip_report_visualizer, hip_roi_visualizer, ) from comp2comp.inference_class_base import InferenceClass from comp2comp.models.models import Models class HipSegmentation(InferenceClass): """Spine segmentation.""" def __init__(self, model_name): super().__init__() self.model_name = model_name self.model = Models.model_from_name(model_name) def __call__(self, inference_pipeline): # inference_pipeline.dicom_series_path = self.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") if not os.path.exists(self.output_dir_segmentations): os.makedirs(self.output_dir_segmentations) self.model_dir = inference_pipeline.model_dir seg, mv = self.hip_seg( os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), self.output_dir_segmentations + "hip.nii.gz", inference_pipeline.model_dir, ) inference_pipeline.model = self.model inference_pipeline.segmentation = seg inference_pipeline.medical_volume = mv return {} def hip_seg( self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir ): """Run spine segmentation. Args: input_path (Union[str, Path]): Input path. output_path (Union[str, Path]): Output path. """ print("Segmenting hip...") st = time() os.environ["SCRATCH"] = self.model_dir # Setup nnunet model = "3d_fullres" folds = [0] trainer = "nnUNetTrainerV2_ep4000_nomirror" crop_path = None task_id = [254] if self.model_name == "ts_hip": setup_nnunet() download_pretrained_weights(task_id[0]) else: raise ValueError("Invalid model name.") from totalsegmentator.nnunet import nnUNet_predict_image with nostdout(): img, seg = nnUNet_predict_image( input_path, output_path, task_id, model=model, folds=folds, trainer=trainer, tta=False, multilabel_image=True, resample=1.5, crop=None, crop_path=crop_path, task_name="total", nora_tag=None, preview=False, nr_threads_resampling=1, nr_threads_saving=6, quiet=False, verbose=False, test=0, ) end = time() # Log total time for hip segmentation print(f"Total time for hip segmentation: {end-st:.2f}s.") return seg, img class HipComputeROIs(InferenceClass): def __init__(self, hip_model): super().__init__() self.hip_model_name = hip_model self.hip_model_type = Models.model_from_name(self.hip_model_name) def __call__(self, inference_pipeline): segmentation = inference_pipeline.segmentation medical_volume = inference_pipeline.medical_volume model = inference_pipeline.model images_folder = os.path.join(inference_pipeline.output_dir, "dev") results_dict = hip_utils.compute_rois( medical_volume, segmentation, model, images_folder ) inference_pipeline.femur_results_dict = results_dict return {} class HipMetricsSaver(InferenceClass): """Save metrics to a CSV file.""" def __init__(self): super().__init__() def __call__(self, inference_pipeline): metrics_output_dir = os.path.join(inference_pipeline.output_dir, "metrics") if not os.path.exists(metrics_output_dir): os.makedirs(metrics_output_dir) results_dict = inference_pipeline.femur_results_dict left_head_hu = results_dict["left_head"]["hu"] right_head_hu = results_dict["right_head"]["hu"] left_intertrochanter_hu = results_dict["left_intertrochanter"]["hu"] right_intertrochanter_hu = results_dict["right_intertrochanter"]["hu"] left_neck_hu = results_dict["left_neck"]["hu"] right_neck_hu = results_dict["right_neck"]["hu"] # save to csv df = pd.DataFrame( { "Left Head (HU)": [left_head_hu], "Right Head (HU)": [right_head_hu], "Left Intertrochanter (HU)": [left_intertrochanter_hu], "Right Intertrochanter (HU)": [right_intertrochanter_hu], "Left Neck (HU)": [left_neck_hu], "Right Neck (HU)": [right_neck_hu], } ) df.to_csv(os.path.join(metrics_output_dir, "hip_metrics.csv"), index=False) return {} class HipVisualizer(InferenceClass): def __init__(self): super().__init__() def __call__(self, inference_pipeline): medical_volume = inference_pipeline.medical_volume left_head_roi = inference_pipeline.femur_results_dict["left_head"]["roi"] left_head_centroid = inference_pipeline.femur_results_dict["left_head"][ "centroid" ] left_head_hu = inference_pipeline.femur_results_dict["left_head"]["hu"] left_intertrochanter_roi = inference_pipeline.femur_results_dict[ "left_intertrochanter" ]["roi"] left_intertrochanter_centroid = inference_pipeline.femur_results_dict[ "left_intertrochanter" ]["centroid"] left_intertrochanter_hu = inference_pipeline.femur_results_dict[ "left_intertrochanter" ]["hu"] left_neck_roi = inference_pipeline.femur_results_dict["left_neck"]["roi"] left_neck_centroid = inference_pipeline.femur_results_dict["left_neck"][ "centroid" ] left_neck_hu = inference_pipeline.femur_results_dict["left_neck"]["hu"] right_head_roi = inference_pipeline.femur_results_dict["right_head"]["roi"] right_head_centroid = inference_pipeline.femur_results_dict["right_head"][ "centroid" ] right_head_hu = inference_pipeline.femur_results_dict["right_head"]["hu"] right_intertrochanter_roi = inference_pipeline.femur_results_dict[ "right_intertrochanter" ]["roi"] right_intertrochanter_centroid = inference_pipeline.femur_results_dict[ "right_intertrochanter" ]["centroid"] right_intertrochanter_hu = inference_pipeline.femur_results_dict[ "right_intertrochanter" ]["hu"] right_neck_roi = inference_pipeline.femur_results_dict["right_neck"]["roi"] right_neck_centroid = inference_pipeline.femur_results_dict["right_neck"][ "centroid" ] right_neck_hu = inference_pipeline.femur_results_dict["right_neck"]["hu"] output_dir = inference_pipeline.output_dir images_output_dir = os.path.join(output_dir, "images") if not os.path.exists(images_output_dir): os.makedirs(images_output_dir) hip_roi_visualizer( medical_volume, left_head_roi, left_head_centroid, left_head_hu, images_output_dir, "left_head", ) hip_roi_visualizer( medical_volume, left_intertrochanter_roi, left_intertrochanter_centroid, left_intertrochanter_hu, images_output_dir, "left_intertrochanter", ) hip_roi_visualizer( medical_volume, left_neck_roi, left_neck_centroid, left_neck_hu, images_output_dir, "left_neck", ) hip_roi_visualizer( medical_volume, right_head_roi, right_head_centroid, right_head_hu, images_output_dir, "right_head", ) hip_roi_visualizer( medical_volume, right_intertrochanter_roi, right_intertrochanter_centroid, right_intertrochanter_hu, images_output_dir, "right_intertrochanter", ) hip_roi_visualizer( medical_volume, right_neck_roi, right_neck_centroid, right_neck_hu, images_output_dir, "right_neck", ) hip_report_visualizer( medical_volume.get_fdata(), left_head_roi + right_head_roi, [left_head_centroid, right_head_centroid], images_output_dir, "head", { "Left Head HU": round(left_head_hu), "Right Head HU": round(right_head_hu), }, ) hip_report_visualizer( medical_volume.get_fdata(), left_intertrochanter_roi + right_intertrochanter_roi, [left_intertrochanter_centroid, right_intertrochanter_centroid], images_output_dir, "intertrochanter", { "Left Intertrochanter HU": round(left_intertrochanter_hu), "Right Intertrochanter HU": round(right_intertrochanter_hu), }, ) hip_report_visualizer( medical_volume.get_fdata(), left_neck_roi + right_neck_roi, [left_neck_centroid, right_neck_centroid], images_output_dir, "neck", { "Left Neck HU": round(left_neck_hu), "Right Neck HU": round(right_neck_hu), }, ) return {}