kcelia commited on
Commit
6f25c7a
·
unverified ·
1 Parent(s): e6267cc

chore: reshape the image to (100, 100, 3) if not and add a check for rgb format

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """A local gradio app that filters images using FHE."""
2
-
3
  import os
4
  import shutil
5
  import subprocess
@@ -191,6 +191,18 @@ def encrypt(user_id, input_image, filter_name):
191
 
192
  if input_image is None:
193
  raise gr.Error("Please choose an image first.")
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Retrieve the client API
196
  client = get_client(user_id, filter_name)
@@ -482,7 +494,7 @@ with demo:
482
  )
483
 
484
  output_image = gr.Image(
485
- label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
486
  interactive=False,
487
  height=256,
488
  width=256,
@@ -513,7 +525,7 @@ with demo:
513
  # Button to send the encodings to the server using post method
514
  get_output_button.click(
515
  get_output,
516
- inputs=[user_id, filter_name],
517
  outputs=[encrypted_output_representation]
518
  )
519
 
 
1
  """A local gradio app that filters images using FHE."""
2
+ from PIL import Image
3
  import os
4
  import shutil
5
  import subprocess
 
191
 
192
  if input_image is None:
193
  raise gr.Error("Please choose an image first.")
194
+
195
+ if input_image.shape[-1] != 3:
196
+ raise ValueError(f"Input image must have 3 channels (RGB). Current shape: {input_image.shape}")
197
+
198
+ # Resize the image if it hasn't the shape (100, 100, 3)
199
+ if input_image.shape != (100 , 100, 3):
200
+ print(f"Before: {type(input_image)=}, {input_image.shape=}")
201
+ input_image_pil = Image.fromarray(input_image)
202
+ # Resize the image
203
+ input_image_pil = input_image_pil.resize((100, 100))
204
+ input_image = numpy.array(input_image_pil)
205
+ print(f"After: {type(input_image)=}, {input_image.shape=}")
206
 
207
  # Retrieve the client API
208
  client = get_client(user_id, filter_name)
 
494
  )
495
 
496
  output_image = gr.Image(
497
+ label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
498
  interactive=False,
499
  height=256,
500
  width=256,
 
525
  # Button to send the encodings to the server using post method
526
  get_output_button.click(
527
  get_output,
528
+ inputs=[user_id, filter_name],
529
  outputs=[encrypted_output_representation]
530
  )
531