kcelia commited on
Commit
9b80a97
1 Parent(s): be82820

chore: version 4

Browse files
Files changed (2) hide show
  1. app.py +128 -60
  2. utils.py +1 -1
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import subprocess
2
  import time
3
- from pathlib import Path
4
- from typing import Dict, List, Tuple, Union
5
 
6
  import gradio as gr
7
  import numpy as np
@@ -28,15 +27,59 @@ from concrete.ml.deployment import FHEModelClient
28
  subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
29
  time.sleep(3)
30
 
31
- # pylint: disable=c-extension-no-member
32
- def is_nan(inputs):
 
 
 
 
 
 
 
 
 
 
 
33
  return inputs is None or (inputs is not None and len(inputs) < 1)
34
 
35
 
36
- def get_user_symptoms_from_checkboxgroup(checkbox_symptoms) -> np.array:
 
 
37
 
38
- symptoms_vector = {key: 0 for key in valid_columns}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  for pretty_symptom in checkbox_symptoms:
41
  original_symptom = "_".join((pretty_symptom.lower().split(" ")))
42
  if original_symptom not in symptoms_vector.keys():
@@ -53,20 +96,16 @@ def get_user_symptoms_from_checkboxgroup(checkbox_symptoms) -> np.array:
53
  return user_symptoms_vect
54
 
55
 
56
- def fill_in_fn(default_disease, *checkbox_symptoms):
57
-
58
- df = pd.read_csv(TRAINING_FILENAME)
59
- df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease]
60
- symptoms = pretty_print(df_filtred.columns[df_filtred.eq(1).any()].to_list())
61
-
62
- if any(lst for lst in checkbox_symptoms if lst):
63
- for sublist in checkbox_symptoms:
64
- symptoms.extend(sublist)
65
-
66
- return {box: symptoms for box in check_boxes}
67
 
 
 
68
 
69
- def get_features(*checked_symptoms):
 
 
70
  if not any(lst for lst in checked_symptoms if lst):
71
  return {
72
  error_box1: gr.update(
@@ -118,7 +157,7 @@ def key_gen_fn(user_symptoms: List[str]) -> Dict:
118
  with evaluation_key_path.open("wb") as f:
119
  f.write(serialized_evaluation_keys)
120
 
121
- serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
122
 
123
  return {
124
  error_box2: gr.update(visible=False),
@@ -128,7 +167,14 @@ def key_gen_fn(user_symptoms: List[str]) -> Dict:
128
  }
129
 
130
 
131
- def encrypt_fn(user_symptoms, user_id):
 
 
 
 
 
 
 
132
 
133
  if is_nan(user_id) or is_nan(user_symptoms):
134
  print("Error in encryption step: Provide your symptoms and generate the evaluation keys.")
@@ -164,7 +210,7 @@ def encrypt_fn(user_symptoms, user_id):
164
  }
165
 
166
 
167
- def send_input_fn(user_id, user_symptoms):
168
  """Send the encrypted data and the evaluation key to the server.
169
 
170
  Args:
@@ -215,7 +261,7 @@ def send_input_fn(user_id, user_symptoms):
215
  ("files", open(evaluation_key_path, "rb")),
216
  ]
217
 
218
- # Send the encrypted input image and evaluation key to the server
219
  url = SERVER_URL + "send_input"
220
  with requests.post(
221
  url=url,
@@ -226,12 +272,11 @@ def send_input_fn(user_id, user_symptoms):
226
  return {error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent"}
227
 
228
 
229
- def run_fhe_fn(user_id):
230
- """Send the encrypted input image as well as the evaluation key to the server.
231
 
232
  Args:
233
  user_id (int): The current user's ID.
234
- filter_name (str): The current filter to consider.
235
  """
236
  if is_nan(user_id): # or is_nan(user_symptoms):
237
  return {
@@ -246,7 +291,7 @@ def run_fhe_fn(user_id):
246
  "user_id": user_id,
247
  }
248
 
249
- # Trigger the FHE execution on the encrypted image previously sent
250
 
251
  url = SERVER_URL + "run_fhe"
252
 
@@ -268,7 +313,14 @@ def run_fhe_fn(user_id):
268
  }
269
 
270
 
271
- def get_output_fn(user_id, user_symptoms):
 
 
 
 
 
 
 
272
  if is_nan(user_id) or is_nan(user_symptoms):
273
  return {
274
  error_box6: gr.update(
@@ -278,11 +330,13 @@ def get_output_fn(user_id, user_symptoms):
278
  )
279
  }
280
 
 
 
281
  data = {
282
  "user_id": user_id,
283
  }
284
 
285
- # Retrieve the encrypted output image
286
  url = SERVER_URL + "get_output"
287
  with requests.post(
288
  url=url,
@@ -302,7 +356,17 @@ def get_output_fn(user_id, user_symptoms):
302
  return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"}
303
 
304
 
305
- def decrypt_fn(user_id, user_symptoms):
 
 
 
 
 
 
 
 
 
 
306
  if is_nan(user_id) or is_nan(user_symptoms):
307
  return {
308
  error_box7: gr.update(
@@ -343,13 +407,14 @@ def decrypt_fn(user_id, user_symptoms):
343
  }
344
 
345
 
 
346
  def clear_all_btn():
347
  """Clear all the box outputs."""
348
 
349
  clean_directory()
350
 
351
  return {
352
- disease_box: None,
353
  user_id_box: None,
354
  user_vect_box1: None,
355
  user_vect_box2: None,
@@ -382,10 +447,12 @@ CSS = """
382
  """
383
 
384
  if __name__ == "__main__":
 
385
  print("Starting demo ...")
 
386
  clean_directory()
387
 
388
- (_, X_train, X_test), (df_test, y_train, y_test) = load_data()
389
 
390
  valid_columns = X_train.columns.to_list()
391
 
@@ -411,7 +478,7 @@ if __name__ == "__main__":
411
  </p>
412
 
413
  <p align="center">
414
- <img width="100%" height="30%" src="https://raw.githubusercontent.com/kcelia/Img/main/HEALTHCARE PREDICTION USING MACHINE LEARNING WITH FULLY HOMOMORPHIC ENCRYPTION.png">
415
  </p>
416
  """
417
  )
@@ -430,8 +497,8 @@ if __name__ == "__main__":
430
  check_boxes = []
431
  for i, category in enumerate(SYMPTOMS_LIST):
432
  with gr.Accordion(
433
- pretty_print(category.keys()), open=True, elem_classes="feedback"
434
- ):
435
  check_box = gr.CheckboxGroup(
436
  pretty_print(category.values()),
437
  label=pretty_print(category.keys()),
@@ -442,31 +509,30 @@ if __name__ == "__main__":
442
  error_box1 = gr.Textbox(label="Error", visible=False)
443
 
444
  # Default disease, picked from the dataframe
445
- disease_box = gr.Dropdown(list(sorted(set(df_test["prognosis"]))), label="Disease:")
446
-
447
- disease_box.change(
448
- fn=fill_in_fn,
449
- inputs=[disease_box, *check_boxes],
450
- outputs=[*check_boxes],
451
- )
452
 
453
  # User symptom vector
454
- with gr.Row():
455
- user_vect_box1 = gr.Textbox(label="User Symptoms Vector:", interactive=False)
456
 
457
- with gr.Row():
458
- # Submit botton
459
- submit_button = gr.Button("Submit")
460
 
461
  with gr.Row():
462
  # Clear botton
463
  clear_button = gr.Button("Reset")
464
 
465
  submit_button.click(
466
- fn=get_features,
467
  inputs=[*check_boxes],
468
  outputs=[user_vect_box1, error_box1],
469
  )
 
470
  with gr.TabItem("2. Data Encryption") as encryption_tab:
471
  gr.Markdown("<span style='color:orange'>Client Side</span>")
472
  gr.Markdown("## Step 2: Generate the keys")
@@ -482,14 +548,13 @@ if __name__ == "__main__":
482
  with gr.Column(scale=1, min_width=600):
483
  key_len_box = gr.Textbox(label="Evaluation Key Size:", interactive=False)
484
 
485
- with gr.Row():
486
- # Evaluation key (truncated)
487
- with gr.Column(scale=2, min_width=600):
488
- key_box = gr.Textbox(
489
- label="Evaluation key (truncated):",
490
- max_lines=2,
491
- interactive=False,
492
- )
493
 
494
  gen_key_btn.click(
495
  key_gen_fn,
@@ -553,7 +618,7 @@ if __name__ == "__main__":
553
  outputs=[error_box4, srv_resp_send_data_box],
554
  )
555
 
556
- with gr.TabItem("3. Processing Data") as fhe_tab:
557
  gr.Markdown("<span style='color:orange'>Client Side</span>")
558
  gr.Markdown("## Step 5: Run the FHE evaluation")
559
 
@@ -569,8 +634,12 @@ if __name__ == "__main__":
569
  outputs=[fhe_execution_time_box, error_box5],
570
  )
571
 
 
 
 
 
572
  gr.Markdown(
573
- "## Step 6: Get the data from the <span style='color:orange'>Server</span>"
574
  )
575
 
576
  error_box6 = gr.Textbox(label="Error", visible=False)
@@ -589,8 +658,7 @@ if __name__ == "__main__":
589
  outputs=[srv_resp_retrieve_data_box, error_box6],
590
  )
591
 
592
- with gr.TabItem("4. Data Decryption") as decryption_tab:
593
- gr.Markdown("<span style='color:orange'>Client Side</span>")
594
  gr.Markdown("## Step 7: Decrypt the output")
595
 
596
  decrypt_target_btn = gr.Button("Decrypt the output")
@@ -608,7 +676,7 @@ if __name__ == "__main__":
608
  outputs=[
609
  user_vect_box1,
610
  user_vect_box2,
611
- disease_box,
612
  error_box1,
613
  error_box2,
614
  error_box3,
 
1
  import subprocess
2
  import time
3
+ from typing import Dict, List, Tuple
 
4
 
5
  import gradio as gr
6
  import numpy as np
 
27
  subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
28
  time.sleep(3)
29
 
30
+ # pylint: disable=c-extension-no-member,invalid-name
31
+
32
+
33
+ def is_nan(inputs) -> bool:
34
+ """
35
+ Check if the input is NaN.
36
+
37
+ Args:
38
+ inputs (any): The input to be checked.
39
+
40
+ Returns:
41
+ bool: True if the input is NaN or empty, False otherwise.
42
+ """
43
  return inputs is None or (inputs is not None and len(inputs) < 1)
44
 
45
 
46
+ # def fill_in_fn(default_disease: str, *checkbox_symptoms: Tuple[str]) -> Dict:
47
+ # """
48
+ # Fill in the gr.CheckBoxGroup list with the predefined symptoms of a selected default disease.
49
 
50
+ # Args:
51
+ # default_disease (str): The default disease
52
+ # *checkbox_symptoms (Tuple[str]): Tuple of selected symptoms
53
+
54
+ # Returns:
55
+ # dict: The updated gr.CheckBoxesGroup.
56
+ # """
57
+ # df = pd.read_csv(TRAINING_FILENAME)
58
+ # df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease]
59
+ # symptoms = pretty_print(df_filtred.columns[df_filtred.eq(1).any()].to_list())
60
+
61
+ # if any(lst for lst in checkbox_symptoms if lst):
62
+ # for sublist in checkbox_symptoms:
63
+ # symptoms.extend(sublist)
64
+
65
+ # return {box: symptoms for box in check_boxes}
66
 
67
+
68
+ def get_user_symptoms_from_checkboxgroup(checkbox_symptoms: List) -> np.array:
69
+ """
70
+ Convert the user symptoms into a binary vector representation.
71
+
72
+ Args:
73
+ checkbox_symptoms (list): A list of user symptoms.
74
+
75
+ Returns:
76
+ np.array: A binary vector representing the user's symptoms.
77
+
78
+ Raises:
79
+ KeyError: If a provided symptom is not recognized as a valid symptom.
80
+
81
+ """
82
+ symptoms_vector = {key: 0 for key in valid_columns}
83
  for pretty_symptom in checkbox_symptoms:
84
  original_symptom = "_".join((pretty_symptom.lower().split(" ")))
85
  if original_symptom not in symptoms_vector.keys():
 
96
  return user_symptoms_vect
97
 
98
 
99
+ def get_features_fn(*checked_symptoms: Tuple[str]) -> Dict:
100
+ """
101
+ Get vector features based on the selected symptoms.
 
 
 
 
 
 
 
 
102
 
103
+ Args:
104
+ checked_symptoms (Tuple[str]): User symptoms
105
 
106
+ Returns:
107
+ Dict: The encoded user vector symptoms.
108
+ """
109
  if not any(lst for lst in checked_symptoms if lst):
110
  return {
111
  error_box1: gr.update(
 
157
  with evaluation_key_path.open("wb") as f:
158
  f.write(serialized_evaluation_keys)
159
 
160
+ serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
161
 
162
  return {
163
  error_box2: gr.update(visible=False),
 
167
  }
168
 
169
 
170
+ def encrypt_fn(user_symptoms: np.ndarray, user_id: str) -> None:
171
+ """
172
+ Encrypt the user symptoms vector in the `Client Side`.
173
+
174
+ Args:
175
+ user_symptoms (List[str]): The vector symptoms provided by the user
176
+ user_id (user): The current user's ID
177
+ """
178
 
179
  if is_nan(user_id) or is_nan(user_symptoms):
180
  print("Error in encryption step: Provide your symptoms and generate the evaluation keys.")
 
210
  }
211
 
212
 
213
+ def send_input_fn(user_id: str, user_symptoms: np.ndarray) -> Dict:
214
  """Send the encrypted data and the evaluation key to the server.
215
 
216
  Args:
 
261
  ("files", open(evaluation_key_path, "rb")),
262
  ]
263
 
264
+ # Send the encrypted input and evaluation key to the server
265
  url = SERVER_URL + "send_input"
266
  with requests.post(
267
  url=url,
 
272
  return {error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent"}
273
 
274
 
275
+ def run_fhe_fn(user_id: str) -> Dict:
276
+ """Send the encrypted input as well as the evaluation key to the server.
277
 
278
  Args:
279
  user_id (int): The current user's ID.
 
280
  """
281
  if is_nan(user_id): # or is_nan(user_symptoms):
282
  return {
 
291
  "user_id": user_id,
292
  }
293
 
294
+ # Trigger the FHE execution on the encrypted previously sent
295
 
296
  url = SERVER_URL + "run_fhe"
297
 
 
313
  }
314
 
315
 
316
+ def get_output_fn(user_id: str, user_symptoms: np.ndarray) -> Dict:
317
+ """Retreive the encrypted data from the server.
318
+
319
+ Args:
320
+ user_id (int): The current user's ID
321
+ user_symptoms (numpy.ndarray): The user symptoms
322
+ """
323
+
324
  if is_nan(user_id) or is_nan(user_symptoms):
325
  return {
326
  error_box6: gr.update(
 
330
  )
331
  }
332
 
333
+
334
+
335
  data = {
336
  "user_id": user_id,
337
  }
338
 
339
+ # Retrieve the encrypted output
340
  url = SERVER_URL + "get_output"
341
  with requests.post(
342
  url=url,
 
356
  return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"}
357
 
358
 
359
+ def decrypt_fn(user_id: str, user_symptoms: np.ndarray) -> Dict:
360
+ """Dencrypt the data on the `Client Side`.
361
+
362
+ Args:
363
+ user_id (int): The current user's ID
364
+ user_symptoms (numpy.ndarray): The user symptoms
365
+
366
+ Returns:
367
+ Decrypted output
368
+ """
369
+
370
  if is_nan(user_id) or is_nan(user_symptoms):
371
  return {
372
  error_box7: gr.update(
 
407
  }
408
 
409
 
410
+
411
  def clear_all_btn():
412
  """Clear all the box outputs."""
413
 
414
  clean_directory()
415
 
416
  return {
417
+ # disease_box: None,
418
  user_id_box: None,
419
  user_vect_box1: None,
420
  user_vect_box2: None,
 
447
  """
448
 
449
  if __name__ == "__main__":
450
+
451
  print("Starting demo ...")
452
+
453
  clean_directory()
454
 
455
+ (X_train, X_test), (y_train, y_test) = load_data()
456
 
457
  valid_columns = X_train.columns.to_list()
458
 
 
478
  </p>
479
 
480
  <p align="center">
481
+ <img width="100%" height="30%" src="https://raw.githubusercontent.com/kcelia/Img/main/health_prediction_img.png">
482
  </p>
483
  """
484
  )
 
497
  check_boxes = []
498
  for i, category in enumerate(SYMPTOMS_LIST):
499
  with gr.Accordion(
500
+ pretty_print(category.keys()), open=False, elem_classes="feedback"
501
+ ) as accordion:
502
  check_box = gr.CheckboxGroup(
503
  pretty_print(category.values()),
504
  label=pretty_print(category.keys()),
 
509
  error_box1 = gr.Textbox(label="Error", visible=False)
510
 
511
  # Default disease, picked from the dataframe
512
+ # disease_box = gr.Dropdown(list(sorted(set(df_test["prognosis"]))),
513
+ # label="Disease:")
514
+ # disease_box.change(
515
+ # fn=fill_in_fn,
516
+ # inputs=[disease_box, *check_boxes],
517
+ # outputs=[*check_boxes],
518
+ # )
519
 
520
  # User symptom vector
521
+ user_vect_box1 = gr.Textbox(label="User Symptoms Vector:", interactive=False)
 
522
 
523
+ # Submit botton
524
+ submit_button = gr.Button("Submit")
 
525
 
526
  with gr.Row():
527
  # Clear botton
528
  clear_button = gr.Button("Reset")
529
 
530
  submit_button.click(
531
+ fn=get_features_fn,
532
  inputs=[*check_boxes],
533
  outputs=[user_vect_box1, error_box1],
534
  )
535
+
536
  with gr.TabItem("2. Data Encryption") as encryption_tab:
537
  gr.Markdown("<span style='color:orange'>Client Side</span>")
538
  gr.Markdown("## Step 2: Generate the keys")
 
548
  with gr.Column(scale=1, min_width=600):
549
  key_len_box = gr.Textbox(label="Evaluation Key Size:", interactive=False)
550
 
551
+ # Evaluation key (truncated)
552
+ with gr.Column(scale=2, min_width=600):
553
+ key_box = gr.Textbox(
554
+ label="Evaluation key (truncated):",
555
+ max_lines=3,
556
+ interactive=False,
557
+ )
 
558
 
559
  gen_key_btn.click(
560
  key_gen_fn,
 
618
  outputs=[error_box4, srv_resp_send_data_box],
619
  )
620
 
621
+ with gr.TabItem("3. FHE execution") as fhe_tab:
622
  gr.Markdown("<span style='color:orange'>Client Side</span>")
623
  gr.Markdown("## Step 5: Run the FHE evaluation")
624
 
 
634
  outputs=[fhe_execution_time_box, error_box5],
635
  )
636
 
637
+ with gr.TabItem("4. Data Decryption") as decryption_tab:
638
+
639
+ gr.Markdown("<span style='color:orange'>Client Side</span>")
640
+
641
  gr.Markdown(
642
+ "## Step 6: Get the data from the <span style='color:orange'>Server Side</span>"
643
  )
644
 
645
  error_box6 = gr.Textbox(label="Error", visible=False)
 
658
  outputs=[srv_resp_retrieve_data_box, error_box6],
659
  )
660
 
661
+
 
662
  gr.Markdown("## Step 7: Decrypt the output")
663
 
664
  decrypt_target_btn = gr.Button("Decrypt the output")
 
676
  outputs=[
677
  user_vect_box1,
678
  user_vect_box2,
679
+ # disease_box,
680
  error_box1,
681
  error_box2,
682
  error_box3,
utils.py CHANGED
@@ -113,7 +113,7 @@ def load_data() -> Tuple[pandas.DataFrame, pandas.DataFrame, numpy.ndarray]:
113
  y_test = df_test[TARGET_COLUMNS[0]]
114
  X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
115
 
116
- return (df_train, X_train, X_test), (df_test, y_train, y_test)
117
 
118
 
119
  def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
 
113
  y_test = df_test[TARGET_COLUMNS[0]]
114
  X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
115
 
116
+ return (X_train, X_test), (y_train, y_test)
117
 
118
 
119
  def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):