msong97 commited on
Commit
1eb9e66
·
1 Parent(s): fbe2693

add inference time, remove global variable for user-specific variable, consistency in idx_slider

Browse files
Files changed (1) hide show
  1. app.py +62 -56
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import os
3
  import random
 
4
  from functools import partial
5
  from pathlib import Path
6
  from typing import List
@@ -20,7 +21,6 @@ torch.set_grad_enabled(False) # stops tracking values for gradients
20
 
21
 
22
  ### Gradio Utils
23
-
24
  def generate_imgs_from_user(image,
25
  model: EvalModel, baseline: BaselineModel,
26
  physics: PhysicsWithGenerator, use_gen: bool,
@@ -60,37 +60,43 @@ def generate_imgs(x: torch.Tensor,
60
  physics: PhysicsWithGenerator, use_gen: bool,
61
  metrics: List[Metric]):
62
 
63
- with torch.no_grad():
64
- ### Compute y
65
- y = physics(x, use_gen) # possible reduction in img shape due to Blurring
66
-
67
- ### Compute x_hat
68
- out = model(y=y, physics=physics.physics)
69
- out_baseline = baseline(y=y, physics=physics.physics)
70
-
71
- ### Process tensors before metric computation
72
- if "Blur" in physics.name:
73
- w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
74
- h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2
75
-
76
- x = x[..., w_1:w_2, h_1:h_2]
77
- out = out[..., w_1:w_2, h_1:h_2]
78
- if out_baseline.shape != out.shape:
79
- out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
80
-
81
- ### Metrics
82
- metrics_y = ""
83
- metrics_out = ""
84
- metrics_out_baseline = ""
85
- for metric in metrics:
86
- if y.shape == x.shape:
87
- metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
88
- metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
89
- metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
 
 
 
 
90
 
91
  ### Process y when y shape is different from x shape
92
- if physics.name == "MRI" in physics.name:
93
  y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
 
 
94
  else:
95
  y_plot = y.clone()
96
 
@@ -114,28 +120,26 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
114
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
115
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
116
 
117
- AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
118
- 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
119
  def get_dataset(dataset_name):
120
- global AVAILABLE_PHYSICS
121
  if dataset_name == 'MRI':
122
- AVAILABLE_PHYSICS = ['MRI']
123
- baseline_name = 'DPIR_MRI'
124
  physics_name = 'MRI'
 
125
  elif dataset_name == 'CT':
126
- AVAILABLE_PHYSICS = ['CT']
127
- baseline_name = 'DPIR_CT'
128
  physics_name = 'CT'
 
129
  else:
130
- AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
131
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
132
- baseline_name = 'DPIR'
133
  physics_name = 'MotionBlur_easy'
 
134
 
135
  dataset = get_dataset_on_DEVICE_STR(dataset_name)
 
136
  physics = get_physics_on_DEVICE_STR(physics_name)
137
  baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
138
- return dataset, physics, baseline
139
 
140
 
141
  ### Gradio Blocks interface
@@ -144,28 +148,30 @@ title = "Inverse problem playground" # displayed on gradio tab and in the gradi
144
  with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
145
  gr.Markdown("## " + title)
146
 
147
- # DEFAULT VALUES
148
  # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
149
  # Solution: using lambda expression
150
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
151
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
 
 
152
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
153
  physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
154
- idx_placeholder = gr.State(0)
 
155
 
156
- metric_names = ["PSNR"]
157
- metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(metric_names))
158
 
159
- @gr.render(inputs=[dataset_placeholder, physics_placeholder])
160
- def dynamic_layout(dataset, physics):
161
- ### LAYOUT
 
 
 
 
162
 
163
- # Display images
164
- with gr.Row():
165
- gt_img = gr.Image(label=f"Ground-truth IMAGE", interactive=True)
166
- observed_img = gr.Image(label=f"Observed IMAGE", interactive=False)
167
- model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
168
- model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
169
 
170
  # Manage datasets and display metric values
171
  with gr.Row():
@@ -174,7 +180,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
174
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
175
  label="Datasets",
176
  value=dataset.name)
177
- idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index")
178
  with gr.Row():
179
  load_button = gr.Button("Run on index image from dataset")
180
  load_random_button = gr.Button("Run on random image from dataset")
@@ -191,14 +197,14 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
191
  # Manage physics
192
  with gr.Row():
193
  with gr.Column(scale=1):
194
- choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
195
  label="Physics",
196
  value=physics.name)
197
  use_generator_button = gr.Checkbox(label="Generate physics parameters during inference")
198
  with gr.Column(scale=1):
199
  with gr.Row():
200
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
201
- label="Updatable Parameter Key")
202
  value_text = gr.Textbox(label="Update Value")
203
  update_button = gr.Button("Manually update parameter value")
204
  with gr.Column(scale=2):
@@ -211,7 +217,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
211
 
212
  choose_dataset.change(fn=get_dataset,
213
  inputs=choose_dataset,
214
- outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
215
  choose_physics.change(fn=get_physics_on_DEVICE_STR,
216
  inputs=choose_physics,
217
  outputs=[physics_placeholder])
 
1
  import json
2
  import os
3
  import random
4
+ import time
5
  from functools import partial
6
  from pathlib import Path
7
  from typing import List
 
21
 
22
 
23
  ### Gradio Utils
 
24
  def generate_imgs_from_user(image,
25
  model: EvalModel, baseline: BaselineModel,
26
  physics: PhysicsWithGenerator, use_gen: bool,
 
60
  physics: PhysicsWithGenerator, use_gen: bool,
61
  metrics: List[Metric]):
62
 
63
+ ### Compute y
64
+ y = physics(x, use_gen) # possible reduction in img shape due to Blurring
65
+
66
+ ### Compute x_hat from RAM & DPIR
67
+ ram_time = time.time()
68
+ out = model(y=y, physics=physics.physics)
69
+ ram_time = time.time() - ram_time
70
+
71
+ dpir_time = time.time()
72
+ out_baseline = baseline(y=y, physics=physics.physics)
73
+ dpir_time = time.time() - dpir_time
74
+
75
+ ### Process tensors before metric computation
76
+ if "Blur" in physics.name:
77
+ w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
78
+ h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2
79
+
80
+ x = x[..., w_1:w_2, h_1:h_2]
81
+ out = out[..., w_1:w_2, h_1:h_2]
82
+ if out_baseline.shape != out.shape:
83
+ out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
84
+
85
+ ### Metrics
86
+ metrics_y = ""
87
+ metrics_out = f"Inference time = {ram_time:.3f}s" + "\n"
88
+ metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n"
89
+ for metric in metrics:
90
+ if y.shape == x.shape:
91
+ metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
92
+ metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
93
+ metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
94
 
95
  ### Process y when y shape is different from x shape
96
+ if physics.name == "MRI":
97
  y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
98
+ elif physics.name == "CT":
99
+ y_plot = physics.physics.A_adjoint(y)
100
  else:
101
  y_plot = y.clone()
102
 
 
120
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
121
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
122
 
 
 
123
  def get_dataset(dataset_name):
 
124
  if dataset_name == 'MRI':
125
+ available_physics = ['MRI']
 
126
  physics_name = 'MRI'
127
+ baseline_name = 'DPIR_MRI'
128
  elif dataset_name == 'CT':
129
+ available_physics = ['CT']
 
130
  physics_name = 'CT'
131
+ baseline_name = 'DPIR_CT'
132
  else:
133
+ available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
134
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
 
135
  physics_name = 'MotionBlur_easy'
136
+ baseline_name = 'DPIR'
137
 
138
  dataset = get_dataset_on_DEVICE_STR(dataset_name)
139
+ idx = 0
140
  physics = get_physics_on_DEVICE_STR(physics_name)
141
  baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
142
+ return dataset, idx, physics, baseline, available_physics
143
 
144
 
145
  ### Gradio Blocks interface
 
148
  with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
149
  gr.Markdown("## " + title)
150
 
151
+ ### DEFAULT VALUES
152
  # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
153
  # Solution: using lambda expression
154
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
155
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
156
+ metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
157
+
158
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
159
  physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
160
+ available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
161
+ 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
162
 
 
 
163
 
164
+ ### LAYOUT
165
+ # Display images
166
+ with gr.Row():
167
+ gt_img = gr.Image(label="Ground-truth IMAGE", interactive=True)
168
+ observed_img = gr.Image(label="Observed IMAGE", interactive=False)
169
+ model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
170
+ model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
171
 
172
+ @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
173
+ def dynamic_layout(dataset, physics, available_physics):
174
+ ### LAYOUT
 
 
 
175
 
176
  # Manage datasets and display metric values
177
  with gr.Row():
 
180
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
181
  label="Datasets",
182
  value=dataset.name)
183
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=0)
184
  with gr.Row():
185
  load_button = gr.Button("Run on index image from dataset")
186
  load_random_button = gr.Button("Run on random image from dataset")
 
197
  # Manage physics
198
  with gr.Row():
199
  with gr.Column(scale=1):
200
+ choose_physics = gr.Radio(choices=available_physics,
201
  label="Physics",
202
  value=physics.name)
203
  use_generator_button = gr.Checkbox(label="Generate physics parameters during inference")
204
  with gr.Column(scale=1):
205
  with gr.Row():
206
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
207
+ label="Updatable Parameter Key")
208
  value_text = gr.Textbox(label="Update Value")
209
  update_button = gr.Button("Manually update parameter value")
210
  with gr.Column(scale=2):
 
217
 
218
  choose_dataset.change(fn=get_dataset,
219
  inputs=choose_dataset,
220
+ outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder])
221
  choose_physics.change(fn=get_physics_on_DEVICE_STR,
222
  inputs=choose_physics,
223
  outputs=[physics_placeholder])