andreped commited on
Commit
35de495
2 Parent(s): f6ffe33 8291c90

Merge pull request #59 from andreped/download-button

Browse files

Download button in demo app; upgraded gradio to latest

demo/requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib
2
- gradio==3.50.2
 
1
  raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib
2
+ gradio==4.29.0
demo/src/css_style.py CHANGED
@@ -9,8 +9,12 @@ margin: auto;
9
  #upload {
10
  height: 110px;
11
  }
 
 
 
 
12
  #run-button {
13
- height: 110px;
14
  width: 150px;
15
  }
16
  #toggle-button {
 
9
  #upload {
10
  height: 110px;
11
  }
12
+ #download {
13
+ height: 47px;
14
+ width: 150px;
15
+ }
16
  #run-button {
17
+ height: 47px;
18
  width: 150px;
19
  }
20
  #toggle-button {
demo/src/gui.py CHANGED
@@ -59,7 +59,8 @@ class WebUI:
59
  visible=True,
60
  elem_id="model-3d",
61
  camera_position=[90, 180, 768],
62
- ).style(height=512)
 
63
 
64
  def set_class_name(self, value):
65
  LOGGER.info(f"Changed task to: {value}")
@@ -75,30 +76,44 @@ class WebUI:
75
 
76
  def process(self, mesh_file_name):
77
  path = mesh_file_name.name
 
 
 
 
 
78
  run_model(
79
  path,
80
  model_path=os.path.join(self.cwd, "resources/models/"),
81
  task=self.class_names[self.class_name],
82
  name=self.result_names[self.class_name],
 
83
  )
84
  LOGGER.info("Converting prediction NIfTI to OBJ...")
85
- nifti_to_obj("prediction.nii.gz")
86
 
87
  LOGGER.info("Loading CT to numpy...")
88
  self.images = load_ct_to_numpy(path)
89
 
90
  LOGGER.info("Loading prediction volume to numpy..")
91
- self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
 
 
92
 
93
  return "./prediction.obj"
94
 
 
 
 
 
 
 
 
95
  def get_img_pred_pair(self, k):
96
  k = int(k)
97
  out = gr.AnnotatedImage(
98
  self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
99
  visible=True,
100
  elem_id="model-2d",
101
- ).style(
102
  color_map={self.class_name: "#ffae00"},
103
  height=512,
104
  width=512,
@@ -117,20 +132,18 @@ class WebUI:
117
  placeholder="\n" * 16,
118
  label="Logs",
119
  info="Verbose from inference will be displayed below.",
120
- lines=38,
121
- max_lines=38,
122
  autoscroll=True,
123
  elem_id="logs",
124
  show_copy_button=True,
125
- scroll_to_output=False,
126
  container=True,
127
- line_breaks=True,
128
  )
129
  demo.load(read_logs, None, logs, every=1)
130
 
131
  with gr.Column():
132
  with gr.Row():
133
- with gr.Column(scale=0.2, min_width=150):
134
  sidebar_state = gr.State(True)
135
 
136
  btn_toggle_sidebar = gr.Button(
@@ -149,7 +162,9 @@ class WebUI:
149
  btn_clear_logs.click(flush_logs, [], [])
150
 
151
  file_output = gr.File(
152
- file_count="single", elem_id="upload"
 
 
153
  )
154
  file_output.upload(
155
  self.upload_file, file_output, file_output
@@ -160,7 +175,7 @@ class WebUI:
160
  label="Task",
161
  info="Which structure to segment.",
162
  multiselect=False,
163
- size="sm",
164
  )
165
  model_selector.input(
166
  fn=lambda x: self.set_class_name(x),
@@ -168,14 +183,11 @@ class WebUI:
168
  outputs=None,
169
  )
170
 
171
- with gr.Column(scale=0.2, min_width=150):
172
  run_btn = gr.Button(
173
  "Run analysis",
174
  variant="primary",
175
  elem_id="run-button",
176
- ).style(
177
- full_width=False,
178
- size="lg",
179
  )
180
  run_btn.click(
181
  fn=lambda x: self.process(x),
@@ -183,6 +195,18 @@ class WebUI:
183
  outputs=self.volume_renderer,
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  with gr.Row():
187
  gr.Examples(
188
  examples=[
@@ -202,17 +226,16 @@ class WebUI:
202
  )
203
 
204
  with gr.Row():
205
- with gr.Box():
206
  with gr.Column():
207
  # create dummy image to be replaced by loaded images
208
  t = gr.AnnotatedImage(
209
- visible=True, elem_id="model-2d"
210
- ).style(
211
  color_map={self.class_name: "#ffae00"},
212
- height=512,
213
- width=512,
214
  )
215
-
216
  self.slider.input(
217
  self.get_img_pred_pair,
218
  self.slider,
@@ -221,7 +244,7 @@ class WebUI:
221
 
222
  self.slider.render()
223
 
224
- with gr.Box():
225
  self.volume_renderer.render()
226
 
227
  # sharing app publicly -> share=True:
 
59
  visible=True,
60
  elem_id="model-3d",
61
  camera_position=[90, 180, 768],
62
+ height=512,
63
+ )
64
 
65
  def set_class_name(self, value):
66
  LOGGER.info(f"Changed task to: {value}")
 
76
 
77
  def process(self, mesh_file_name):
78
  path = mesh_file_name.name
79
+ curr = path.split("/")[-1]
80
+ self.extension = ".".join(curr.split(".")[1:])
81
+ self.filename = (
82
+ curr.split(".")[0] + "-" + self.class_names[self.class_name]
83
+ )
84
  run_model(
85
  path,
86
  model_path=os.path.join(self.cwd, "resources/models/"),
87
  task=self.class_names[self.class_name],
88
  name=self.result_names[self.class_name],
89
+ output_filename=self.filename + "." + self.extension,
90
  )
91
  LOGGER.info("Converting prediction NIfTI to OBJ...")
92
+ nifti_to_obj(path=self.filename + "." + self.extension)
93
 
94
  LOGGER.info("Loading CT to numpy...")
95
  self.images = load_ct_to_numpy(path)
96
 
97
  LOGGER.info("Loading prediction volume to numpy..")
98
+ self.pred_images = load_pred_volume_to_numpy(
99
+ self.filename + "." + self.extension
100
+ )
101
 
102
  return "./prediction.obj"
103
 
104
+ def download_prediction(self):
105
+ if (not self.filename) or (not self.extension):
106
+ LOGGER.error(
107
+ "The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
108
+ )
109
+ return self.filename + "." + self.extension
110
+
111
  def get_img_pred_pair(self, k):
112
  k = int(k)
113
  out = gr.AnnotatedImage(
114
  self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
115
  visible=True,
116
  elem_id="model-2d",
 
117
  color_map={self.class_name: "#ffae00"},
118
  height=512,
119
  width=512,
 
132
  placeholder="\n" * 16,
133
  label="Logs",
134
  info="Verbose from inference will be displayed below.",
135
+ lines=36,
136
+ max_lines=36,
137
  autoscroll=True,
138
  elem_id="logs",
139
  show_copy_button=True,
 
140
  container=True,
 
141
  )
142
  demo.load(read_logs, None, logs, every=1)
143
 
144
  with gr.Column():
145
  with gr.Row():
146
+ with gr.Column(scale=1, min_width=150):
147
  sidebar_state = gr.State(True)
148
 
149
  btn_toggle_sidebar = gr.Button(
 
162
  btn_clear_logs.click(flush_logs, [], [])
163
 
164
  file_output = gr.File(
165
+ file_count="single",
166
+ elem_id="upload",
167
+ scale=3,
168
  )
169
  file_output.upload(
170
  self.upload_file, file_output, file_output
 
175
  label="Task",
176
  info="Which structure to segment.",
177
  multiselect=False,
178
+ scale=1,
179
  )
180
  model_selector.input(
181
  fn=lambda x: self.set_class_name(x),
 
183
  outputs=None,
184
  )
185
 
186
+ with gr.Column(scale=1, min_width=150):
187
  run_btn = gr.Button(
188
  "Run analysis",
189
  variant="primary",
190
  elem_id="run-button",
 
 
 
191
  )
192
  run_btn.click(
193
  fn=lambda x: self.process(x),
 
195
  outputs=self.volume_renderer,
196
  )
197
 
198
+ download_btn = gr.DownloadButton(
199
+ "Download prediction",
200
+ visible=True,
201
+ variant="secondary",
202
+ elem_id="download",
203
+ )
204
+ download_btn.click(
205
+ fn=self.download_prediction,
206
+ inputs=None,
207
+ outputs=download_btn,
208
+ )
209
+
210
  with gr.Row():
211
  gr.Examples(
212
  examples=[
 
226
  )
227
 
228
  with gr.Row():
229
+ with gr.Group():
230
  with gr.Column():
231
  # create dummy image to be replaced by loaded images
232
  t = gr.AnnotatedImage(
233
+ visible=True,
234
+ elem_id="model-2d",
235
  color_map={self.class_name: "#ffae00"},
236
+ # height=512,
237
+ # width=512,
238
  )
 
239
  self.slider.input(
240
  self.get_img_pred_pair,
241
  self.slider,
 
244
 
245
  self.slider.render()
246
 
247
+ with gr.Group(): # gr.Box():
248
  self.volume_renderer.render()
249
 
250
  # sharing app publicly -> share=True:
demo/src/inference.py CHANGED
@@ -11,6 +11,7 @@ def run_model(
11
  verbose: str = "info",
12
  task: str = "CT_Airways",
13
  name: str = "Airways",
 
14
  ):
15
  if verbose == "debug":
16
  logging.getLogger().setLevel(logging.DEBUG)
@@ -27,6 +28,9 @@ def run_model(
27
  if os.path.exists("./result/"):
28
  shutil.rmtree("./result/")
29
 
 
 
 
30
  patient_directory = ""
31
  output_path = ""
32
  try:
@@ -84,7 +88,7 @@ def run_model(
84
  + "-t1gd_annotation-"
85
  + name
86
  + ".nii.gz",
87
- "./prediction.nii.gz",
88
  )
89
  # Clean-up
90
  if os.path.exists(patient_directory):
 
11
  verbose: str = "info",
12
  task: str = "CT_Airways",
13
  name: str = "Airways",
14
+ output_filename: str = None,
15
  ):
16
  if verbose == "debug":
17
  logging.getLogger().setLevel(logging.DEBUG)
 
28
  if os.path.exists("./result/"):
29
  shutil.rmtree("./result/")
30
 
31
+ if output_filename is None:
32
+ raise ValueError("Please, set output_filename.")
33
+
34
  patient_directory = ""
35
  output_path = ""
36
  try:
 
88
  + "-t1gd_annotation-"
89
  + name
90
  + ".nii.gz",
91
+ output_filename,
92
  )
93
  # Clean-up
94
  if os.path.exists(patient_directory):