import spaces import tempfile import os from pathlib import Path import SimpleITK as sitk import numpy as np import nibabel as nib from totalsegmentator.python_api import totalsegmentator import gradio as gr from segmap import seg_map import logging # Logging configuration logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"] def map_labels(seg_array): labels = [] count = 0 logger.debug("unique segs:") logger.debug(str(len(np.unique(seg_array)))) for seg_class in np.unique(seg_array): if seg_class == 0: continue labels.append((seg_array == seg_class, seg_map[seg_class])) count += 1 return labels def sitk_to_numpy(img_sitk, norm=False): img_sitk = sitk.DICOMOrient(img_sitk, "LPS") img_np = sitk.GetArrayFromImage(img_sitk) if norm: min_val, max_val = np.min(img_np), np.max(img_np) img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255 img_np = img_np.astype(np.uint8) return img_np def load_image(path, norm=False): img_sitk = sitk.ReadImage(path) return sitk_to_numpy(img_sitk, norm) def show_img_seg(img_np, seg_np=None, slice_idx=50): if img_np is None or (isinstance(img_np, list) and len(img_np) == 0): return None if isinstance(img_np, list): img_np = img_np[-1] slice_pos = int(slice_idx * (img_np.shape[0] / 100)) img_slice = img_np[slice_pos, :, :] if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0): seg_np = [] else: if isinstance(seg_np, list): seg_np = seg_np[-1] seg_np = map_labels(seg_np[slice_pos, :, :]) return img_slice, seg_np def load_img_to_state(path, img_state, seg_state): img_state.clear() seg_state.clear() if path: img_np = load_image(path, norm=True) img_state.append(img_np) return None, img_state, seg_state else: return None, img_state, seg_state def save_seg(seg, path): if Path(path).name in sample_files: path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz") else: sitk.WriteImage(seg, path) return path @spaces.GPU(duration=150) def run_inference(path): with tempfile.TemporaryDirectory() as temp_dir: input_nib = nib.load(path) output_nib = totalsegmentator(input_nib, fast=True) output_path = os.path.join(temp_dir, "totalseg_output.nii.gz") nib.save(output_nib, output_path) seg_sitk = sitk.ReadImage(output_path) return seg_sitk def inference_wrapper(input_file, img_state, seg_state, slice_slider=50): file_name = Path(input_file).name if file_name in sample_files: seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz")) else: seg_sitk = run_inference(input_file.name) seg_path = save_seg(seg_sitk, input_file.name) seg_state.append(sitk_to_numpy(seg_sitk)) if not img_state: img_sitk = sitk.ReadImage(input_file.name) img_state.append(sitk_to_numpy(img_sitk)) return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path with gr.Blocks(title="TotalSegmentator") as interface: gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images") gr.Markdown(""" - **GitHub:** https://github.com/wasserth/TotalSegmentator - **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images - Supports both CT and MR imaging modalities - Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio` """) img_state = gr.State([]) seg_state = gr.State([]) with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True): with gr.Row(): with gr.Column(): file_input = gr.File( type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] ) gr.Examples(["input_examples/" + example for example in sample_files], file_input) with gr.Row(): infer_button = gr.Button("Generate Segmentations", variant="primary") clear_button = gr.ClearButton() with gr.Column(): slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") img_viewer = gr.AnnotatedImage(label="Image Viewer") download_seg = gr.File(label="Download Segmentation", interactive=False) file_input.change( load_img_to_state, inputs=[file_input, img_state, seg_state], outputs=[img_viewer, img_state, seg_state], ) slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer]) infer_button.click( inference_wrapper, inputs=[file_input, img_state, seg_state, slice_slider], outputs=[img_viewer, seg_state, download_seg], ) clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg]) if __name__ == "__main__": interface.queue() interface.launch(debug=True)