zainmushtaq54 commited on
Commit
e462d07
·
verified ·
1 Parent(s): 76ddc61

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitattributes +4 -35
  2. .gitignore +7 -0
  3. README.md +61 -12
  4. app.py +771 -0
  5. dev.py +52 -0
  6. healthcare_prediction.jpg +0 -0
  7. requirements.txt +4 -0
  8. server.py +100 -0
  9. symptoms_categories.py +197 -0
  10. utils.py +144 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pkl filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.extension filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints
3
+
4
+ .venv
5
+ deployment_files/.*
6
+ deployment_files/client_dir/
7
+ deployment_files/server_dir/
README.md CHANGED
@@ -1,12 +1,61 @@
1
- ---
2
- title: Health
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Health Prediction On Encrypted Data Using Fully Homomorphic Encryption
3
+ emoji: 🩺😷
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: true
10
+ tags:
11
+ - FHE
12
+ - PPML
13
+ - privacy
14
+ - privacy preserving machine learning
15
+ - image processing
16
+ - homomorphic encryption
17
+ - security
18
+ python_version: 3.10.6
19
+ ---
20
+
21
+ # Healthcare prediction using FHE
22
+
23
+ ## Running the application on your machine
24
+
25
+ From this directory, i.e., `health_prediction`, you can proceed with the following steps.
26
+
27
+ ### Do once
28
+
29
+ First, create a virtual env and activate it:
30
+
31
+ <!--pytest-codeblocks:skip-->
32
+
33
+ ```bash
34
+ python3 -m venv .venv
35
+ source .venv/bin/activate
36
+ ```
37
+
38
+ Then, install required packages:
39
+
40
+ <!--pytest-codeblocks:skip-->
41
+
42
+ ```bash
43
+ pip3 install pip --upgrade
44
+ pip3 install -U pip wheel setuptools --ignore-installed
45
+ pip3 install -r requirements.txt --ignore-installed
46
+ ```
47
+
48
+ ## Run the following steps each time you relaunch the application
49
+
50
+ In a terminal, run:
51
+
52
+ <!--pytest-codeblocks:skip-->
53
+
54
+ ```bash
55
+ source .venv/bin/activate
56
+ python3 app.py
57
+ ```
58
+
59
+ ## Interacting with the application
60
+
61
+ Open the given URL link (search for a line like `Running on local URL: http://127.0.0.1:8888/`).
app.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+ from typing import Dict, List, Tuple
4
+
5
+ import gradio as gr # pylint: disable=import-error
6
+ import numpy as np
7
+ import pandas as pd
8
+ import requests
9
+ from symptoms_categories import SYMPTOMS_LIST
10
+ from utils import (
11
+ CLIENT_DIR,
12
+ CURRENT_DIR,
13
+ DEPLOYMENT_DIR,
14
+ INPUT_BROWSER_LIMIT,
15
+ KEYS_DIR,
16
+ SERVER_URL,
17
+ TARGET_COLUMNS,
18
+ TRAINING_FILENAME,
19
+ clean_directory,
20
+ get_disease_name,
21
+ load_data,
22
+ pretty_print,
23
+ )
24
+
25
+ from concrete.ml.deployment import FHEModelClient
26
+
27
+ subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
28
+ time.sleep(3)
29
+
30
+ # pylint: disable=c-extension-no-member,invalid-name
31
+
32
+
33
+ def is_none(obj) -> bool:
34
+ """
35
+ Check if the object is None.
36
+
37
+ Args:
38
+ obj (any): The input to be checked.
39
+
40
+ Returns:
41
+ bool: True if the object is None or empty, False otherwise.
42
+ """
43
+ return obj is None or (obj is not None and len(obj) < 1)
44
+
45
+
46
+ def display_default_symptoms_fn(default_disease: str) -> Dict:
47
+ """
48
+ Displays the symptoms of a given existing disease.
49
+
50
+ Args:
51
+ default_disease (str): Disease
52
+ Returns:
53
+ Dict: The according symptoms
54
+ """
55
+ df = pd.read_csv(TRAINING_FILENAME)
56
+ df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease]
57
+
58
+ return {
59
+ default_symptoms: gr.update(
60
+ visible=True,
61
+ value=pretty_print(
62
+ df_filtred.columns[df_filtred.eq(1).any()].to_list(), delimiter=", "
63
+ ),
64
+ )
65
+ }
66
+
67
+
68
+ def get_user_symptoms_from_checkboxgroup(checkbox_symptoms: List) -> np.array:
69
+ """
70
+ Convert the user symptoms into a binary vector representation.
71
+
72
+ Args:
73
+ checkbox_symptoms (List): A list of user symptoms.
74
+
75
+ Returns:
76
+ np.array: A binary vector representing the user's symptoms.
77
+
78
+ Raises:
79
+ KeyError: If a provided symptom is not recognized as a valid symptom.
80
+
81
+ """
82
+ symptoms_vector = {key: 0 for key in valid_symptoms}
83
+ for pretty_symptom in checkbox_symptoms:
84
+ original_symptom = "_".join((pretty_symptom.lower().split(" ")))
85
+ if original_symptom not in symptoms_vector.keys():
86
+ raise KeyError(
87
+ f"The symptom '{original_symptom}' you provided is not recognized as a valid "
88
+ f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}"
89
+ )
90
+ symptoms_vector[original_symptom] = 1
91
+
92
+ user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :]
93
+
94
+ assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten())
95
+
96
+ return user_symptoms_vect
97
+
98
+
99
+ def get_features_fn(*checked_symptoms: Tuple[str]) -> Dict:
100
+ """
101
+ Get vector features based on the selected symptoms.
102
+
103
+ Args:
104
+ checked_symptoms (Tuple[str]): User symptoms
105
+
106
+ Returns:
107
+ Dict: The encoded user vector symptoms.
108
+ """
109
+ if not any(lst for lst in checked_symptoms if lst):
110
+ return {
111
+ error_box1: gr.update(visible=True, value="⚠️ Please provide your chief complaints."),
112
+ }
113
+
114
+ if len(pretty_print(checked_symptoms)) < 5:
115
+ print("Provide at least 5 symptoms.")
116
+ return {
117
+ error_box1: gr.update(visible=True, value="⚠️ Provide at least 5 symptoms"),
118
+ one_hot_vect: None,
119
+ }
120
+
121
+ return {
122
+ error_box1: gr.update(visible=False),
123
+ one_hot_vect: gr.update(
124
+ visible=False,
125
+ value=get_user_symptoms_from_checkboxgroup(pretty_print(checked_symptoms)),
126
+ ),
127
+ submit_btn: gr.update(value="Data submitted ✅"),
128
+ }
129
+
130
+
131
+ def key_gen_fn(user_symptoms: List[str]) -> Dict:
132
+ """
133
+ Generate keys for a given user.
134
+
135
+ Args:
136
+ user_symptoms (List[str]): The vector symptoms provided by the user.
137
+
138
+ Returns:
139
+ dict: A dictionary containing the generated keys and related information.
140
+
141
+ """
142
+ clean_directory()
143
+
144
+ if is_none(user_symptoms):
145
+ print("Error: Please submit your symptoms or select a default disease.")
146
+ return {
147
+ error_box2: gr.update(visible=True, value="⚠️ Please submit your symptoms first."),
148
+ }
149
+
150
+ # Generate a random user ID
151
+ user_id = np.random.randint(0, 2**32)
152
+ print(f"Your user ID is: {user_id}....")
153
+
154
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
155
+ client.load()
156
+
157
+ # Creates the private and evaluation keys on the client side
158
+ client.generate_private_and_evaluation_keys()
159
+
160
+ # Get the serialized evaluation keys
161
+ serialized_evaluation_keys = client.get_serialized_evaluation_keys()
162
+ assert isinstance(serialized_evaluation_keys, bytes)
163
+
164
+ # Save the evaluation key
165
+ evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
166
+ with evaluation_key_path.open("wb") as f:
167
+ f.write(serialized_evaluation_keys)
168
+
169
+ serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
170
+
171
+ return {
172
+ error_box2: gr.update(visible=False),
173
+ key_box: gr.update(visible=False, value=serialized_evaluation_keys_shorten_hex),
174
+ user_id_box: gr.update(visible=False, value=user_id),
175
+ key_len_box: gr.update(
176
+ visible=False, value=f"{len(serialized_evaluation_keys) / (10**6):.2f} MB"
177
+ ),
178
+ gen_key_btn: gr.update(value="Keys have been generated ✅")
179
+ }
180
+
181
+
182
+ def encrypt_fn(user_symptoms: np.ndarray, user_id: str) -> None:
183
+ """
184
+ Encrypt the user symptoms vector in the `Client Side`.
185
+
186
+ Args:
187
+ user_symptoms (List[str]): The vector symptoms provided by the user
188
+ user_id (user): The current user's ID
189
+ """
190
+
191
+ if is_none(user_id) or is_none(user_symptoms):
192
+ print("Error in encryption step: Provide your symptoms and generate the evaluation keys.")
193
+ return {
194
+ error_box3: gr.update(
195
+ visible=True,
196
+ value="⚠️ Please ensure that your symptoms have been submitted and "
197
+ "that you have generated the evaluation key.",
198
+ )
199
+ }
200
+
201
+ # Retrieve the client API
202
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
203
+ client.load()
204
+
205
+ user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
206
+ # quant_user_symptoms = client.model.quantize_input(user_symptoms)
207
+
208
+ encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
209
+ assert isinstance(encrypted_quantized_user_symptoms, bytes)
210
+ encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input"
211
+
212
+ with encrypted_input_path.open("wb") as f:
213
+ f.write(encrypted_quantized_user_symptoms)
214
+
215
+ encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[
216
+ :INPUT_BROWSER_LIMIT
217
+ ]
218
+
219
+ return {
220
+ error_box3: gr.update(visible=False),
221
+ one_hot_vect_box: gr.update(visible=True, value=user_symptoms),
222
+ enc_vect_box: gr.update(visible=True, value=encrypted_quantized_user_symptoms_shorten_hex),
223
+ }
224
+
225
+
226
+ def send_input_fn(user_id: str, user_symptoms: np.ndarray) -> Dict:
227
+ """Send the encrypted data and the evaluation key to the server.
228
+
229
+ Args:
230
+ user_id (str): The current user's ID
231
+ user_symptoms (np.ndarray): The user symptoms
232
+ """
233
+
234
+ if is_none(user_id) or is_none(user_symptoms):
235
+ return {
236
+ error_box4: gr.update(
237
+ visible=True,
238
+ value="⚠️ Please check your connectivity \n"
239
+ "⚠️ Ensure that the symptoms have been submitted and the evaluation "
240
+ "key has been generated before sending the data to the server.",
241
+ )
242
+ }
243
+
244
+ evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
245
+ encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input"
246
+
247
+ if not evaluation_key_path.is_file():
248
+ print(
249
+ "Error Encountered While Sending Data to the Server: "
250
+ f"The key has been generated correctly - {evaluation_key_path.is_file()=}"
251
+ )
252
+
253
+ return {
254
+ error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.")
255
+ }
256
+
257
+ if not encrypted_input_path.is_file():
258
+ print(
259
+ "Error Encountered While Sending Data to the Server: The data has not been encrypted "
260
+ f"correctly on the client side - {encrypted_input_path.is_file()=}"
261
+ )
262
+ return {
263
+ error_box4: gr.update(
264
+ visible=True,
265
+ value="⚠️ Please encrypt the data with the private key first.",
266
+ ),
267
+ }
268
+
269
+ # Define the data and files to post
270
+ data = {
271
+ "user_id": user_id,
272
+ "input": user_symptoms,
273
+ }
274
+
275
+ files = [
276
+ ("files", open(encrypted_input_path, "rb")),
277
+ ("files", open(evaluation_key_path, "rb")),
278
+ ]
279
+
280
+ # Send the encrypted input and evaluation key to the server
281
+ url = SERVER_URL + "send_input"
282
+ with requests.post(
283
+ url=url,
284
+ data=data,
285
+ files=files,
286
+ ) as response:
287
+ print(f"Sending Data: {response.ok=}")
288
+ return {
289
+ error_box4: gr.update(visible=False),
290
+ srv_resp_send_data_box: "Data sent",
291
+ }
292
+
293
+
294
+ def run_fhe_fn(user_id: str) -> Dict:
295
+ """Send the encrypted input and the evaluation key to the server.
296
+
297
+ Args:
298
+ user_id (int): The current user's ID.
299
+ """
300
+ if is_none(user_id):
301
+ return {
302
+ error_box5: gr.update(
303
+ visible=True,
304
+ value="⚠️ Please check your connectivity \n"
305
+ "⚠️ Ensure that the symptoms have been submitted, the evaluation "
306
+ "key has been generated and the server received the data "
307
+ "before processing the data.",
308
+ ),
309
+ fhe_execution_time_box: None,
310
+ }
311
+
312
+ data = {
313
+ "user_id": user_id,
314
+ }
315
+
316
+ url = SERVER_URL + "run_fhe"
317
+
318
+ with requests.post(
319
+ url=url,
320
+ data=data,
321
+ ) as response:
322
+ if not response.ok:
323
+ return {
324
+ error_box5: gr.update(
325
+ visible=True,
326
+ value=(
327
+ "⚠️ An error occurred on the Server Side. "
328
+ "Please check connectivity and data transmission."
329
+ ),
330
+ ),
331
+ fhe_execution_time_box: gr.update(visible=False),
332
+ }
333
+ else:
334
+ time.sleep(1)
335
+ print(f"response.ok: {response.ok}, {response.json()} - Computed")
336
+
337
+ return {
338
+ error_box5: gr.update(visible=False),
339
+ fhe_execution_time_box: gr.update(visible=True, value=f"{response.json():.2f} seconds"),
340
+ }
341
+
342
+
343
+ def get_output_fn(user_id: str, user_symptoms: np.ndarray) -> Dict:
344
+ """Retreive the encrypted data from the server.
345
+
346
+ Args:
347
+ user_id (str): The current user's ID
348
+ user_symptoms (np.ndarray): The user symptoms
349
+ """
350
+
351
+ if is_none(user_id) or is_none(user_symptoms):
352
+ return {
353
+ error_box6: gr.update(
354
+ visible=True,
355
+ value="⚠️ Please check your connectivity \n"
356
+ "⚠️ Ensure that the server has successfully processed and transmitted the data to the client.",
357
+ )
358
+ }
359
+
360
+ data = {
361
+ "user_id": user_id,
362
+ }
363
+
364
+ # Retrieve the encrypted output
365
+ url = SERVER_URL + "get_output"
366
+ with requests.post(
367
+ url=url,
368
+ data=data,
369
+ ) as response:
370
+ if response.ok:
371
+ print(f"Receive Data: {response.ok=}")
372
+
373
+ encrypted_output = response.content
374
+
375
+ # Save the encrypted output to bytes in a file as it is too large to pass through
376
+ # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
377
+ encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output"
378
+
379
+ with encrypted_output_path.open("wb") as f:
380
+ f.write(encrypted_output)
381
+ return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"}
382
+
383
+
384
+ def decrypt_fn(
385
+ user_id: str, user_symptoms: np.ndarray, *checked_symptoms, threshold: int = 0.5
386
+ ) -> Dict:
387
+ """Dencrypt the data on the `Client Side`.
388
+
389
+ Args:
390
+ user_id (str): The current user's ID
391
+ user_symptoms (np.ndarray): The user symptoms
392
+ threshold (float): Probability confidence threshold
393
+
394
+ Returns:
395
+ Decrypted output
396
+ """
397
+
398
+ if is_none(user_id) or is_none(user_symptoms):
399
+ return {
400
+ error_box7: gr.update(
401
+ visible=True,
402
+ value="⚠️ Please check your connectivity \n"
403
+ "⚠️ Ensure that the client has successfully received the data from the server.",
404
+ )
405
+ }
406
+
407
+ # Get the encrypted output path
408
+ encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output"
409
+
410
+ if not encrypted_output_path.is_file():
411
+ print("Error in decryption step: Please run the FHE execution, first.")
412
+ return {
413
+ error_box7: gr.update(
414
+ visible=True,
415
+ value="⚠️ Please ensure that: \n"
416
+ "- the connectivity \n"
417
+ "- the symptoms have been submitted \n"
418
+ "- the evaluation key has been generated \n"
419
+ "- the server processed the encrypted data \n"
420
+ "- the Client received the data from the Server before decrypting the prediction",
421
+ ),
422
+ decrypt_box: None,
423
+ }
424
+
425
+ # Load the encrypted output as bytes
426
+ with encrypted_output_path.open("rb") as f:
427
+ encrypted_output = f.read()
428
+
429
+ # Retrieve the client API
430
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
431
+ client.load()
432
+
433
+ # Deserialize, decrypt and post-process the encrypted output
434
+ output = client.deserialize_decrypt_dequantize(encrypted_output)
435
+
436
+ top3_diseases = np.argsort(output.flatten())[-3:][::-1]
437
+ top3_proba = output[0][top3_diseases]
438
+
439
+ out = ""
440
+
441
+ if top3_proba[0] < threshold or abs(top3_proba[0] - top3_proba[1]) < 0.1:
442
+ out = (
443
+ "⚠️ The prediction appears uncertain; including more symptoms "
444
+ "may improve the results.\n\n"
445
+ )
446
+
447
+ out = (
448
+ f"{out}Given the symptoms you provided: "
449
+ f"{pretty_print(checked_symptoms, case_conversion=str.capitalize, delimiter=', ')}\n\n"
450
+ "Here are the top3 predictions:\n\n"
451
+ f"1. « {get_disease_name(top3_diseases[0])} » with a probability of {top3_proba[0]:.2%}\n"
452
+ f"2. « {get_disease_name(top3_diseases[1])} » with a probability of {top3_proba[1]:.2%}\n"
453
+ f"3. « {get_disease_name(top3_diseases[2])} » with a probability of {top3_proba[2]:.2%}\n"
454
+ )
455
+
456
+ return {
457
+ error_box7: gr.update(visible=False),
458
+ decrypt_box: out,
459
+ submit_btn: gr.update(value="Submit"),
460
+ }
461
+
462
+
463
+ def reset_fn():
464
+ """Reset the space and clear all the box outputs."""
465
+
466
+ clean_directory()
467
+
468
+ return {
469
+ one_hot_vect: None,
470
+ one_hot_vect_box: None,
471
+ enc_vect_box: gr.update(visible=True, value=None),
472
+ quant_vect_box: gr.update(visible=False, value=None),
473
+ user_id_box: gr.update(visible=False, value=None),
474
+ default_symptoms: gr.update(visible=True, value=None),
475
+ default_disease_box: gr.update(visible=True, value=None),
476
+ key_box: gr.update(visible=True, value=None),
477
+ key_len_box: gr.update(visible=False, value=None),
478
+ fhe_execution_time_box: gr.update(visible=True, value=None),
479
+ decrypt_box: None,
480
+ submit_btn: gr.update(value="Submit"),
481
+ error_box7: gr.update(visible=False),
482
+ error_box1: gr.update(visible=False),
483
+ error_box2: gr.update(visible=False),
484
+ error_box3: gr.update(visible=False),
485
+ error_box4: gr.update(visible=False),
486
+ error_box5: gr.update(visible=False),
487
+ error_box6: gr.update(visible=False),
488
+ srv_resp_send_data_box: None,
489
+ srv_resp_retrieve_data_box: None,
490
+ **{box: None for box in check_boxes},
491
+ }
492
+
493
+
494
+ if __name__ == "__main__":
495
+
496
+ print("Starting demo ...")
497
+
498
+ clean_directory()
499
+
500
+ (X_train, X_test), (y_train, y_test), valid_symptoms, diseases = load_data()
501
+
502
+ with gr.Blocks() as demo:
503
+
504
+ # Link + images
505
+ gr.Markdown()
506
+ gr.Markdown(
507
+ """
508
+ <p align="center">
509
+ <img width=200 src="https://user-images.githubusercontent.com/5758427/197816413-d9cddad3-ba38-4793-847d-120975e1da11.png">
510
+ </p>
511
+ """)
512
+ gr.Markdown()
513
+ gr.Markdown("""<h2 align="center">Health Prediction On Encrypted Data Using Fully Homomorphic Encryption</h2>""")
514
+ gr.Markdown()
515
+ gr.Markdown(
516
+ """
517
+ <p align="center">
518
+ <a href="https://github.com/zama-ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197972109-faaaff3e-10e2-4ab6-80f5-7531f7cfb08f.png">Concrete-ML</a>
519
+
520
+ <a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197976802-fddd34c5-f59a-48d0-9bff-7ad1b00cb1fb.png">Documentation</a>
521
+
522
+ <a href="https://zama.ai/community"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197977153-8c9c01a7-451a-4993-8e10-5a6ed5343d02.png">Community</a>
523
+
524
+ <a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197975044-bab9d199-e120-433b-b3be-abd73b211a54.png">@zama_fhe</a>
525
+ </p>
526
+ """)
527
+ gr.Markdown()
528
+ gr.Markdown(
529
+ """"
530
+ <p align="center">
531
+ <img width="65%" height="25%" src="https://raw.githubusercontent.com/kcelia/Img/main/healthcare_prediction.jpg">
532
+ </p>
533
+ """
534
+ )
535
+ gr.Markdown("## Notes")
536
+ gr.Markdown(
537
+ """
538
+ - The private key is used to encrypt and decrypt the data and shall never be shared.
539
+ - The evaluation key is a public key that the server needs to process encrypted data.
540
+ """
541
+ )
542
+
543
+ # ------------------------- Step 1 -------------------------
544
+ gr.Markdown("\n")
545
+ gr.Markdown("## Step 1: Select chief complaints")
546
+ gr.Markdown("<hr />")
547
+ gr.Markdown("<span style='color:grey'>Client Side</span>")
548
+ gr.Markdown("Select at least 5 chief complaints from the list below.")
549
+
550
+ # Step 1.1: Provide symptoms
551
+ check_boxes = []
552
+ with gr.Row():
553
+ with gr.Column():
554
+ for category in SYMPTOMS_LIST[:3]:
555
+ with gr.Accordion(pretty_print(category.keys()), open=False):
556
+ check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0)
557
+ check_boxes.append(check_box)
558
+ with gr.Column():
559
+ for category in SYMPTOMS_LIST[3:6]:
560
+ with gr.Accordion(pretty_print(category.keys()), open=False):
561
+ check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0)
562
+ check_boxes.append(check_box)
563
+ with gr.Column():
564
+ for category in SYMPTOMS_LIST[6:]:
565
+ with gr.Accordion(pretty_print(category.keys()), open=False):
566
+ check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0)
567
+ check_boxes.append(check_box)
568
+
569
+ error_box1 = gr.Textbox(label="Error ❌", visible=False)
570
+
571
+ # Default disease, picked from the dataframe
572
+ gr.Markdown(
573
+ "You can choose an **existing disease** and explore its associated symptoms.",
574
+ visible=False,
575
+ )
576
+
577
+ with gr.Row():
578
+ with gr.Column(scale=2):
579
+ default_disease_box = gr.Dropdown(sorted(diseases), label="Diseases", visible=False)
580
+ with gr.Column(scale=5):
581
+ default_symptoms = gr.Textbox(label="Related Symptoms:", visible=False)
582
+ # User vector symptoms encoded in oneHot representation
583
+ one_hot_vect = gr.Textbox(visible=False)
584
+ # Submit botton
585
+ submit_btn = gr.Button("Submit")
586
+ # Clear botton
587
+ clear_button = gr.Button("Reset Space 🔁", visible=False)
588
+
589
+ default_disease_box.change(
590
+ fn=display_default_symptoms_fn, inputs=[default_disease_box], outputs=[default_symptoms]
591
+ )
592
+
593
+ submit_btn.click(
594
+ fn=get_features_fn,
595
+ inputs=[*check_boxes],
596
+ outputs=[one_hot_vect, error_box1, submit_btn],
597
+ )
598
+
599
+ # ------------------------- Step 2 -------------------------
600
+ gr.Markdown("\n")
601
+ gr.Markdown("## Step 2: Encrypt data")
602
+ gr.Markdown("<hr />")
603
+ gr.Markdown("<span style='color:grey'>Client Side</span>")
604
+ # Step 2.1: Key generation
605
+ gr.Markdown(
606
+ "### Key Generation\n\n"
607
+ "In FHE schemes, a secret (enc/dec)ryption keys are generated for encrypting and decrypting data owned by the client. \n\n"
608
+ "Additionally, a public evaluation key is generated, enabling external entities to perform homomorphic operations on encrypted data, without the need to decrypt them. \n\n"
609
+ "The evaluation key will be transmitted to the server for further processing."
610
+ )
611
+
612
+ gen_key_btn = gr.Button("Generate the private and evaluation keys.")
613
+ error_box2 = gr.Textbox(label="Error ❌", visible=False)
614
+ user_id_box = gr.Textbox(label="User ID:", visible=False)
615
+ key_len_box = gr.Textbox(label="Evaluation Key Size:", visible=False)
616
+ key_box = gr.Textbox(label="Evaluation key (truncated):", max_lines=3, visible=False)
617
+
618
+ gen_key_btn.click(
619
+ key_gen_fn,
620
+ inputs=one_hot_vect,
621
+ outputs=[
622
+ key_box,
623
+ user_id_box,
624
+ key_len_box,
625
+ error_box2,
626
+ gen_key_btn,
627
+ ],
628
+ )
629
+
630
+ # Step 2.2: Encrypt data locally
631
+ gr.Markdown("### Encrypt the data")
632
+ encrypt_btn = gr.Button("Encrypt the data using the private secret key")
633
+ error_box3 = gr.Textbox(label="Error ❌", visible=False)
634
+ quant_vect_box = gr.Textbox(label="Quantized Vector:", visible=False)
635
+
636
+ with gr.Row():
637
+ with gr.Column():
638
+ one_hot_vect_box = gr.Textbox(label="User Symptoms Vector:", max_lines=10)
639
+ with gr.Column():
640
+ enc_vect_box = gr.Textbox(label="Encrypted Vector:", max_lines=10)
641
+
642
+ encrypt_btn.click(
643
+ encrypt_fn,
644
+ inputs=[one_hot_vect, user_id_box],
645
+ outputs=[
646
+ one_hot_vect_box,
647
+ enc_vect_box,
648
+ error_box3,
649
+ ],
650
+ )
651
+ # Step 2.3: Send encrypted data to the server
652
+ gr.Markdown(
653
+ "### Send the encrypted data to the <span style='color:grey'>Server Side</span>"
654
+ )
655
+ error_box4 = gr.Textbox(label="Error ❌", visible=False)
656
+
657
+ # with gr.Row().style(equal_height=False):
658
+ with gr.Row():
659
+ with gr.Column(scale=4):
660
+ send_input_btn = gr.Button("Send data")
661
+ with gr.Column(scale=1):
662
+ srv_resp_send_data_box = gr.Checkbox(label="Data Sent", show_label=False)
663
+
664
+ send_input_btn.click(
665
+ send_input_fn,
666
+ inputs=[user_id_box, one_hot_vect],
667
+ outputs=[error_box4, srv_resp_send_data_box],
668
+ )
669
+
670
+ # ------------------------- Step 3 -------------------------
671
+ gr.Markdown("\n")
672
+ gr.Markdown("## Step 3: Run the FHE evaluation")
673
+ gr.Markdown("<hr />")
674
+ gr.Markdown("<span style='color:grey'>Server Side</span>")
675
+ gr.Markdown(
676
+ "Once the server receives the encrypted data, it can process and compute the output without ever decrypting the data just as it would on clear data.\n\n"
677
+ "This server employs a [Logistic Regression](https://github.com/zama-ai/concrete-ml/tree/release/1.1.x/use_case_examples/disease_prediction) model that has been trained on this [data-set](https://github.com/anujdutt9/Disease-Prediction-from-Symptoms/tree/master/dataset)."
678
+ )
679
+
680
+ run_fhe_btn = gr.Button("Run the FHE evaluation")
681
+ error_box5 = gr.Textbox(label="Error ❌", visible=False)
682
+ fhe_execution_time_box = gr.Textbox(label="Total FHE Execution Time:", visible=True)
683
+ run_fhe_btn.click(
684
+ run_fhe_fn,
685
+ inputs=[user_id_box],
686
+ outputs=[fhe_execution_time_box, error_box5],
687
+ )
688
+
689
+ # ------------------------- Step 4 -------------------------
690
+ gr.Markdown("\n")
691
+ gr.Markdown("## Step 4: Decrypt the data")
692
+ gr.Markdown("<hr />")
693
+ gr.Markdown("<span style='color:grey'>Client Side</span>")
694
+ gr.Markdown(
695
+ "### Get the encrypted data from the <span style='color:grey'>Server Side</span>"
696
+ )
697
+
698
+ error_box6 = gr.Textbox(label="Error ❌", visible=False)
699
+
700
+ # Step 4.1: Data transmission
701
+ # with gr.Row().style(equal_height=True):
702
+ with gr.Row():
703
+ with gr.Column(scale=4):
704
+ get_output_btn = gr.Button("Get data")
705
+ with gr.Column(scale=1):
706
+ srv_resp_retrieve_data_box = gr.Checkbox(label="Data Received", show_label=False)
707
+
708
+ get_output_btn.click(
709
+ get_output_fn,
710
+ inputs=[user_id_box, one_hot_vect],
711
+ outputs=[srv_resp_retrieve_data_box, error_box6],
712
+ )
713
+
714
+ # Step 4.1: Data transmission
715
+ gr.Markdown("### Decrypt the output")
716
+ decrypt_btn = gr.Button("Decrypt the output using the private secret key")
717
+ error_box7 = gr.Textbox(label="Error ❌", visible=False)
718
+ decrypt_box = gr.Textbox(label="Decrypted Output:")
719
+
720
+ decrypt_btn.click(
721
+ decrypt_fn,
722
+ inputs=[user_id_box, one_hot_vect, *check_boxes],
723
+ outputs=[decrypt_box, error_box7, submit_btn],
724
+ )
725
+
726
+ # ------------------------- End -------------------------
727
+
728
+ gr.Markdown(
729
+ """The app was built with [Concrete ML](https://github.com/zama-ai/concrete-ml), a Privacy-Preserving Machine Learning (PPML) open-source set of tools by Zama.
730
+ Try it yourself and don't forget to star on [Github](https://github.com/zama-ai/concrete-ml) ⭐.
731
+ """
732
+ )
733
+
734
+ gr.Markdown("\n\n")
735
+
736
+ gr.Markdown(
737
+ """**Please Note**: This space is intended solely for educational and demonstration purposes.
738
+ It should not be considered as a replacement for professional medical counsel, diagnosis, or therapy for any health or related issues.
739
+ Any questions or concerns about your individual health should be addressed to your doctor or another qualified healthcare provider.
740
+ """
741
+ )
742
+
743
+ clear_button.click(
744
+ reset_fn,
745
+ outputs=[
746
+ one_hot_vect_box,
747
+ one_hot_vect,
748
+ submit_btn,
749
+ error_box1,
750
+ error_box2,
751
+ error_box3,
752
+ error_box4,
753
+ error_box5,
754
+ error_box6,
755
+ error_box7,
756
+ default_disease_box,
757
+ default_symptoms,
758
+ user_id_box,
759
+ key_len_box,
760
+ key_box,
761
+ quant_vect_box,
762
+ enc_vect_box,
763
+ srv_resp_send_data_box,
764
+ srv_resp_retrieve_data_box,
765
+ fhe_execution_time_box,
766
+ decrypt_box,
767
+ *check_boxes,
768
+ ],
769
+ )
770
+
771
+ demo.launch()
dev.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generating deployment files."""
2
+
3
+ import shutil
4
+
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+
9
+ from concrete.ml.sklearn import LogisticRegression as ConcreteLogisticRegression
10
+ from concrete.ml.deployment import FHEModelDev
11
+
12
+
13
+ # Data files location
14
+ TRAINING_FILE_NAME = "./data/Training_preprocessed.csv"
15
+ TESTING_FILE_NAME = "./data/Testing_preprocessed.csv"
16
+
17
+ # Load data
18
+ df_train = pd.read_csv(TRAINING_FILE_NAME)
19
+ df_test = pd.read_csv(TESTING_FILE_NAME)
20
+
21
+ # Split the data into X_train, y_train, X_test_, y_test sets
22
+ TARGET_COLUMN = ["prognosis_encoded", "prognosis"]
23
+
24
+ y_train = df_train[TARGET_COLUMN[0]].values.flatten()
25
+ y_test = df_test[TARGET_COLUMN[0]].values.flatten()
26
+
27
+ X_train = df_train.drop(TARGET_COLUMN, axis=1)
28
+ X_test = df_test.drop(TARGET_COLUMN, axis=1)
29
+
30
+ # Concrete ML model
31
+
32
+ # Models parameters
33
+ optimal_param = {"C": 0.9, "n_bits": 13, "solver": "sag", "multi_class": "auto"}
34
+
35
+ clf = ConcreteLogisticRegression(**optimal_param)
36
+
37
+ # Fit the model
38
+ clf.fit(X_train, y_train)
39
+
40
+ # Compile the model
41
+ fhe_circuit = clf.compile(X_train)
42
+
43
+ fhe_circuit.client.keygen(force=False)
44
+
45
+ path_to_model = Path("./deployment_files/").resolve()
46
+
47
+ if path_to_model.exists():
48
+ shutil.rmtree(path_to_model)
49
+
50
+ dev = FHEModelDev(path_to_model, clf)
51
+
52
+ dev.save(via_mlir=True)
healthcare_prediction.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ concrete-ml==1.4.0
2
+ gradio
3
+ uvicorn>=0.21.0
4
+ fastapi>=0.93.0
server.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Server that will listen for GET and POST requests from the client."""
2
+
3
+ import time
4
+ from typing import List
5
+
6
+ from fastapi import FastAPI, File, Form, UploadFile
7
+ from fastapi.responses import JSONResponse, Response
8
+ from utils import DEPLOYMENT_DIR, SERVER_DIR # pylint: disable=no-name-in-module
9
+
10
+ from concrete.ml.deployment import FHEModelServer
11
+
12
+ # Load the FHE server
13
+ FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR)
14
+
15
+ # Initialize an instance of FastAPI
16
+ app = FastAPI()
17
+
18
+ # Define the default route
19
+ @app.get("/")
20
+ def root():
21
+ """
22
+ Root endpoint of the health prediction API.
23
+
24
+ Returns:
25
+ dict: The welcome message.
26
+ """
27
+ return {"message": "Welcome to your disease prediction with FHE!"}
28
+
29
+
30
+ @app.post("/send_input")
31
+ def send_input(
32
+ user_id: str = Form(),
33
+ files: List[UploadFile] = File(),
34
+ ):
35
+ """Send the inputs to the server."""
36
+
37
+ print("\nSend the data to the server ............\n")
38
+
39
+ # Receive the Client's files (Evaluation key + Encrypted symptoms)
40
+ evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key"
41
+ encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input"
42
+
43
+ # Save the files using the above paths
44
+ with encrypted_input_path.open("wb") as encrypted_input, evaluation_key_path.open(
45
+ "wb"
46
+ ) as evaluation_key:
47
+ encrypted_input.write(files[0].file.read())
48
+ evaluation_key.write(files[1].file.read())
49
+
50
+
51
+ @app.post("/run_fhe")
52
+ def run_fhe(
53
+ user_id: str = Form(),
54
+ ):
55
+ """Inference in FHE."""
56
+
57
+ print("\nRun in FHE in the server ............\n")
58
+ evaluation_key_path = SERVER_DIR / f"{user_id}_valuation_key"
59
+ encrypted_input_path = SERVER_DIR / f"{user_id}_encrypted_input"
60
+
61
+ # Read the files (Evaluation key + Encrypted symptoms) using the above paths
62
+ with encrypted_input_path.open("rb") as encrypted_output_file, evaluation_key_path.open(
63
+ "rb"
64
+ ) as evaluation_key_file:
65
+ encrypted_output = encrypted_output_file.read()
66
+ evaluation_key = evaluation_key_file.read()
67
+
68
+ # Run the FHE execution
69
+ start = time.time()
70
+ encrypted_output = FHE_SERVER.run(encrypted_output, evaluation_key)
71
+ assert isinstance(encrypted_output, bytes)
72
+ fhe_execution_time = round(time.time() - start, 2)
73
+
74
+ # Retrieve the encrypted output path
75
+ encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output"
76
+
77
+ # Write the file using the above path
78
+ with encrypted_output_path.open("wb") as f:
79
+ f.write(encrypted_output)
80
+
81
+ return JSONResponse(content=fhe_execution_time)
82
+
83
+
84
+ @app.post("/get_output")
85
+ def get_output(user_id: str = Form()):
86
+ """Retrieve the encrypted output from the server."""
87
+
88
+ print("\nGet the output from the server ............\n")
89
+
90
+ # Path where the encrypted output is saved
91
+ encrypted_output_path = SERVER_DIR / f"{user_id}_encrypted_output"
92
+
93
+ # Read the file using the above path
94
+ with encrypted_output_path.open("rb") as f:
95
+ encrypted_output = f.read()
96
+
97
+ time.sleep(1)
98
+
99
+ # Send the encrypted output
100
+ return Response(encrypted_output)
symptoms_categories.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In this file, we roughly split up a list of symptoms, taken from "./training.csv" file, avalaible
3
+ through: "https://github.com/anujdutt9/Disease-Prediction-from-Symptoms/tree/master/dataset"
4
+ into medical categories, in order to make the UI more plesant for the users.
5
+
6
+ Each variable contains a list of symptoms sthat can be pecific to a part of the body or to a list
7
+ of similar symptoms.
8
+ """
9
+
10
+
11
+ DIGESTIVE_SYSTEM_SYMPTOMS = {
12
+ "DIGESTIVE_SYSTEM_CONCERNS": [
13
+ "stomach_pain",
14
+ "acidity",
15
+ "vomiting",
16
+ "indigestion",
17
+ "constipation",
18
+ "abdominal_pain",
19
+ "diarrhea",
20
+ "nausea",
21
+ "distention_of_abdomen",
22
+ "stomach_bleeding",
23
+ "pain_during_bowel_movements",
24
+ "passage_of_gases",
25
+ "red_spots_over_body",
26
+ "swelling_of_stomach",
27
+ "bloody_stool",
28
+ "irritation_in_anus",
29
+ "pain_in_anal_region",
30
+ "abnormal_menstruation",
31
+ ]
32
+ }
33
+
34
+ DERMATOLOGICAL_SYMPTOMS = {
35
+ "DERMATOLOGICAL_CONCERNS": [
36
+ "itching",
37
+ "skin_rash",
38
+ "pus_filled_pimples",
39
+ "blackheads",
40
+ "scurving",
41
+ "skin_peeling",
42
+ "silver_like_dusting",
43
+ "small_dents_in_nails",
44
+ "inflammatory_nails",
45
+ "blister",
46
+ "red_sore_around_nose",
47
+ "bruising",
48
+ "yellow_crust_ooze",
49
+ "dischromic_patches",
50
+ "nodal_skin_eruptions",
51
+ "toxic_look_(typhus)",
52
+ "brittle_nails",
53
+ "yellowish_skin",
54
+ ]
55
+ }
56
+
57
+ ORL_SYMPTOMS = {
58
+ "ORL_CONCERNS": [
59
+ "loss_of_smell",
60
+ "continuous_sneezing",
61
+ "runny_nose",
62
+ "patches_in_throat",
63
+ "throat_irritation",
64
+ "sinus_pressure",
65
+ "enlarged_thyroid",
66
+ "loss_of_balance",
67
+ "unsteadiness",
68
+ "dizziness",
69
+ "spinning_movements",
70
+ ]
71
+ }
72
+
73
+ THORAX_SYMPTOMS = {
74
+ "THORAX_CONCERNS": [
75
+ "breathlessness",
76
+ "chest_pain",
77
+ "cough",
78
+ "rusty_sputum",
79
+ "phlegm",
80
+ "mucoid_sputum",
81
+ "congestion",
82
+ "blood_in_sputum",
83
+ "fast_heart_rate",
84
+ ]
85
+ }
86
+
87
+ OPHTHALMOLOGICAL_SYMPTOMS = {
88
+ "OPHTHALMOLOGICAL_CONCERNS": [
89
+ "sunken_eyes",
90
+ "redness_of_eyes",
91
+ "watering_from_eyes",
92
+ "blurred_and_distorted_vision",
93
+ "pain_behind_the_eyes",
94
+ "visual_disturbances",
95
+ ]
96
+ }
97
+
98
+ VASCULAR_LYMPHATIC_SYMPTOMS = {
99
+ "VASCULAR_AND_LYMPHATIC_CONCERNS": [
100
+ "cold_hands_and_feets",
101
+ "swollen_blood_vessels",
102
+ "swollen_legs",
103
+ "swelled_lymph_nodes",
104
+ "palpitations",
105
+ "prominent_veins_on_calf",
106
+ "yellowing_of_eyes",
107
+ "puffy_face_and_eyes",
108
+ "severe_fluid_overload",
109
+ "swollen_extremeties",
110
+ ]
111
+ }
112
+
113
+ UROLOGICAL_SYMPTOMS = {
114
+ "UROLOGICAL_CONCERNS": [
115
+ "burning_micturition",
116
+ "spotting_urination",
117
+ "yellow_urine",
118
+ "bladder_discomfort",
119
+ "foul_smell_of_urine",
120
+ "continuous_feel_of_urine",
121
+ "polyuria",
122
+ "dark_urine",
123
+ ]
124
+ }
125
+
126
+ MUSCULOSKELETAL_SYMPTOMS = {
127
+ "MUSCULOSKELETAL_CONCERNS": [
128
+ "joint_pain",
129
+ "muscle_wasting",
130
+ "muscle_pain",
131
+ "muscle_weakness",
132
+ "knee_pain",
133
+ "stiff_neck",
134
+ "swelling_joints",
135
+ "movement_stiffness",
136
+ "hip_joint_pain",
137
+ "painful_walking",
138
+ "weakness_of_one_body_side",
139
+ "neck_pain",
140
+ "back_pain",
141
+ "weakness_in_limbs",
142
+ "cramps",
143
+ ]
144
+ }
145
+
146
+ GENERAL_SYMPTOMS = {
147
+ "GENERAL_CONCERNS": [
148
+ "acute_liver_failure",
149
+ "anxiety",
150
+ "restlessness",
151
+ "lethargy",
152
+ "mood_swings",
153
+ "irritability",
154
+ "lack_of_concentration",
155
+ "fatigue",
156
+ "malaise",
157
+ "weight_gain",
158
+ "increased_appetite",
159
+ "weight_loss",
160
+ "loss_of_appetite",
161
+ "excess_body_fat",
162
+ "excessive_hunger",
163
+ "ulcers_on_tongue",
164
+ "shivering",
165
+ "chills",
166
+ "irregular_sugar_level",
167
+ "high_fever",
168
+ "slurred_speech",
169
+ "sweating",
170
+ "internal_itching",
171
+ "mild_fever",
172
+ "dehydration",
173
+ "headache",
174
+ "frequent_unprotected_sexual_intercourse_with_multiple_partners",
175
+ "drying_and_tingling_lips",
176
+ "altered_sensorium",
177
+ "family_history",
178
+ "receiving_blood_transfusion",
179
+ "receiving_unsterile_injections",
180
+ "chronic_alcohol_abuse",
181
+ ]
182
+ }
183
+
184
+ SYMPTOMS_LIST = [
185
+ # Column 1
186
+ DIGESTIVE_SYSTEM_SYMPTOMS,
187
+ UROLOGICAL_SYMPTOMS,
188
+ VASCULAR_LYMPHATIC_SYMPTOMS,
189
+ # Column 2
190
+ ORL_SYMPTOMS,
191
+ DERMATOLOGICAL_SYMPTOMS,
192
+ MUSCULOSKELETAL_SYMPTOMS,
193
+ # Column 3
194
+ OPHTHALMOLOGICAL_SYMPTOMS,
195
+ THORAX_SYMPTOMS,
196
+ GENERAL_SYMPTOMS,
197
+ ]
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Union
5
+
6
+ import numpy
7
+ import pandas
8
+
9
+ from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
10
+
11
+ # Max Input to be displayed on the HuggingFace space brower using Gradio
12
+ # Too large inputs, slow down the server: https://github.com/gradio-app/gradio/issues/1877
13
+ INPUT_BROWSER_LIMIT = 380
14
+
15
+ # Store the server's URL
16
+ SERVER_URL = "http://localhost:8000/"
17
+
18
+ CURRENT_DIR = Path(__file__).parent
19
+ DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files"
20
+ KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
21
+ CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
22
+ SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
23
+
24
+ ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
25
+
26
+ # Columns that define the target
27
+ TARGET_COLUMNS = ["prognosis_encoded", "prognosis"]
28
+
29
+ TRAINING_FILENAME = "./data/Training_preprocessed.csv"
30
+ TESTING_FILENAME = "./data/Testing_preprocessed.csv"
31
+
32
+ # pylint: disable=invalid-name
33
+
34
+
35
+ def pretty_print(
36
+ inputs, case_conversion=str.title, which_replace: str = "_", to_what: str = " ", delimiter=None
37
+ ):
38
+ """
39
+ Prettify and sort the input as a list of string.
40
+
41
+ Args:
42
+ inputs (Any): The inputs to be prettified.
43
+
44
+ Returns:
45
+ List: The prettified and sorted list of inputs.
46
+
47
+ """
48
+ # Flatten the list if required
49
+ pretty_list = []
50
+ for item in inputs:
51
+ if isinstance(item, list):
52
+ pretty_list.extend(item)
53
+ else:
54
+ pretty_list.append(item)
55
+
56
+ # Sort
57
+ pretty_list = sorted(list(set(pretty_list)))
58
+ # Replace
59
+ pretty_list = [item.replace(which_replace, to_what) for item in pretty_list]
60
+ pretty_list = [case_conversion(item) for item in pretty_list]
61
+ if delimiter:
62
+ pretty_list = f"{delimiter.join(pretty_list)}."
63
+
64
+ return pretty_list
65
+
66
+
67
+ def clean_directory() -> None:
68
+ """
69
+ Clear direcgtories
70
+ """
71
+ print("Cleaning...\n")
72
+ for target_dir in ALL_DIRS:
73
+ if os.path.exists(target_dir) and os.path.isdir(target_dir):
74
+ shutil.rmtree(target_dir)
75
+ target_dir.mkdir(exist_ok=True, parents=True)
76
+
77
+
78
+ def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str:
79
+ """Return the disease name given its encoded label.
80
+
81
+ Args:
82
+ encoded_prediction (int): The encoded prediction
83
+ file_name (str): The data file path
84
+
85
+ Returns:
86
+ str: The according disease name
87
+ """
88
+ df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates()
89
+ disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten()
90
+ return disease_name
91
+
92
+
93
+ def load_data() -> Union[Tuple[pandas.DataFrame, numpy.ndarray], List]:
94
+ """
95
+ Return the data
96
+
97
+ Args:
98
+ None
99
+
100
+ Return:
101
+ The train, testing set and valid symptoms.
102
+ """
103
+ # Load data
104
+ df_train = pandas.read_csv(TRAINING_FILENAME)
105
+ df_test = pandas.read_csv(TESTING_FILENAME)
106
+
107
+ # Separate the traget from the training / testing set:
108
+ # TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease
109
+ # TARGET_COLUMNS[1] -> "prognosis" -> contains the name of the disease
110
+
111
+ y_train = df_train[TARGET_COLUMNS[0]]
112
+ X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
113
+
114
+ y_test = df_test[TARGET_COLUMNS[0]]
115
+ X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
116
+
117
+ return (
118
+ (X_train, X_test),
119
+ (y_train, y_test),
120
+ X_train.columns.to_list(),
121
+ df_train[TARGET_COLUMNS[1]].unique().tolist(),
122
+ )
123
+
124
+
125
+ def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
126
+ """
127
+ Load a pre-trained serialized model
128
+
129
+ Args:
130
+ X_train (pandas.DataFrame): Training set
131
+ y_train (numpy.ndarray): Targets of the training set
132
+
133
+ Return:
134
+ The Concrete ML model and its circuit
135
+ """
136
+ # Parameters
137
+ concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
138
+ classifier = ConcreteXGBoostClassifier(**concrete_args)
139
+ # Train the model
140
+ classifier.fit(X_train, y_train)
141
+ # Compile the model
142
+ circuit = classifier.compile(X_train)
143
+
144
+ return classifier, circuit