import subprocess import time from typing import Dict, List, Tuple import gradio as gr # pylint: disable=import-error import numpy as np import pandas as pd import requests from symptoms_categories import SYMPTOMS_LIST from utils import ( CLIENT_DIR, CURRENT_DIR, DEPLOYMENT_DIR, INPUT_BROWSER_LIMIT, KEYS_DIR, SERVER_URL, TARGET_COLUMNS, TRAINING_FILENAME, clean_directory, get_disease_name, load_data, pretty_print, ) from concrete.ml.deployment import FHEModelClient subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) time.sleep(3) # pylint: disable=c-extension-no-member,invalid-name def is_none(obj) -> bool: """ Check if the object is None. Args: obj (any): The input to be checked. Returns: bool: True if the object is None or empty, False otherwise. """ return obj is None or (obj is not None and len(obj) < 1) def display_default_symptoms_fn(default_disease: str) -> Dict: """ Displays the symptoms of a given existing disease. Args: default_disease (str): Disease Returns: Dict: The according symptoms """ df = pd.read_csv(TRAINING_FILENAME) df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease] return { default_symptoms: gr.update( visible=True, value=pretty_print( df_filtred.columns[df_filtred.eq(1).any()].to_list(), delimiter=", " ), ) } def get_user_symptoms_from_checkboxgroup(checkbox_symptoms: List) -> np.array: """ Convert the user symptoms into a binary vector representation. Args: checkbox_symptoms (List): A list of user symptoms. Returns: np.array: A binary vector representing the user's symptoms. Raises: KeyError: If a provided symptom is not recognized as a valid symptom. """ symptoms_vector = {key: 0 for key in valid_symptoms} for pretty_symptom in checkbox_symptoms: original_symptom = "_".join((pretty_symptom.lower().split(" "))) if original_symptom not in symptoms_vector.keys(): raise KeyError( f"The symptom '{original_symptom}' you provided is not recognized as a valid " f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}" ) symptoms_vector[original_symptom] = 1 user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :] assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten()) return user_symptoms_vect def get_features_fn(*checked_symptoms: Tuple[str]) -> Dict: """ Get vector features based on the selected symptoms. Args: checked_symptoms (Tuple[str]): User symptoms Returns: Dict: The encoded user vector symptoms. """ if not any(lst for lst in checked_symptoms if lst): return { error_box1: gr.update(visible=True, value="⚠️ Please provide your chief complaints."), } if len(pretty_print(checked_symptoms)) < 5: print("Provide at least 5 symptoms.") return { error_box1: gr.update(visible=True, value="⚠️ Provide at least 5 symptoms"), one_hot_vector: None, } return { error_box1: gr.update(visible=False), one_hot_vector: gr.update( visible=False, value=get_user_symptoms_from_checkboxgroup(pretty_print(checked_symptoms)), ), submit_btn: gr.update(value="Data submitted ✅"), } def key_gen_fn(user_symptoms: List[str]) -> Dict: """ Generate keys for a given user. Args: user_symptoms (List[str]): The vector symptoms provided by the user. Returns: dict: A dictionary containing the generated keys and related information. """ clean_directory() if is_none(user_symptoms): print("Error: Please submit your symptoms or select a default disease.") return { error_box2: gr.update(visible=True, value="⚠️ Please submit your symptoms first."), } # Generate a random user ID user_id = np.random.randint(0, 2**32) print(f"Your user ID is: {user_id}....") client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Creates the private and evaluation keys on the client side client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # Save the evaluation key evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box2: gr.update(visible=False), key_box: gr.update(visible=False, value=serialized_evaluation_keys_shorten_hex), user_id_box: gr.update(visible=True, value=user_id), key_len_box: gr.update( visible=False, value=f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" ), } def encrypt_fn(user_symptoms: np.ndarray, user_id: str) -> None: """ Encrypt the user symptoms vector in the `Client Side`. Args: user_symptoms (List[str]): The vector symptoms provided by the user user_id (user): The current user's ID """ if is_none(user_id) or is_none(user_symptoms): print("Error in encryption step: Provide your symptoms and generate the evaluation keys.") return { error_box3: gr.update( visible=True, value="⚠️ Please ensure that your symptoms have been submitted and " "that you have generated the evaluation key.", ) } # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1) #quant_user_symptoms = client.model.quantize_input(user_symptoms) encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms) assert isinstance(encrypted_quantized_user_symptoms, bytes) encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_symptoms) encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box3: gr.update(visible=False), one_hot_vector_box: gr.update(visible=True, value=user_symptoms), enc_vect_box: gr.update(visible=True, value=encrypted_quantized_user_symptoms_shorten_hex), } def send_input_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms """ if is_none(user_id) or is_none(user_symptoms): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the symptoms have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_symptoms, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def run_fhe_fn(user_id: str) -> Dict: """Send the encrypted input and the evaluation key to the server. Args: user_id (int): The current user's ID. """ if is_none(user_id): return { error_box5: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the symptoms have been submitted, the evaluation " "key has been generated and the server received the data " "before processing the data.", ), fhe_execution_time_box: None, } data = { "user_id": user_id, } url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if not response.ok: return { error_box5: gr.update( visible=True, value=( "⚠️ An error occurred on the Server Side. " "Please check connectivity and data transmission." ), ), fhe_execution_time_box: gr.update(visible=False), } else: time.sleep(1) print(f"response.ok: {response.ok}, {response.json()} - Computed") return { error_box5: gr.update(visible=False), fhe_execution_time_box: gr.update(visible=True, value=f"{response.json():.2f} seconds"), } def get_output_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Retreive the encrypted data from the server. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms """ if is_none(user_id) or is_none(user_symptoms): return { error_box6: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the server has successfully processed and transmitted the data to the client.", ) } data = { "user_id": user_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: print(f"Receive Data: {response.ok=}") encrypted_output = response.content # Save the encrypted output to bytes in a file as it is too large to pass through # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" with encrypted_output_path.open("wb") as f: f.write(encrypted_output) return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"} def decrypt_fn( user_id: str, user_symptoms: np.ndarray, *checked_symptoms, threshold: int = 0.5 ) -> Dict: """Dencrypt the data on the `Client Side`. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms threshold (float): Probability confidence threshold Returns: Decrypted output """ if is_none(user_id) or is_none(user_symptoms): return { error_box7: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the client has successfully received the data from the server.", ) } # Get the encrypted output path encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" if not encrypted_output_path.is_file(): print("Error in decryption step: Please run the FHE execution, first.") return { error_box7: gr.update( visible=True, value="⚠️ Please ensure that: \n" "- the connectivity \n" "- the symptoms have been submitted \n" "- the evaluation key has been generated \n" "- the server processed the encrypted data \n" "- the Client received the data from the Server before decrypting the prediction", ), decrypt_box: None, } # Load the encrypted output as bytes with encrypted_output_path.open("rb") as f: encrypted_output = f.read() # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Deserialize, decrypt and post-process the encrypted output output = client.deserialize_decrypt_dequantize(encrypted_output) top3_diseases = np.argsort(output.flatten())[-3:][::-1] top3_proba = output[0][top3_diseases] if ( (top3_proba[0] < threshold) or (np.sum(top3_proba) < threshold) or (abs(top3_proba[0] - top3_proba[1]) < threshold) ): out = "⚠️ The prediction appears uncertain; including more symptoms may improve the results.\n\n" else: out = "" out = ( f"{out}" f"Given the symptoms you provided: {pretty_print(checked_symptoms, case_conversion=str.capitalize, delimiter=', ')}\n\n" "Here are the top3 predictions:\n\n" f"1. « {get_disease_name(top3_diseases[0])} » with a probability of {top3_proba[0]:.2%}\n" f"2. « {get_disease_name(top3_diseases[1])} » with a probability of {top3_proba[1]:.2%}\n" f"3. « {get_disease_name(top3_diseases[2])} » with a probability of {top3_proba[2]:.2%}\n" ) return { error_box7: gr.update(visible=False), decrypt_box: out, submit_btn: gr.update(value="Submit") } def reset_fn(): """Reset the space and clear all the box outputs.""" clean_directory() return { one_hot_vector_box: None, submit_btn: gr.update(value="Submit"), user_id_box: gr.update(visible=False, value=None), one_hot_vector: None, default_symptoms: gr.update(visible=True, value=None), disease_box: gr.update(visible=True, value=None), quant_vect_box: gr.update(visible=False, value=None), enc_vect_box: gr.update(visible=True, value=None), key_box: gr.update(visible=True, value=None), key_len_box: gr.update(visible=False, value=None), fhe_execution_time_box: gr.update(visible=True, value=None), decrypt_box: None, error_box7: gr.update(visible=False), error_box1: gr.update(visible=False), error_box2: gr.update(visible=False), error_box3: gr.update(visible=False), error_box4: gr.update(visible=False), error_box5: gr.update(visible=False), error_box6: gr.update(visible=False), srv_resp_send_data_box: None, srv_resp_retrieve_data_box: None, **{box: None for box in check_boxes}, } if __name__ == "__main__": print("Starting demo ...") clean_directory() (X_train, X_test), (y_train, y_test), valid_symptoms, diseases = load_data() with gr.Blocks() as demo: # Link + images gr.Markdown( """

Health Prediction On Encrypted Data Using Fully Homomorphic Encryption.

Concrete-ML Documentation Community @zama_fhe

""" ) gr.Markdown("## Notes") gr.Markdown( """ - The private key is used to encrypt and decrypt the data and shall never be shared. - The evaluation key is a public key that the server needs to process encrypted data. """ ) # ------------------------- Step 1 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 1: Select chief complaints") gr.Markdown("
") gr.Markdown("Client Side") gr.Markdown("Select at least 5 chief complaints from the list below.") # Step 1.1: Provide symptoms check_boxes = [] with gr.Row(): with gr.Column(): for category in SYMPTOMS_LIST[:3]: with gr.Accordion(pretty_print(category.keys()), open=False): check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0) check_boxes.append(check_box) with gr.Column(): for category in SYMPTOMS_LIST[3:6]: with gr.Accordion(pretty_print(category.keys()), open=False): check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0) check_boxes.append(check_box) with gr.Column(): for category in SYMPTOMS_LIST[6:]: with gr.Accordion(pretty_print(category.keys()), open=False): check_box = gr.CheckboxGroup(pretty_print(category.values()), show_label=0) check_boxes.append(check_box) error_box1 = gr.Textbox(label="Error ❌", visible=False) # Default disease, picked from the dataframe gr.Markdown("You can choose an **existing disease** and explore its associated symptoms.", visible=False) with gr.Row(): with gr.Column(scale=2): disease_box = gr.Dropdown(sorted(diseases), label="Diseases", visible=False) with gr.Column(scale=5): default_symptoms = gr.Textbox(label="Related Symptoms:", visible=False) # User vector symptoms encoded in oneHot representation one_hot_vector = gr.Textbox(visible=False) # Submit botton submit_btn = gr.Button("Submit") # Clear botton clear_button = gr.Button("Reset Space 🔁", visible=False) disease_box.change(fn=display_default_symptoms_fn, inputs=[disease_box], outputs=[default_symptoms]) submit_btn.click( fn=get_features_fn, inputs=[*check_boxes], outputs=[one_hot_vector, error_box1, submit_btn], ) # ------------------------- Step 2 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 2: Encrypt data") gr.Markdown("
") gr.Markdown("Client Side") # Step 2.1: Key generation gr.Markdown( "### Key Generation\n\n" "In FHE schemes, a secret (enc/dec)ryption keys are generated for encrypting and decrypting data owned by the client. \n\n" "Additionally, a public evaluation key is generated, enabling external entities to perform homomorphic operations on encrypted data, without the need to decrypt them. \n\n" "The evaluation key will be transmitted to the server for further processing." ) gen_key_btn = gr.Button("Generate the evaluation key") error_box2 = gr.Textbox(label="Error ❌", visible=False) user_id_box = gr.Textbox(label="User ID:", visible=True) key_len_box = gr.Textbox(label="Evaluation Key Size:", visible=False) key_box = gr.Textbox(label="Evaluation key (truncated):", max_lines=3, visible=False) gen_key_btn.click( key_gen_fn, inputs=one_hot_vector, outputs=[ key_box, user_id_box, key_len_box, error_box2, ], ) # Step 2.2: Encrypt data locally gr.Markdown("### Encrypt the data") encrypt_btn = gr.Button("Encrypt the data using the private secret key") error_box3 = gr.Textbox(label="Error ❌", visible=False) quant_vect_box = gr.Textbox(label="Quantized Vector:", visible=False) with gr.Row(): with gr.Column(): one_hot_vector_box = gr.Textbox(label="User Symptoms Vector:", max_lines=10) with gr.Column(): enc_vect_box = gr.Textbox(label="Encrypted Vector:", max_lines=10) encrypt_btn.click( encrypt_fn, inputs=[one_hot_vector, user_id_box], outputs=[ one_hot_vector_box, enc_vect_box, error_box3, ], ) # Step 2.3: Send encrypted data to the server gr.Markdown( "### Send the encrypted data to the Server Side" ) error_box4 = gr.Textbox(label="Error ❌", visible=False) with gr.Row().style(equal_height=False): with gr.Column(scale=4): send_input_btn = gr.Button("Send data") with gr.Column(scale=1): srv_resp_send_data_box = gr.Checkbox(label="Data Sent", show_label=False) send_input_btn.click( send_input_fn, inputs=[user_id_box, one_hot_vector], outputs=[error_box4, srv_resp_send_data_box], ) # ------------------------- Step 3 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 3: Run the FHE evaluation") gr.Markdown("
") gr.Markdown("Server Side") gr.Markdown( "Once the server receives the encrypted data, it can process and compute the output without ever decrypting the data just as it would on clear data.\n\n" "This server employs a [Logistic Regression]() model that has been trained on this [data-set](https://github.com/anujdutt9/Disease-Prediction-from-Symptoms/tree/master/dataset)." ) run_fhe_btn = gr.Button("Run the FHE evaluation") error_box5 = gr.Textbox(label="Error ❌", visible=False) fhe_execution_time_box = gr.Textbox(label="Total FHE Execution Time:", visible=True) run_fhe_btn.click( run_fhe_fn, inputs=[user_id_box], outputs=[fhe_execution_time_box, error_box5], ) # ------------------------- Step 4 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 4: Decrypt the data") gr.Markdown("
") gr.Markdown("Client Side") gr.Markdown("### Get the encrypted data from the Server Side") error_box6 = gr.Textbox(label="Error ❌", visible=False) # Step 4.1: Data transmission with gr.Row().style(equal_height=True): with gr.Column(scale=4): get_output_btn = gr.Button("Get data") with gr.Column(scale=1): srv_resp_retrieve_data_box = gr.Checkbox(label="Data Received", show_label=False) get_output_btn.click( get_output_fn, inputs=[user_id_box, one_hot_vector], outputs=[srv_resp_retrieve_data_box, error_box6], ) # Step 4.1: Data transmission gr.Markdown("### Decrypt the output") decrypt_btn = gr.Button("Decrypt the output using the private secret key") error_box7 = gr.Textbox(label="Error ❌", visible=False) decrypt_box = gr.Textbox(label="Decrypted Output:") decrypt_btn.click( decrypt_fn, inputs=[user_id_box, one_hot_vector, *check_boxes], outputs=[decrypt_box, error_box7, submit_btn], ) # ------------------------- End ------------------------- gr.Markdown( """The app was built with [Concrete ML](https://github.com/zama-ai/concrete-ml), a Privacy-Preserving Machine Learning (PPML) open-source set of tools by Zama. Try it yourself and don't forget to star on [Github](https://github.com/zama-ai/concrete-ml) ⭐. """) gr.Markdown("\n\n") gr.Markdown( """**Please Note**: This space is intended solely for educational and demonstration purposes. It should not be considered as a replacement for professional medical counsel, diagnosis, or therapy for any health or related issues. Any questions or concerns about your individual health should be addressed to your doctor or another qualified healthcare provider. """) clear_button.click( reset_fn, outputs=[ one_hot_vector_box, one_hot_vector, submit_btn, # disease_box, error_box1, error_box2, error_box3, error_box4, error_box5, error_box6, error_box7, disease_box, default_symptoms, user_id_box, key_len_box, key_box, quant_vect_box, enc_vect_box, srv_resp_send_data_box, srv_resp_retrieve_data_box, fhe_execution_time_box, decrypt_box, *check_boxes, ], ) demo.launch()