Spaces:
Running
on
T4
Running
on
T4
Load automatically the right baseline and default physics when choosing a dataset
Browse files
app.py
CHANGED
@@ -112,25 +112,23 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
|
112 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
113 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
114 |
|
115 |
-
def get_physics(physics_name):
|
116 |
-
if physics_name == 'MRI':
|
117 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR_MRI')
|
118 |
-
elif physics_name == 'CT':
|
119 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR_CT')
|
120 |
-
else:
|
121 |
-
baseline = get_baseline_model_on_DEVICE_STR('DPIR')
|
122 |
-
return get_physics_on_DEVICE_STR(physics_name), baseline
|
123 |
-
|
124 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
125 |
def get_dataset(dataset_name):
|
126 |
global AVAILABLE_PHYSICS
|
127 |
if dataset_name == 'MRI':
|
128 |
AVAILABLE_PHYSICS = ['MRI']
|
|
|
|
|
129 |
elif dataset_name == 'CT':
|
130 |
AVAILABLE_PHYSICS = ['CT']
|
|
|
|
|
131 |
else:
|
132 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
### Gradio Blocks interface
|
136 |
|
@@ -212,10 +210,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
212 |
### Event listeners
|
213 |
choose_dataset.change(fn=get_dataset,
|
214 |
inputs=choose_dataset,
|
215 |
-
outputs=dataset_placeholder)
|
216 |
-
choose_physics.change(fn=
|
217 |
inputs=choose_physics,
|
218 |
-
outputs=[physics_placeholder
|
219 |
update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
|
220 |
choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
|
221 |
inputs=choose_metrics,
|
|
|
112 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
113 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
|
116 |
def get_dataset(dataset_name):
|
117 |
global AVAILABLE_PHYSICS
|
118 |
if dataset_name == 'MRI':
|
119 |
AVAILABLE_PHYSICS = ['MRI']
|
120 |
+
baseline_name = 'DPIR_MRI'
|
121 |
+
physics_name = 'MRI'
|
122 |
elif dataset_name == 'CT':
|
123 |
AVAILABLE_PHYSICS = ['CT']
|
124 |
+
baseline_name = 'DPIR_CT'
|
125 |
+
physics_name = 'CT'
|
126 |
else:
|
127 |
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
128 |
+
baseline_name = 'DPIR'
|
129 |
+
physics_name = 'MotionBlur_easy'
|
130 |
+
return get_dataset_on_DEVICE_STR(dataset_name), get_physics_on_DEVICE_STR(physics_name), get_baseline_model_on_DEVICE_STR(baseline_name)
|
131 |
+
|
132 |
|
133 |
### Gradio Blocks interface
|
134 |
|
|
|
210 |
### Event listeners
|
211 |
choose_dataset.change(fn=get_dataset,
|
212 |
inputs=choose_dataset,
|
213 |
+
outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
|
214 |
+
choose_physics.change(fn=get_physics_on_DEVICE_STR,
|
215 |
inputs=choose_physics,
|
216 |
+
outputs=[physics_placeholder])
|
217 |
update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
|
218 |
choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
|
219 |
inputs=choose_metrics,
|