DiGuaQiu commited on
Commit
52811bb
1 Parent(s): 9236ec2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -46
app.py CHANGED
@@ -3,27 +3,15 @@
3
  import tempfile
4
  from pathlib import Path
5
 
6
- import nibabel as nib
7
- import numpy as np
8
- from PIL import ImageDraw
9
- from streamlit_drawable_canvas import st_canvas
10
- from streamlit_image_coordinates import streamlit_image_coordinates
11
- import nibabel as nib
12
  import SimpleITK as sitk
 
 
13
  import streamlit as st
14
  import utils
15
- from utils import (
16
- initial_rectangle,
17
- make_fig,
18
- reflect_box_into_model,
19
- reflect_json_data_to_3D_box,
20
- run,
21
- )
22
-
23
- # from viewer import BasicViewer
24
 
25
  print("script run")
26
  st.title("MRSegmentator")
 
27
 
28
  #############################################
29
  # init session_state
@@ -50,8 +38,9 @@ if "transparency" not in st.session_state:
50
  st.session_state.transparency = 0.25
51
 
52
  case_list = [
53
- "images/amos_0059.nii.gz",
54
- "images/amos_0555.nii.gz",
 
55
  ]
56
 
57
  #############################################
@@ -66,8 +55,11 @@ def clear_prompts():
66
  def reset_demo_case():
67
  st.session_state.data_item = None
68
  st.session_state.reset_demo_case = True
 
 
69
  clear_prompts()
70
 
 
71
  def clear_file():
72
  st.session_state.option = None
73
  reset_demo_case()
@@ -85,26 +77,33 @@ with arxive_col:
85
  st.write("Paper: https://arxiv.org/abs/2405.06463")
86
 
87
  # modify demo case here
88
- demo_type = st.radio("Demo case source", ["Select", "Upload"], on_change=clear_file)
89
 
90
  with tempfile.TemporaryDirectory() as tmpdirname:
91
 
92
  # modify demo case here
93
- if demo_type == "Select":
94
- uploaded_file = st.selectbox(
95
  "Select a demo case",
96
  case_list,
97
  index=None,
98
  placeholder="Select a demo case...",
99
  on_change=reset_demo_case,
100
  )
 
 
 
 
 
 
 
 
 
101
  else:
102
- uploaded_file = st.file_uploader(
103
- "Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case
104
- )
105
 
106
- if( uploaded_file is not None ):
107
- with open(tmpdirname + "/" + uploaded_file.name, 'wb') as f:
108
  f.write(uploaded_file.getvalue())
109
  uploaded_file = tmpdirname + "/" + uploaded_file.name
110
 
@@ -117,39 +116,46 @@ with tempfile.TemporaryDirectory() as tmpdirname:
117
  ):
118
 
119
  st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
120
- st.session_state.data_item_ori = sitk.ReadImage(Path(__file__).parent / str(uploaded_file))
 
121
  st.session_state.reset_demo_case = False
122
- st.session_state.preds_3D = None
123
- st.session_state.preds_path = None
124
-
125
 
126
  if st.session_state.option is None:
127
  st.write("please select demo case first")
128
  else:
129
  image_3D = st.session_state.data_item
130
- px_range = st.slider( "Select intensity range",
131
- int(image_3D.min()),
132
- int(image_3D.max()),
133
- (int(image_3D.min()), int(image_3D.max()))
134
- )
 
135
  col_control1, col_control2 = st.columns(2)
136
 
137
  with col_control1:
138
  selected_index_z = st.slider(
139
- "Axial view", 0, image_3D.shape[0] - 1, image_3D.shape[0] // 2, key="xy", disabled=st.session_state.running
 
 
 
 
 
140
  )
141
 
142
  with col_control2:
143
  selected_index_y = st.slider(
144
- "Coronal view", 0, image_3D.shape[1] - 1, image_3D.shape[1] // 2, key="xz", disabled=st.session_state.running
 
 
 
 
 
145
  )
146
 
147
  col_image1, col_image2 = st.columns(2)
148
 
149
  if st.session_state.preds_3D is not None:
150
- st.session_state.transparency = st.slider(
151
- "Mask opacity", 0.0, 1.0, 0.5, disabled=st.session_state.running
152
- )
153
 
154
  with col_image1:
155
 
@@ -159,7 +165,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
159
  if st.session_state.preds_3D is not None:
160
  preds_z_array = st.session_state.preds_3D[selected_index_z]
161
 
162
- image_z = make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
163
  st.image(image_z, use_column_width=False)
164
 
165
  with col_image2:
@@ -169,7 +175,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
169
  if st.session_state.preds_3D is not None:
170
  preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
171
 
172
- image_y = make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
173
  st.image(image_y, use_column_width=False)
174
 
175
  ######################################################
@@ -177,6 +183,9 @@ with tempfile.TemporaryDirectory() as tmpdirname:
177
  col1, col2, col3 = st.columns(3)
178
 
179
  with col1:
 
 
 
180
  if st.button(
181
  "Clear",
182
  use_container_width=True,
@@ -188,19 +197,21 @@ with tempfile.TemporaryDirectory() as tmpdirname:
188
  st.rerun()
189
 
190
  with col2:
 
 
 
191
 
192
  if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
193
 
194
  with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
195
 
196
  preds = st.session_state.preds_3D_ori
197
- #result_image.CopyInformation(inputImage)
198
  sitk.WriteImage(preds, tmpfile.name)
199
- #nib.save(st.session_state.preds_3D, tmpfile.name)
200
  with open(tmpfile.name, "rb") as f:
201
  bytes_data = f.read()
202
  st.download_button(
203
- label="Download result(.nii.gz)",
204
  data=bytes_data,
205
  file_name="segmentation.nii.gz",
206
  mime="application/octet-stream",
@@ -208,12 +219,25 @@ with tempfile.TemporaryDirectory() as tmpdirname:
208
  )
209
 
210
  with col3:
 
 
 
 
 
 
 
 
 
 
 
 
211
  run_button_name = "Run" if not st.session_state.running else "Running"
212
  if st.button(
213
  run_button_name,
214
  type="primary",
215
  use_container_width=True,
216
- disabled=(st.session_state.data_item is None or st.session_state.running),
 
217
  ):
218
  st.session_state.running = True
219
  st.rerun()
@@ -221,5 +245,5 @@ with tempfile.TemporaryDirectory() as tmpdirname:
221
  if st.session_state.running:
222
  st.session_state.running = False
223
  with st.status("Running...", expanded=False) as status:
224
- run(tmpdirname)
225
  st.rerun()
 
3
  import tempfile
4
  from pathlib import Path
5
 
 
 
 
 
 
 
6
  import SimpleITK as sitk
7
+ from mrsegmentator.utils import add_postfix
8
+
9
  import streamlit as st
10
  import utils
 
 
 
 
 
 
 
 
 
11
 
12
  print("script run")
13
  st.title("MRSegmentator")
14
+ st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)")
15
 
16
  #############################################
17
  # init session_state
 
38
  st.session_state.transparency = 0.25
39
 
40
  case_list = [
41
+ "amos_0517_MRI.nii.gz",
42
+ "amos_0541_MRI.nii.gz",
43
+ "amos_0571_MRI.nii.gz",
44
  ]
45
 
46
  #############################################
 
55
  def reset_demo_case():
56
  st.session_state.data_item = None
57
  st.session_state.reset_demo_case = True
58
+ st.session_state.preds_3D = None
59
+ st.session_state.preds_3D_ori = None
60
  clear_prompts()
61
 
62
+
63
  def clear_file():
64
  st.session_state.option = None
65
  reset_demo_case()
 
77
  st.write("Paper: https://arxiv.org/abs/2405.06463")
78
 
79
  # modify demo case here
80
+ demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file)
81
 
82
  with tempfile.TemporaryDirectory() as tmpdirname:
83
 
84
  # modify demo case here
85
+ if demo_type == "Select (presegmented)":
86
+ selection = st.selectbox(
87
  "Select a demo case",
88
  case_list,
89
  index=None,
90
  placeholder="Select a demo case...",
91
  on_change=reset_demo_case,
92
  )
93
+
94
+ if selection:
95
+ uploaded_file = "images/" + selection
96
+ seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg"))
97
+ st.session_state.preds_3D = utils.read_image(seg_path)
98
+ st.session_state.preds_3D_ori = sitk.ReadImage(seg_path)
99
+ else:
100
+ uploaded_file = None
101
+
102
  else:
103
+ uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case)
 
 
104
 
105
+ if uploaded_file is not None:
106
+ with open(tmpdirname + "/" + uploaded_file.name, "wb") as f:
107
  f.write(uploaded_file.getvalue())
108
  uploaded_file = tmpdirname + "/" + uploaded_file.name
109
 
 
116
  ):
117
 
118
  st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
119
+ # st.session_state.preds_3D = None
120
+ # st.session_state.preds_3D_ori = None
121
  st.session_state.reset_demo_case = False
 
 
 
122
 
123
  if st.session_state.option is None:
124
  st.write("please select demo case first")
125
  else:
126
  image_3D = st.session_state.data_item
127
+ px_range = st.slider(
128
+ "Select intensity range",
129
+ int(image_3D.min()),
130
+ int(image_3D.max()),
131
+ (int(image_3D.min()), int(image_3D.max())),
132
+ )
133
  col_control1, col_control2 = st.columns(2)
134
 
135
  with col_control1:
136
  selected_index_z = st.slider(
137
+ "Axial view",
138
+ 0,
139
+ image_3D.shape[0] - 1,
140
+ image_3D.shape[0] // 2,
141
+ key="xy",
142
+ disabled=st.session_state.running,
143
  )
144
 
145
  with col_control2:
146
  selected_index_y = st.slider(
147
+ "Coronal view",
148
+ 0,
149
+ image_3D.shape[1] - 1,
150
+ image_3D.shape[1] // 2,
151
+ key="xz",
152
+ disabled=st.session_state.running,
153
  )
154
 
155
  col_image1, col_image2 = st.columns(2)
156
 
157
  if st.session_state.preds_3D is not None:
158
+ st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running)
 
 
159
 
160
  with col_image1:
161
 
 
165
  if st.session_state.preds_3D is not None:
166
  preds_z_array = st.session_state.preds_3D[selected_index_z]
167
 
168
+ image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
169
  st.image(image_z, use_column_width=False)
170
 
171
  with col_image2:
 
175
  if st.session_state.preds_3D is not None:
176
  preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
177
 
178
+ image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
179
  st.image(image_y, use_column_width=False)
180
 
181
  ######################################################
 
183
  col1, col2, col3 = st.columns(3)
184
 
185
  with col1:
186
+ st.markdown("#")
187
+ st.markdown("####")
188
+ st.markdown("####")
189
  if st.button(
190
  "Clear",
191
  use_container_width=True,
 
197
  st.rerun()
198
 
199
  with col2:
200
+ st.markdown("#")
201
+ st.markdown("####")
202
+ st.markdown("####")
203
 
204
  if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
205
 
206
  with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
207
 
208
  preds = st.session_state.preds_3D_ori
 
209
  sitk.WriteImage(preds, tmpfile.name)
210
+
211
  with open(tmpfile.name, "rb") as f:
212
  bytes_data = f.read()
213
  st.download_button(
214
+ label="Download result (.nii.gz)",
215
  data=bytes_data,
216
  file_name="segmentation.nii.gz",
217
  mime="application/octet-stream",
 
219
  )
220
 
221
  with col3:
222
+ folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"])
223
+ if folds == "Model of Fold 1":
224
+ st.session_state.folds = (0,)
225
+ else:
226
+ st.session_state.folds = (
227
+ 0,
228
+ 1,
229
+ 2,
230
+ 3,
231
+ 4,
232
+ )
233
+
234
  run_button_name = "Run" if not st.session_state.running else "Running"
235
  if st.button(
236
  run_button_name,
237
  type="primary",
238
  use_container_width=True,
239
+ disabled=True,
240
+ # disabled=(st.session_state.data_item is None or st.session_state.running),
241
  ):
242
  st.session_state.running = True
243
  st.rerun()
 
245
  if st.session_state.running:
246
  st.session_state.running = False
247
  with st.status("Running...", expanded=False) as status:
248
+ utils.run(tmpdirname)
249
  st.rerun()