AdritRao's picture
Upload 62 files
a3290d1
raw
history blame
No virus
3.49 kB
import os
from pathlib import Path
from time import time
from typing import Union
from totalsegmentator.libs import (
download_pretrained_weights,
nostdout,
setup_nnunet,
)
from comp2comp.contrast_phase.contrast_inf import predict_phase
from comp2comp.inference_class_base import InferenceClass
class ContrastPhaseDetection(InferenceClass):
"""Contrast Phase Detection."""
def __init__(self, input_path):
super().__init__()
self.input_path = input_path
def __call__(self, inference_pipeline):
self.output_dir = inference_pipeline.output_dir
self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
if not os.path.exists(self.output_dir_segmentations):
os.makedirs(self.output_dir_segmentations)
self.model_dir = inference_pipeline.model_dir
seg, img = self.run_segmentation(
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
self.output_dir_segmentations + "s01.nii.gz",
inference_pipeline.model_dir,
)
# segArray, imgArray = self.convertNibToNumpy(seg, img)
imgNiftiPath = os.path.join(
self.output_dir_segmentations, "converted_dcm.nii.gz"
)
segNiftPath = os.path.join(self.output_dir_segmentations, "s01.nii.gz")
predict_phase(segNiftPath, imgNiftiPath, outputPath=self.output_dir)
return {}
def run_segmentation(
self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
):
"""Run segmentation.
Args:
input_path (Union[str, Path]): Input path.
output_path (Union[str, Path]): Output path.
"""
print("Segmenting...")
st = time()
os.environ["SCRATCH"] = self.model_dir
# Setup nnunet
model = "3d_fullres"
folds = [0]
trainer = "nnUNetTrainerV2_ep4000_nomirror"
crop_path = None
task_id = [251]
setup_nnunet()
for task_id in [251]:
download_pretrained_weights(task_id)
from totalsegmentator.nnunet import nnUNet_predict_image
with nostdout():
img, seg = nnUNet_predict_image(
input_path,
output_path,
task_id,
model=model,
folds=folds,
trainer=trainer,
tta=False,
multilabel_image=True,
resample=1.5,
crop=None,
crop_path=crop_path,
task_name="total",
nora_tag=None,
preview=False,
nr_threads_resampling=1,
nr_threads_saving=6,
quiet=False,
verbose=False,
test=0,
)
end = time()
# Log total time for spine segmentation
print(f"Total time for segmentation: {end-st:.2f}s.")
return seg, img
def convertNibToNumpy(self, TSNib, ImageNib):
"""Convert nifti to numpy array.
Args:
TSNib (nibabel.nifti1.Nifti1Image): TotalSegmentator output.
ImageNib (nibabel.nifti1.Nifti1Image): Input image.
Returns:
numpy.ndarray: TotalSegmentator output.
numpy.ndarray: Input image.
"""
TS_array = TSNib.get_fdata()
img_array = ImageNib.get_fdata()
return TS_array, img_array