kcelia commited on
Commit
f5aa6c7
1 Parent(s): 3d845fb

chore: update

Browse files
Files changed (2) hide show
  1. app.py +93 -45
  2. server.py +49 -44
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import shutil
3
  import subprocess
 
4
  from pathlib import Path
5
- from time import time
6
  from typing import List, Tuple, Union
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import pandas as pd
 
11
  from preprocessing import pretty_print
12
  from symptoms_categories import SYMPTOMS_LIST
13
 
@@ -16,16 +17,21 @@ from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
16
  from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
17
 
18
  INPUT_BROWSER_LIMIT = 635
19
-
20
  # This repository's main necessary folders
21
  REPO_DIR = Path(__file__).parent
22
  MODEL_PATH = REPO_DIR / "client_folder"
23
  KEYS_PATH = REPO_DIR / ".fhe_keys"
24
- CLIENT_PATH = MODEL_PATH / "client.zip"
25
- SERVER_PATH = MODEL_PATH / "server.zip"
 
 
 
 
 
26
 
27
- # subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
28
- # time.sleep(3)
29
 
30
 
31
  def clean_directory():
@@ -169,8 +175,8 @@ def key_gen_fn(user_symptoms):
169
 
170
  # np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
171
  evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
172
- with evaluation_key_path.open("wb") as evaluation_key_file:
173
- evaluation_key_file.write(serialized_evaluation_keys)
174
 
175
  serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
176
 
@@ -200,7 +206,7 @@ def encrypt_fn(user_symptoms, user_id):
200
 
201
  quant_user_symptoms = client.model.quantize_input(user_symptoms)
202
  encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
203
-
204
  encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
205
 
206
  with encrypted_input_path.open("wb") as f:
@@ -227,46 +233,69 @@ def encrypt_fn(user_symptoms, user_id):
227
  }
228
 
229
 
230
- # def send_input(user_id, user_symptoms):
231
- # """Send the encrypted input image as well as the evaluation key to the server.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # Args:
234
- # user_id (int): The current user's ID.
235
- # filter_name (str): The current filter to consider.
236
- # """
237
- # # Get the evaluation key path
238
 
 
 
239
 
240
- # evaluation_key_path = get_client_file_path("evaluation_key", user_id, filter_name)
 
 
241
 
242
- # if user_id == "" or not evaluation_key_path.is_file():
243
- # raise gr.Error("Please generate the private key first.")
244
 
245
- # encrypted_input_path = get_client_file_path("encrypted_image", user_id, filter_name)
246
- # encrypted_symptoms_path = KEYS_PATH / f"{user_id}" / "encrypted_symtoms"
 
 
 
 
247
 
248
- # if not encrypted_input_path.is_file():
249
- # raise gr.Error("Please generate the private key and then encrypt an image first.")
 
 
 
250
 
251
- # # Define the data and files to post
252
- # data = {
253
- # "user_id": user_id,
254
- # "filter": filter_name,
255
- # }
256
 
257
- # files = [
258
- # ("files", open(encrypted_input_path, "rb")),
259
- # ("files", open(evaluation_key_path, "rb")),
260
- # ]
 
 
 
 
261
 
262
- # # Send the encrypted input image and evaluation key to the server
263
- # url = SERVER_URL + "send_input"
264
- # with requests.post(
265
- # url=url,
266
- # data=data,
267
- # files=files,
268
- # ) as response:
269
- # return response.ok
270
 
271
 
272
  # def decrypt_prediction(encrypted_quantized_vect, user_id):
@@ -277,11 +306,13 @@ def encrypt_fn(user_symptoms, user_id):
277
  # return predictions
278
 
279
 
 
280
 
 
281
 
282
- def clear_all_btn():
283
  return {
284
  box_default: None,
 
285
  user_id_textbox: None,
286
  eval_key_textbox: None,
287
  quant_vect_textbox: None,
@@ -291,13 +322,14 @@ def clear_all_btn():
291
  error_box_1: gr.update(visible=False),
292
  error_box_2: gr.update(visible=False),
293
  error_box_3: gr.update(visible=False),
 
 
294
  **{box: None for box in check_boxes},
295
  }
296
 
297
 
298
  if __name__ == "__main__":
299
  print("Starting demo ...")
300
-
301
 
302
  (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
303
 
@@ -423,7 +455,7 @@ if __name__ == "__main__":
423
  gr.Markdown("# Step 3: Encode the message with the private key")
424
  gr.Markdown("Client side")
425
 
426
- encrypt_btn = gr.Button("Encode the message with the private key and send it to the server")
427
 
428
  error_box_3 = gr.Textbox(label="Error", visible=False)
429
 
@@ -452,12 +484,25 @@ if __name__ == "__main__":
452
  outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox, error_box_3],
453
  )
454
 
455
- gr.Markdown("# Step 4: Run the FHE evaluation")
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  gr.Markdown("Server side")
457
 
458
  run_fhe = gr.Button("Run the FHE evaluation")
459
 
460
- gr.Markdown("# Step 5: Decrypt the sentiment")
461
  gr.Markdown("Server side")
462
 
463
  decrypt_target_botton = gr.Button("Decrypt the sentiment")
@@ -478,10 +523,13 @@ if __name__ == "__main__":
478
  error_box_1,
479
  error_box_2,
480
  error_box_3,
 
 
481
  user_id_textbox,
482
  eval_key_textbox,
483
  quant_vect_textbox,
484
  user_vector_textbox,
 
485
  eval_key_len_textbox,
486
  encrypted_vect_textbox,
487
  *check_boxes,
 
1
  import os
2
  import shutil
3
  import subprocess
4
+ import time
5
  from pathlib import Path
 
6
  from typing import List, Tuple, Union
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import pandas as pd
11
+ import requests
12
  from preprocessing import pretty_print
13
  from symptoms_categories import SYMPTOMS_LIST
14
 
 
17
  from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
18
 
19
  INPUT_BROWSER_LIMIT = 635
20
+ SERVER_URL = "http://localhost:8000/"
21
  # This repository's main necessary folders
22
  REPO_DIR = Path(__file__).parent
23
  MODEL_PATH = REPO_DIR / "client_folder"
24
  KEYS_PATH = REPO_DIR / ".fhe_keys"
25
+ CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
26
+ SERVER_TMP_PATH = REPO_DIR / "server_tmp"
27
+
28
+ # Create the necessary folders
29
+ KEYS_PATH.mkdir(exist_ok=True)
30
+ CLIENT_TMP_PATH.mkdir(exist_ok=True)
31
+ SERVER_TMP_PATH.mkdir(exist_ok=True)
32
 
33
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
34
+ time.sleep(3)
35
 
36
 
37
  def clean_directory():
 
175
 
176
  # np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
177
  evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
178
+ with evaluation_key_path.open("wb") as f:
179
+ f.write(serialized_evaluation_keys)
180
 
181
  serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
182
 
 
206
 
207
  quant_user_symptoms = client.model.quantize_input(user_symptoms)
208
  encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
209
+ assert isinstance(encrypted_quantized_user_symptoms, bytes)
210
  encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
211
 
212
  with encrypted_input_path.open("wb") as f:
 
233
  }
234
 
235
 
236
+ def is_nan(input):
237
+ return input is None or (input is not None and len(input) < 1)
238
+
239
+
240
+ def send_input_fn(user_id, user_symptoms):
241
+ """Send the encrypted input image as well as the evaluation key to the server.
242
+
243
+ Args:
244
+ user_id (int): The current user's ID.
245
+ filter_name (str): The current filter to consider.
246
+ """
247
+ # Get the evaluation key path
248
+
249
+ if is_nan(user_id) or is_nan(user_symptoms):
250
+ return {
251
+ error_box_4: gr.update(
252
+ visible=True,
253
+ value="Please ensure that the evaluation key has been generated "
254
+ "and the symptoms have been submitted before sending the data to the server",
255
+ )
256
+ }
257
 
258
+ evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
259
+ encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
 
 
 
260
 
261
+ if not evaluation_key_path.is_file():
262
+ print(f"Please generate the private key, first.{evaluation_key_path.is_file()=}")
263
 
264
+ return {
265
+ error_box_4: gr.update(visible=True, value="Please generate the private key first.")
266
+ }
267
 
268
+ if not encrypted_input_path.is_file():
269
+ print(f"Please submit your symptoms, first.{encrypted_input_path.is_file()=}")
270
 
271
+ return {
272
+ error_box_4: gr.update(
273
+ visible=True,
274
+ value="Please generate the private key and then encrypt an image first.",
275
+ )
276
+ }
277
 
278
+ # Define the data and files to post
279
+ data = {
280
+ "user_id": user_id,
281
+ "filter": user_symptoms,
282
+ }
283
 
284
+ files = [
285
+ ("files", open(encrypted_input_path, "rb")),
286
+ ("files", open(evaluation_key_path, "rb")),
287
+ ]
 
288
 
289
+ # Send the encrypted input image and evaluation key to the server
290
+ url = SERVER_URL + "send_input"
291
+ with requests.post(
292
+ url=url,
293
+ data=data,
294
+ files=files,
295
+ ) as response:
296
+ print(f"response.ok: {response.ok}")
297
 
298
+ return {error_box_4: gr.update(visible=False), server_response_box: gr.update(visible=True)}
 
 
 
 
 
 
 
299
 
300
 
301
  # def decrypt_prediction(encrypted_quantized_vect, user_id):
 
306
  # return predictions
307
 
308
 
309
+ def clear_all_btn():
310
 
311
+ clean_directory()
312
 
 
313
  return {
314
  box_default: None,
315
+ vect_textbox: None,
316
  user_id_textbox: None,
317
  eval_key_textbox: None,
318
  quant_vect_textbox: None,
 
322
  error_box_1: gr.update(visible=False),
323
  error_box_2: gr.update(visible=False),
324
  error_box_3: gr.update(visible=False),
325
+ error_box_4: gr.update(visible=False),
326
+ server_response_box: gr.update(visible=False),
327
  **{box: None for box in check_boxes},
328
  }
329
 
330
 
331
  if __name__ == "__main__":
332
  print("Starting demo ...")
 
333
 
334
  (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
335
 
 
455
  gr.Markdown("# Step 3: Encode the message with the private key")
456
  gr.Markdown("Client side")
457
 
458
+ encrypt_btn = gr.Button("Encode the message with the private key")
459
 
460
  error_box_3 = gr.Textbox(label="Error", visible=False)
461
 
 
484
  outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox, error_box_3],
485
  )
486
 
487
+ gr.Markdown("# Step 4: Send the encrypted data to the server.")
488
+ gr.Markdown("Client side")
489
+
490
+ send_input_btn = gr.Button("Send the encrypted data to the server..")
491
+ error_box_4 = gr.Textbox(label="Error", visible=False)
492
+ server_response_box = gr.Textbox(value="Data sent", visible=False, show_label=False)
493
+
494
+ send_input_btn.click(
495
+ send_input_fn,
496
+ inputs=[user_id_textbox, user_vector_textbox],
497
+ outputs=[error_box_4, server_response_box],
498
+ )
499
+
500
+ gr.Markdown("# Step 5: Run the FHE evaluation")
501
  gr.Markdown("Server side")
502
 
503
  run_fhe = gr.Button("Run the FHE evaluation")
504
 
505
+ gr.Markdown("# Step 6: Decrypt the sentiment")
506
  gr.Markdown("Server side")
507
 
508
  decrypt_target_botton = gr.Button("Decrypt the sentiment")
 
523
  error_box_1,
524
  error_box_2,
525
  error_box_3,
526
+ error_box_4,
527
+ vect_textbox,
528
  user_id_textbox,
529
  eval_key_textbox,
530
  quant_vect_textbox,
531
  user_vector_textbox,
532
+ server_response_box,
533
  eval_key_len_textbox,
534
  encrypted_vect_textbox,
535
  *check_boxes,
server.py CHANGED
@@ -9,6 +9,11 @@ from fastapi.responses import JSONResponse, Response
9
 
10
  from concrete.ml.deployment import FHEModelServer
11
 
 
 
 
 
 
12
  # Initialize an instance of FastAPI
13
  app = FastAPI()
14
 
@@ -29,65 +34,65 @@ def send_input(
29
  filter: str = Form(),
30
  files: List[UploadFile] = File(),
31
  ):
 
32
  """Send the inputs to the server."""
33
  # Retrieve the encrypted input image and the evaluation key paths
34
- encrypted_image_path = 0 # Tcurrent_dir("encrypted_image", user_id, filter)
35
- evaluation_key_path = current_dir / ".fhe_keys/{user_id}"
36
 
37
- # Write the files using the above paths
38
- with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open(
39
  "wb"
40
  ) as evaluation_key:
41
- encrypted_image.write(files[0].file.read())
42
  evaluation_key.write(files[1].file.read())
43
 
44
 
45
- @app.post("/run_fhe")
46
- def run_fhe(
47
- user_id: str = Form(),
48
- filter: str = Form(),
49
- ):
50
- """Execute the filter on the encrypted input image using FHE."""
51
- # Retrieve the encrypted input image and the evaluation key paths
52
- encrypted_image_path = get_server_file_path("encrypted_image", user_id, filter)
53
- evaluation_key_path = get_server_file_path("evaluation_key", user_id, filter)
54
 
55
- # Read the files using the above paths
56
- with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
57
- "rb"
58
- ) as evaluation_key_file:
59
- encrypted_image = encrypted_image_file.read()
60
- evaluation_key = evaluation_key_file.read()
 
 
 
61
 
62
- # Load the FHE server
63
- fhe_server = FHEServer(FILTERS_PATH / f"{filter}/deployment")
 
 
 
 
64
 
65
- # Run the FHE execution
66
- start = time.time()
67
- encrypted_output_image = fhe_server.run(encrypted_image, evaluation_key)
68
- fhe_execution_time = round(time.time() - start, 2)
69
 
70
- # Retrieve the encrypted output image path
71
- encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
 
 
72
 
73
- # Write the file using the above path
74
- with encrypted_output_path.open("wb") as encrypted_output:
75
- encrypted_output.write(encrypted_output_image)
76
 
77
- return JSONResponse(content=fhe_execution_time)
 
 
78
 
 
79
 
80
- @app.post("/get_output")
81
- def get_output(
82
- user_id: str = Form(),
83
- filter: str = Form(),
84
- ):
85
- """Retrieve the encrypted output image."""
86
- # Retrieve the encrypted output image path
87
- encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
88
 
89
- # Read the file using the above path
90
- with encrypted_output_path.open("rb") as encrypted_output_file:
91
- encrypted_output = encrypted_output_file.read()
 
 
 
 
 
92
 
93
- return Response(encrypted_output)
 
 
 
9
 
10
  from concrete.ml.deployment import FHEModelServer
11
 
12
+ REPO_DIR = Path(__file__).parent
13
+ KEYS_PATH = REPO_DIR / ".fhe_keys"
14
+ MODEL_PATH = REPO_DIR / "client_folder"
15
+
16
+ SERVER_TMP_PATH = REPO_DIR / "server_tmp"
17
  # Initialize an instance of FastAPI
18
  app = FastAPI()
19
 
 
34
  filter: str = Form(),
35
  files: List[UploadFile] = File(),
36
  ):
37
+
38
  """Send the inputs to the server."""
39
  # Retrieve the encrypted input image and the evaluation key paths
40
+ evaluation_key_path = SERVER_TMP_PATH / f"{user_id}_valuation_key"
41
+ encrypted_input_path = SERVER_TMP_PATH / f"{user_id}_encrypted_symptoms"
42
 
43
+ # # Write the files using the above paths
44
+ with encrypted_input_path.open("wb") as encrypted_input, evaluation_key_path.open(
45
  "wb"
46
  ) as evaluation_key:
47
+ encrypted_input.write(files[0].file.read())
48
  evaluation_key.write(files[1].file.read())
49
 
50
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # @app.post("/run_fhe")
53
+ # def run_fhe(
54
+ # user_id: str = Form(),
55
+ # filter: str = Form(),
56
+ # ):
57
+ # """Execute the filter on the encrypted input image using FHE."""
58
+ # Retrieve the encrypted input image and the evaluation key paths
59
+ # encrypted_image_path = get_server_file_path("encrypted_image", user_id, filter)
60
+ # evaluation_key_path = get_server_file_path("evaluation_key", user_id, filter)
61
 
62
+ # Read the files using the above paths
63
+ # with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
64
+ # "rb"
65
+ # ) as evaluation_key_file:
66
+ # encrypted_image = encrypted_image_file.read()
67
+ # evaluation_key = evaluation_key_file.read()
68
 
69
+ # Load the FHE server
70
+ # fhe_server = FHEServer(FILTERS_PATH / f"{filter}/deployment")
 
 
71
 
72
+ # Run the FHE execution
73
+ # start = time.time()
74
+ # encrypted_output_image = fhe_server.run(encrypted_image, evaluation_key)
75
+ # fhe_execution_time = round(time.time() - start, 2)
76
 
77
+ # Retrieve the encrypted output image path
78
+ # encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
 
79
 
80
+ # Write the file using the above path
81
+ # with encrypted_output_path.open("wb") as encrypted_output:
82
+ # encrypted_output.write(encrypted_output_image)
83
 
84
+ # return JSONResponse(content=fhe_execution_time)
85
 
 
 
 
 
 
 
 
 
86
 
87
+ # @app.post("/get_output")
88
+ # def get_output(
89
+ # user_id: str = Form(),
90
+ # filter: str = Form(),
91
+ # ):
92
+ # """Retrieve the encrypted output image."""
93
+ # Retrieve the encrypted output image path
94
+ # encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
95
 
96
+ # Read the file using the above path
97
+ # with encrypted_output_path.open("rb") as encrypted_output_file:
98
+ # encrypted_output = encrypted_output_file.read()