ohayonguy commited on
Commit
20ac05d
·
1 Parent(s): 8d5efa4

trying to fix interface

Browse files
Files changed (1) hide show
  1. app.py +19 -87
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  if os.getenv('SPACES_ZERO_GPU') == "true":
3
  os.environ['SPACES_ZERO_GPU'] = "1"
4
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
@@ -29,7 +30,8 @@ if not os.path.exists(realesr_model_path):
29
  # background enhancer with RealESRGAN
30
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
31
  half = True if torch.cuda.is_available() else False
32
- upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=400, tile_pad=10, pre_pad=0, half=half)
 
33
 
34
  pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device=device)
35
 
@@ -43,8 +45,6 @@ face_helper_dummy = FaceRestoreHelper(
43
  device=device,
44
  model_rootpath=None)
45
 
46
- os.makedirs('output', exist_ok=True)
47
-
48
 
49
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
50
  source_dist_samples = pmrf_model.create_source_distribution_samples(x, y, non_noisy_z0)
@@ -58,6 +58,7 @@ def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, dev
58
 
59
  return x_t_next.clip(0, 1).to(torch.float32)
60
 
 
61
  @torch.inference_mode()
62
  @spaces.GPU()
63
  def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face=False, paste_back=True, scale=2):
@@ -73,21 +74,19 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face
73
  # align and warp each face
74
  face_helper.align_warp_face()
75
  # face restoration
76
- for cropped_face in face_helper.cropped_faces:
77
  # prepare data
78
  h, w = cropped_face.shape[0], cropped_face.shape[1]
79
  cropped_face = cv2.resize(cropped_face, (512, 512), interpolation=cv2.INTER_LINEAR)
 
80
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
81
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
82
 
83
  dummy_x = torch.zeros_like(cropped_face_t)
84
- # with torch.autocast("cuda", dtype=torch.bfloat16):
85
  output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, num_flow_steps, device)
86
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
87
- # restored_face = cropped_face
88
  restored_face = cv2.resize(restored_face, (h, w), interpolation=cv2.INTER_LINEAR)
89
 
90
-
91
  restored_face = restored_face.astype('uint8')
92
  face_helper.add_restored_face(restored_face)
93
 
@@ -124,9 +123,6 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps):
124
  print('Image size too large.')
125
  return None, None
126
 
127
- if h < 300:
128
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
129
-
130
  face_helper = FaceRestoreHelper(
131
  scale,
132
  face_size=512,
@@ -139,7 +135,8 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps):
139
 
140
  has_aligned = True if aligned == 'Yes' else False
141
  cropped_face, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
142
- paste_back=True, num_flow_steps=num_flow_steps, scale=scale)
 
143
  if has_aligned:
144
  output = restored_aligned[0]
145
  input = cropped_face[0].astype('uint8')
@@ -147,14 +144,12 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps):
147
  output = restored_img
148
  input = img
149
 
150
- save_path = f'output/out.png'
151
- cv2.imwrite(save_path, output)
152
-
153
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
154
  h, w = output.shape[0:2]
155
  input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
156
  input = cv2.resize(input, (h, w), interpolation=cv2.INTER_LINEAR)
157
- return [[input, output, seed], save_path]
 
158
 
159
  intro = """
160
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration</h1>
@@ -166,17 +161,18 @@ intro = """
166
  """
167
  markdown_top = """
168
  Gradio demo for the blind face image restoration version of [Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration](https://arxiv.org/abs/2410.00418).
 
169
 
170
  Please refer to our project's page for more details: https://pmrf-ml.github.io/.
171
 
172
  ---
173
 
174
- You may use this demo to enhance the quality of any image which contains faces.
175
-
176
  *Notes* :
177
 
178
  1. Our model is designed to restore aligned face images, but here we incorporate mechanisms that allow restoring the quality of any image that contains any number of faces. Thus, the resulting quality of such general images is not guaranteed.
179
  2. Images that are too large won't work due to memory constraints.
 
 
180
  """
181
 
182
  article = r"""
@@ -186,7 +182,6 @@ If you find our work useful, please help to ⭐ our <a href='https://github.com/
186
 
187
  📝 **Citation**
188
 
189
- If our work is useful for your research, please consider citing:
190
  ```bibtex
191
  @article{ohayon2024pmrf,
192
  author = {Guy Ohayon and Tomer Michaeli and Michael Elad},
@@ -214,15 +209,10 @@ css = """
214
  }
215
  """
216
 
217
-
218
-
219
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
220
  gr.HTML(intro)
221
  gr.Markdown(markdown_top)
222
 
223
- with gr.Row():
224
- run_button = gr.Button(value="Submit", variant="primary")
225
-
226
  with gr.Row():
227
  with gr.Column(scale=2):
228
  input_im = gr.Image(label="Input Image", type="filepath")
@@ -250,54 +240,13 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
250
  )
251
 
252
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
253
- aligned = gr.Checkbox(label="The input is an aligned face image", value=True)
254
 
255
  with gr.Row():
256
- result = ImageSlider(label="Input / Output", type="numpy", interactive=True)
 
257
  with gr.Row():
258
- file = gr.File(label="Download the output image")
259
-
260
- # examples = gr.Examples(
261
- # examples=[
262
- # # [42, False, "examples/image_1.jpg", 28, 4, 0.6],
263
- # # [42, False, "examples/image_2.jpg", 28, 4, 0.6],
264
- # # [42, False, "examples/image_3.jpg", 28, 4, 0.6],
265
- # # [42, False, "examples/image_4.jpg", 28, 4, 0.6],
266
- # # [42, False, "examples/image_5.jpg", 28, 4, 0.6],
267
- # # [42, False, "examples/image_6.jpg", 28, 4, 0.6],
268
- # ],
269
- # inputs=[
270
- # seed,
271
- # randomize_seed,
272
- # input_im,
273
- # num_inference_steps,
274
- # upscale_factor,
275
- # controlnet_conditioning_scale,
276
- # ],
277
- # fn=infer,
278
- # outputs=result,
279
- # cache_examples="lazy",
280
- # )
281
-
282
- # examples = gr.Examples(
283
- # examples=[
284
- # #[42, False, "examples/image_1.jpg", 28, 4, 0.6],
285
- # [42, False, "examples/image_2.jpg", 28, 4, 0.6],
286
- # #[42, False, "examples/image_3.jpg", 28, 4, 0.6],
287
- # #[42, False, "examples/image_4.jpg", 28, 4, 0.6],
288
- # [42, False, "examples/image_5.jpg", 28, 4, 0.6],
289
- # [42, False, "examples/image_6.jpg", 28, 4, 0.6],
290
- # [42, False, "examples/image_7.jpg", 28, 4, 0.6],
291
- # ],
292
- # inputs=[
293
- # seed,
294
- # randomize_seed,
295
- # input_im,
296
- # num_inference_steps,
297
- # upscale_factor,
298
- # controlnet_conditioning_scale,
299
- # ],
300
- # )
301
 
302
  gr.Markdown(article)
303
  gr.on(
@@ -311,27 +260,10 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
311
  upscale_factor,
312
  num_inference_steps,
313
  ],
314
- outputs=[result, file],
315
  show_api=False,
316
  # show_progress="minimal",
317
  )
318
 
319
-
320
- # demo = gr.Interface(
321
- # inference, [
322
- # gr.Image(type="filepath", label="Input"),
323
- # gr.Radio(['Yes', 'No'], type="value", value='aligned', label='Is the input an aligned face image?'),
324
- # gr.Slider(label="Scale factor for the background upsampler. Applicable only to non-aligned face images.", minimum=1, maximum=4, value=2, step=0.1, interactive=True),
325
- # gr.Number(label="Number of flow steps. A higher value should result in better image quality, but will inference will take a longer time.", value=25),
326
- # ], [
327
- # gr.ImageSlider(type="numpy", label="Input / Output", interactive=True),
328
- # gr.File(label="Download the output image")
329
- # ],
330
- # title=title,
331
- # description=description,
332
- # article=article,
333
- # )
334
-
335
-
336
  demo.queue()
337
- demo.launch(state_session_capacity=15, show_api=False, share=False)
 
1
  import os
2
+
3
  if os.getenv('SPACES_ZERO_GPU') == "true":
4
  os.environ['SPACES_ZERO_GPU'] = "1"
5
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
 
30
  # background enhancer with RealESRGAN
31
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
32
  half = True if torch.cuda.is_available() else False
33
+ upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=400, tile_pad=10, pre_pad=0,
34
+ half=half)
35
 
36
  pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device=device)
37
 
 
45
  device=device,
46
  model_rootpath=None)
47
 
 
 
48
 
49
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
50
  source_dist_samples = pmrf_model.create_source_distribution_samples(x, y, non_noisy_z0)
 
58
 
59
  return x_t_next.clip(0, 1).to(torch.float32)
60
 
61
+
62
  @torch.inference_mode()
63
  @spaces.GPU()
64
  def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face=False, paste_back=True, scale=2):
 
74
  # align and warp each face
75
  face_helper.align_warp_face()
76
  # face restoration
77
+ for i, cropped_face in enumerate(face_helper.cropped_faces):
78
  # prepare data
79
  h, w = cropped_face.shape[0], cropped_face.shape[1]
80
  cropped_face = cv2.resize(cropped_face, (512, 512), interpolation=cv2.INTER_LINEAR)
81
+ face_helper.cropped_faces[i] = cropped_face
82
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
83
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
84
 
85
  dummy_x = torch.zeros_like(cropped_face_t)
 
86
  output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, num_flow_steps, device)
87
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
 
88
  restored_face = cv2.resize(restored_face, (h, w), interpolation=cv2.INTER_LINEAR)
89
 
 
90
  restored_face = restored_face.astype('uint8')
91
  face_helper.add_restored_face(restored_face)
92
 
 
123
  print('Image size too large.')
124
  return None, None
125
 
 
 
 
126
  face_helper = FaceRestoreHelper(
127
  scale,
128
  face_size=512,
 
135
 
136
  has_aligned = True if aligned == 'Yes' else False
137
  cropped_face, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
138
+ paste_back=True, num_flow_steps=num_flow_steps,
139
+ scale=scale)
140
  if has_aligned:
141
  output = restored_aligned[0]
142
  input = cropped_face[0].astype('uint8')
 
144
  output = restored_img
145
  input = img
146
 
 
 
 
147
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
148
  h, w = output.shape[0:2]
149
  input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
150
  input = cv2.resize(input, (h, w), interpolation=cv2.INTER_LINEAR)
151
+ return [input, output, seed]
152
+
153
 
154
  intro = """
155
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration</h1>
 
161
  """
162
  markdown_top = """
163
  Gradio demo for the blind face image restoration version of [Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration](https://arxiv.org/abs/2410.00418).
164
+ You may use this demo to enhance the quality of any image which contains faces.
165
 
166
  Please refer to our project's page for more details: https://pmrf-ml.github.io/.
167
 
168
  ---
169
 
 
 
170
  *Notes* :
171
 
172
  1. Our model is designed to restore aligned face images, but here we incorporate mechanisms that allow restoring the quality of any image that contains any number of faces. Thus, the resulting quality of such general images is not guaranteed.
173
  2. Images that are too large won't work due to memory constraints.
174
+
175
+ ---
176
  """
177
 
178
  article = r"""
 
182
 
183
  📝 **Citation**
184
 
 
185
  ```bibtex
186
  @article{ohayon2024pmrf,
187
  author = {Guy Ohayon and Tomer Michaeli and Michael Elad},
 
209
  }
210
  """
211
 
 
 
212
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
213
  gr.HTML(intro)
214
  gr.Markdown(markdown_top)
215
 
 
 
 
216
  with gr.Row():
217
  with gr.Column(scale=2):
218
  input_im = gr.Image(label="Input Image", type="filepath")
 
240
  )
241
 
242
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
243
+ aligned = gr.Checkbox(label="The input is an aligned face image", value=False)
244
 
245
  with gr.Row():
246
+ run_button = gr.Button(value="Submit", variant="primary")
247
+
248
  with gr.Row():
249
+ result = ImageSlider(label="Input / Output", type="numpy", interactive=True, show_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  gr.Markdown(article)
252
  gr.on(
 
260
  upscale_factor,
261
  num_inference_steps,
262
  ],
263
+ outputs=result,
264
  show_api=False,
265
  # show_progress="minimal",
266
  )
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  demo.queue()
269
+ demo.launch(state_session_capacity=15, show_api=False)