Alexander Becker commited on
Commit
b5731be
·
1 Parent(s): 1df2a85

Add model selector

Browse files
Files changed (2) hide show
  1. app.py +32 -7
  2. gradio_dualvision/app_template.py +1 -1
app.py CHANGED
@@ -13,23 +13,29 @@ from huggingface_hub import hf_hub_download
13
  from model import build_thera
14
  from super_resolve import process
15
 
16
- REPO_ID = "prs-eth/thera-edsr-plus"
 
17
 
18
  print(f"JAX devices: {jax.devices()}")
19
  print(f"JAX device type: {jax.devices()[0].device_kind}")
20
 
21
- # load model
22
- model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
23
  with open(model_path, 'rb') as fh:
24
  check = pickle.load(fh)
25
- params, backbone, size = check['model'], check['backbone'], check['size']
 
26
 
27
- model = build_thera(3, backbone, size)
 
 
 
 
28
 
29
 
30
  class TheraApp(DualVisionApp):
31
  DEFAULT_SCALE = 3.92
32
  DEFAULT_DO_ENSEMBLE = False
 
33
 
34
  def make_header(self):
35
  gr.Markdown(
@@ -47,7 +53,7 @@ class TheraApp(DualVisionApp):
47
  </a>
48
  </p>
49
  <p align="center" style="margin-top: 0px;">
50
- Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!
51
  </p>
52
  """
53
  )
@@ -61,6 +67,14 @@ class TheraApp(DualVisionApp):
61
  step=0.01,
62
  value=self.DEFAULT_SCALE,
63
  )
 
 
 
 
 
 
 
 
64
  do_ensemble = gr.Radio(
65
  [
66
  ("No", False),
@@ -71,12 +85,14 @@ class TheraApp(DualVisionApp):
71
  )
72
  return {
73
  "scale": scale,
 
74
  "do_ensemble": do_ensemble,
75
  }
76
 
77
  def process(self, image_in: Image.Image, **kwargs):
78
  scale = kwargs.get("scale", self.DEFAULT_SCALE)
79
  do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE)
 
80
 
81
  source = np.asarray(image_in) / 255.
82
 
@@ -86,7 +102,14 @@ class TheraApp(DualVisionApp):
86
  round(source.shape[1] * scale),
87
  )
88
 
89
- out = process(source, model, params, target_shape, do_ensemble=do_ensemble)
 
 
 
 
 
 
 
90
  out = Image.fromarray(np.asarray(out))
91
 
92
  nearest = image_in.resize(out.size, Image.NEAREST)
@@ -97,6 +120,7 @@ class TheraApp(DualVisionApp):
97
  }
98
  out_settings = {
99
  'scale': scale,
 
100
  'do_ensemble': do_ensemble,
101
  }
102
  return out_modalities, out_settings
@@ -147,6 +171,7 @@ class TheraApp(DualVisionApp):
147
  )
148
  if any(k not in results_settings for k in self.input_keys):
149
  raise gr.Error(f"Mismatching setgings keys")
 
150
  results_settings = {
151
  k: cls(**ctor_args, value=results_settings[k])
152
  for k, cls, ctor_args in zip(
 
13
  from model import build_thera
14
  from super_resolve import process
15
 
16
+ REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
17
+ REPO_ID_RDN = "prs-eth/thera-rdn-pro"
18
 
19
  print(f"JAX devices: {jax.devices()}")
20
  print(f"JAX device type: {jax.devices()[0].device_kind}")
21
 
22
+ model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl")
 
23
  with open(model_path, 'rb') as fh:
24
  check = pickle.load(fh)
25
+ params_edsr, backbone, size = check['model'], check['backbone'], check['size']
26
+ model_edsr = build_thera(3, backbone, size)
27
 
28
+ model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl")
29
+ with open(model_path, 'rb') as fh:
30
+ check = pickle.load(fh)
31
+ params_rdn, backbone, size = check['model'], check['backbone'], check['size']
32
+ model_rdn = build_thera(3, backbone, size)
33
 
34
 
35
  class TheraApp(DualVisionApp):
36
  DEFAULT_SCALE = 3.92
37
  DEFAULT_DO_ENSEMBLE = False
38
+ DEFAULT_MODEL = 'edsr'
39
 
40
  def make_header(self):
41
  gr.Markdown(
 
53
  </a>
54
  </p>
55
  <p align="center" style="margin-top: 0px;">
56
+ <strong>Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!</strong>
57
  </p>
58
  """
59
  )
 
67
  step=0.01,
68
  value=self.DEFAULT_SCALE,
69
  )
70
+ model = gr.Radio(
71
+ [
72
+ ("EDSR", 'edsr'),
73
+ ("RDN", 'rdn'),
74
+ ],
75
+ label="Backbone",
76
+ value=self.DEFAULT_MODEL,
77
+ )
78
  do_ensemble = gr.Radio(
79
  [
80
  ("No", False),
 
85
  )
86
  return {
87
  "scale": scale,
88
+ "model": model,
89
  "do_ensemble": do_ensemble,
90
  }
91
 
92
  def process(self, image_in: Image.Image, **kwargs):
93
  scale = kwargs.get("scale", self.DEFAULT_SCALE)
94
  do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE)
95
+ model = kwargs.get("model", self.DEFAULT_MODEL)
96
 
97
  source = np.asarray(image_in) / 255.
98
 
 
102
  round(source.shape[1] * scale),
103
  )
104
 
105
+ if model == 'edsr':
106
+ m, p = model_edsr, params_edsr
107
+ elif model == 'rdn':
108
+ m, p = model_rdn, params_rdn
109
+ else:
110
+ raise NotImplementedError('model:', model)
111
+
112
+ out = process(source, m, p, target_shape, do_ensemble=do_ensemble)
113
  out = Image.fromarray(np.asarray(out))
114
 
115
  nearest = image_in.resize(out.size, Image.NEAREST)
 
120
  }
121
  out_settings = {
122
  'scale': scale,
123
+ 'model': model,
124
  'do_ensemble': do_ensemble,
125
  }
126
  return out_modalities, out_settings
 
171
  )
172
  if any(k not in results_settings for k in self.input_keys):
173
  raise gr.Error(f"Mismatching setgings keys")
174
+
175
  results_settings = {
176
  k: cls(**ctor_args, value=results_settings[k])
177
  for k, cls, ctor_args in zip(
gradio_dualvision/app_template.py CHANGED
@@ -228,7 +228,7 @@ class DualVisionApp(gr.Blocks):
228
  }}
229
  #settings-accordion {{
230
  margin: 0 auto;
231
- max-width: 500px;
232
  }}
233
  """
234
  if squeeze_canvas:
 
228
  }}
229
  #settings-accordion {{
230
  margin: 0 auto;
231
+ max-width: 650px;
232
  }}
233
  """
234
  if squeeze_canvas: