edureisMD commited on
Commit
73825ed
1 Parent(s): dd174bd

first commit

Browse files
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import tempfile
3
+ import os
4
+ from pathlib import Path
5
+ import SimpleITK as sitk
6
+ import numpy as np
7
+ import nibabel as nib
8
+ from totalsegmentator.python_api import totalsegmentator
9
+ import gradio as gr
10
+ from segmap import seg_map
11
+ import logging
12
+
13
+ # Logging configuration
14
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"]
18
+
19
+
20
+ def map_labels(seg_array):
21
+ labels = []
22
+ count = 0
23
+ logger.debug("unique segs:")
24
+ logger.debug(str(len(np.unique(seg_array))))
25
+ for seg_class in np.unique(seg_array):
26
+ if seg_class == 0:
27
+ continue
28
+ labels.append((seg_array == seg_class, seg_map[seg_class]))
29
+ count += 1
30
+
31
+ return labels
32
+
33
+ def sitk_to_numpy(img_sitk, norm=False):
34
+ img_sitk = sitk.DICOMOrient(img_sitk, "LPS")
35
+ img_np = sitk.GetArrayFromImage(img_sitk)
36
+ if norm:
37
+ min_val, max_val = np.min(img_np), np.max(img_np)
38
+ img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255
39
+ img_np = img_np.astype(np.uint8)
40
+ return img_np
41
+
42
+
43
+ def load_image(path, norm=False):
44
+ img_sitk = sitk.ReadImage(path)
45
+ return sitk_to_numpy(img_sitk, norm)
46
+
47
+
48
+ def show_img_seg(img_np, seg_np=None, slice_idx=50):
49
+ if img_np is None or (isinstance(img_np, list) and len(img_np) == 0):
50
+ return None
51
+ if isinstance(img_np, list):
52
+ img_np = img_np[-1]
53
+ slice_pos = int(slice_idx * (img_np.shape[0] / 100))
54
+ img_slice = img_np[slice_pos, :, :]
55
+
56
+ if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0):
57
+ seg_np = []
58
+ else:
59
+ if isinstance(seg_np, list):
60
+ seg_np = seg_np[-1]
61
+ seg_np = map_labels(seg_np[slice_pos, :, :])
62
+
63
+ return img_slice, seg_np
64
+
65
+
66
+ def load_img_to_state(path, img_state, seg_state):
67
+ img_state.clear()
68
+ seg_state.clear()
69
+
70
+ if path:
71
+ img_np = load_image(path, norm=True)
72
+ img_state.append(img_np)
73
+ return None, img_state, seg_state
74
+ else:
75
+ return None, img_state, seg_state
76
+
77
+
78
+ def save_seg(seg, path):
79
+ if Path(path).name in sample_files:
80
+ path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz")
81
+ else:
82
+ sitk.WriteImage(seg, path)
83
+
84
+ return path
85
+
86
+
87
+ @spaces.GPU(duration=150)
88
+ def run_inference(path):
89
+ with tempfile.TemporaryDirectory() as temp_dir:
90
+ input_nib = nib.load(path)
91
+ output_nib = totalsegmentator(input_nib, fast=True)
92
+ output_path = os.path.join(temp_dir, "totalseg_output.nii.gz")
93
+ nib.save(output_nib, output_path)
94
+ seg_sitk = sitk.ReadImage(output_path)
95
+ return seg_sitk
96
+
97
+
98
+ def inference_wrapper(input_file, img_state, seg_state, slice_slider=50):
99
+ file_name = Path(input_file).name
100
+
101
+ if file_name in sample_files:
102
+ seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz"))
103
+ else:
104
+ seg_sitk = run_inference(input_file.name)
105
+
106
+ seg_path = save_seg(seg_sitk, input_file.name)
107
+ seg_state.append(sitk_to_numpy(seg_sitk))
108
+
109
+ if not img_state:
110
+ img_sitk = sitk.ReadImage(input_file.name)
111
+ img_state.append(sitk_to_numpy(img_sitk))
112
+
113
+ return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path
114
+
115
+
116
+ with gr.Blocks(title="TotalSegmentator") as interface:
117
+
118
+ gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images")
119
+ gr.Markdown("""
120
+ - **GitHub:** https://github.com/wasserth/TotalSegmentator
121
+ - **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images
122
+ - Supports both CT and MR imaging modalities
123
+ - Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio`
124
+ """)
125
+
126
+ img_state = gr.State([])
127
+ seg_state = gr.State([])
128
+
129
+ with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True):
130
+ with gr.Row():
131
+ with gr.Column():
132
+
133
+ file_input = gr.File(
134
+ type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
135
+ )
136
+ gr.Examples(["input_examples/" + example for example in sample_files], file_input)
137
+
138
+ with gr.Row():
139
+ infer_button = gr.Button("Generate Segmentations", variant="primary")
140
+ clear_button = gr.ClearButton()
141
+
142
+ with gr.Column():
143
+ slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
144
+ img_viewer = gr.AnnotatedImage(label="Image Viewer")
145
+ download_seg = gr.File(label="Download Segmentation", interactive=False)
146
+
147
+ file_input.change(
148
+ load_img_to_state,
149
+ inputs=[file_input, img_state, seg_state],
150
+ outputs=[img_viewer, img_state, seg_state],
151
+ )
152
+ slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer])
153
+
154
+ infer_button.click(
155
+ inference_wrapper,
156
+ inputs=[file_input, img_state, seg_state, slice_slider],
157
+ outputs=[img_viewer, seg_state, download_seg],
158
+ )
159
+
160
+ clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg])
161
+
162
+
163
+ if __name__ == "__main__":
164
+ interface.queue()
165
+ interface.launch(debug=True)
input_examples/ct1.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3d562e8465ad99c783626094236b1067a6795aac04b6e39039bfc411d2e0506
3
+ size 9856205
input_examples/ct2.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:980664b91abff172ecda4cf75bade34e4916281ee2509e3b93e3cf8bc326709e
3
+ size 7895923
input_examples/ct3.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d47d8b8261bc239484f8dfdc1706b8e95a7b284e6e9a10cbc3bd4c41cbf359
3
+ size 8035692
output_examples/ct1_seg.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:271468351f58881d9192b9d7620f236d8f5f807484c4db947ccdb60606ea26cb
3
+ size 142167
output_examples/ct2_seg.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f062135c10c78ccf01a6ce13660002239b4be9f781001f1fca41f25f4691473
3
+ size 61747
output_examples/ct3_seg.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65a95c55016e865e59c89b76da30f38ff84769598f1b5e47d681714ea7d9e12a
3
+ size 53422
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ TotalSegmentator
2
+ SimpleITK
3
+ spaces
segmap.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seg_map = [
2
+ "background",
3
+ "spleen",
4
+ "right_kidney",
5
+ "left_kidney",
6
+ "gallbladder",
7
+ "liver",
8
+ "stomach",
9
+ "pancreas",
10
+ "right_adrenal_gland",
11
+ "left_adrenal_gland",
12
+ "left_lung",
13
+ "right_lung",
14
+ "heart",
15
+ "aorta",
16
+ "inferior_vena_cava",
17
+ "portal_vein_and_splenic_vein",
18
+ "left_iliac_artery",
19
+ "right_iliac_artery",
20
+ "left_iliac_vein",
21
+ "right_iliac_vein",
22
+ "esophagus",
23
+ "small_bowel",
24
+ "duodenum",
25
+ "colon",
26
+ "urinary_bladder",
27
+ "spine",
28
+ "sacrum",
29
+ "left_hip",
30
+ "right_hip",
31
+ "left_femur",
32
+ "right_femur",
33
+ "left_autochthonous_muscle",
34
+ "right_autochthonous_muscle",
35
+ "left_iliopsoas_muscle",
36
+ "right_iliopsoas_muscle",
37
+ "left_gluteus_maximus",
38
+ "right_gluteus_maximus",
39
+ "left_gluteus_medius",
40
+ "right_gluteus_medius",
41
+ "left_gluteus_minimus",
42
+ "right_gluteus_minimus",
43
+ "trachea",
44
+ "thyroid_gland",
45
+ "prostate",
46
+ "kidney_cyst_left",
47
+ "kidney_cyst_right",
48
+ "vertebrae_S1",
49
+ "vertebrae_L5",
50
+ "vertebrae_L4",
51
+ "vertebrae_L3",
52
+ "vertebrae_L2",
53
+ "vertebrae_L1",
54
+ "vertebrae_T12",
55
+ "vertebrae_T11",
56
+ "vertebrae_T10",
57
+ "vertebrae_T9",
58
+ "vertebrae_T8",
59
+ "vertebrae_T7",
60
+ "vertebrae_T6",
61
+ "vertebrae_T5",
62
+ "vertebrae_T4",
63
+ "vertebrae_T3",
64
+ "vertebrae_T2",
65
+ "vertebrae_T1",
66
+ "vertebrae_C7",
67
+ "vertebrae_C6",
68
+ "vertebrae_C5",
69
+ "vertebrae_C4",
70
+ "vertebrae_C3",
71
+ "vertebrae_C2",
72
+ "vertebrae_C1",
73
+ "pulmonary_vein",
74
+ "brachiocephalic_trunk",
75
+ "subclavian_artery_right",
76
+ "subclavian_artery_left",
77
+ "common_carotid_artery_right",
78
+ "common_carotid_artery_left",
79
+ "brachiocephalic_vein_left",
80
+ "brachiocephalic_vein_right",
81
+ "atrial_appendage_left",
82
+ "superior_vena_cava",
83
+ "humerus_left",
84
+ "humerus_right",
85
+ "scapula_left",
86
+ "scapula_right",
87
+ "clavicula_left",
88
+ "clavicula_right",
89
+ "spinal_cord",
90
+ "brain",
91
+ "skull",
92
+ "rib_left_1",
93
+ "rib_left_2",
94
+ "rib_left_3",
95
+ "rib_left_4",
96
+ "rib_left_5",
97
+ "rib_left_6",
98
+ "rib_left_7",
99
+ "rib_left_8",
100
+ "rib_left_9",
101
+ "rib_left_10",
102
+ "rib_left_11",
103
+ "rib_left_12",
104
+ "rib_right_1",
105
+ "rib_right_2",
106
+ "rib_right_3",
107
+ "rib_right_4",
108
+ "rib_right_5",
109
+ "rib_right_6",
110
+ "rib_right_7",
111
+ "rib_right_8",
112
+ "rib_right_9",
113
+ "rib_right_10",
114
+ "rib_right_11",
115
+ "rib_right_12",
116
+ "sternum",
117
+ "costal_cartilages",
118
+ "114","115","116","117","118","119","120","121","122","123",
119
+ "124","125","126","127","128","129","130","131","132","133"
120
+ ]