import subprocess import time from typing import Dict, List, Tuple import gradio as gr # pylint: disable=import-error import numpy as np import pandas as pd import requests from symptoms_categories import SYMPTOMS_LIST from utils import ( CLIENT_DIR, CURRENT_DIR, DEPLOYMENT_DIR, INPUT_BROWSER_LIMIT, KEYS_DIR, SERVER_URL, TARGET_COLUMNS, TRAINING_FILENAME, clean_directory, get_disease_name, load_data, pretty_print, ) from concrete.ml.deployment import FHEModelClient subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) time.sleep(3) # pylint: disable=c-extension-no-member,invalid-name def is_none(obj) -> bool: """ Check if the object is None. Args: obj (any): The input to be checked. Returns: bool: True if the object is None or empty, False otherwise. """ return obj is None or (obj is not None and len(obj) < 1) def display_default_symptoms_fn(default_disease: str) -> Dict: """ Displays the symptoms of a given existing disease. Args: default_disease (str): Disease Returns: Dict: The according symptoms """ df = pd.read_csv(TRAINING_FILENAME) df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease] return { default_symptoms: gr.update( visible=True, value=pretty_print( df_filtred.columns[df_filtred.eq(1).any()].to_list(), delimiter=", " ), ) } def get_user_symptoms_from_checkboxgroup(checkbox_symptoms: List) -> np.array: """ Convert the user symptoms into a binary vector representation. Args: checkbox_symptoms (List): A list of user symptoms. Returns: np.array: A binary vector representing the user's symptoms. Raises: KeyError: If a provided symptom is not recognized as a valid symptom. """ symptoms_vector = {key: 0 for key in valid_symptoms} for pretty_symptom in checkbox_symptoms: original_symptom = "_".join((pretty_symptom.lower().split(" "))) if original_symptom not in symptoms_vector.keys(): raise KeyError( f"The symptom '{original_symptom}' you provided is not recognized as a valid " f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}" ) symptoms_vector[original_symptom] = 1 user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :] assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten()) return user_symptoms_vect def get_features_fn(*checked_symptoms: Tuple[str]) -> Dict: """ Get vector features based on the selected symptoms. Args: checked_symptoms (Tuple[str]): User symptoms Returns: Dict: The encoded user vector symptoms. """ if not any(lst for lst in checked_symptoms if lst): return { error_box1: gr.update(visible=True, value="⚠️ Please provide your chief complaints."), } if len(pretty_print(checked_symptoms)) < 5: print("Provide at least 5 symptoms.") return { error_box1: gr.update(visible=True, value="⚠️ Provide at least 5 symptoms"), one_hot_vect: None, } return { error_box1: gr.update(visible=False), one_hot_vect: gr.update( visible=False, value=get_user_symptoms_from_checkboxgroup(pretty_print(checked_symptoms)), ), submit_btn: gr.update(value="Data submitted ✅"), } def key_gen_fn(user_symptoms: List[str]) -> Dict: """ Generate keys for a given user. Args: user_symptoms (List[str]): The vector symptoms provided by the user. Returns: dict: A dictionary containing the generated keys and related information. """ clean_directory() if is_none(user_symptoms): print("Error: Please submit your symptoms or select a default disease.") return { error_box2: gr.update(visible=True, value="⚠️ Please submit your symptoms first."), } # Generate a random user ID user_id = np.random.randint(0, 2**32) print(f"Your user ID is: {user_id}....") client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Creates the private and evaluation keys on the client side client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # Save the evaluation key evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box2: gr.update(visible=False), key_box: gr.update(visible=False, value=serialized_evaluation_keys_shorten_hex), user_id_box: gr.update(visible=True, value=user_id), key_len_box: gr.update( visible=False, value=f"{len(serialized_evaluation_keys) / (10**6):.2f} MB" ), } def encrypt_fn(user_symptoms: np.ndarray, user_id: str) -> None: """ Encrypt the user symptoms vector in the `Client Side`. Args: user_symptoms (List[str]): The vector symptoms provided by the user user_id (user): The current user's ID """ if is_none(user_id) or is_none(user_symptoms): print("Error in encryption step: Provide your symptoms and generate the evaluation keys.") return { error_box3: gr.update( visible=True, value="⚠️ Please ensure that your symptoms have been submitted and " "that you have generated the evaluation key.", ) } # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1) # quant_user_symptoms = client.model.quantize_input(user_symptoms) encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms) assert isinstance(encrypted_quantized_user_symptoms, bytes) encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_symptoms) encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box3: gr.update(visible=False), one_hot_vect_box: gr.update(visible=True, value=user_symptoms), enc_vect_box: gr.update(visible=True, value=encrypted_quantized_user_symptoms_shorten_hex), } def send_input_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms """ if is_none(user_id) or is_none(user_symptoms): return { error_box4: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the symptoms have been submitted and the evaluation " "key has been generated before sending the data to the server.", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return { error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.") } if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="⚠️ Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "input": user_symptoms, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return { error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent", } def run_fhe_fn(user_id: str) -> Dict: """Send the encrypted input and the evaluation key to the server. Args: user_id (int): The current user's ID. """ if is_none(user_id): return { error_box5: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the symptoms have been submitted, the evaluation " "key has been generated and the server received the data " "before processing the data.", ), fhe_execution_time_box: None, } data = { "user_id": user_id, } url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if not response.ok: return { error_box5: gr.update( visible=True, value=( "⚠️ An error occurred on the Server Side. " "Please check connectivity and data transmission." ), ), fhe_execution_time_box: gr.update(visible=False), } else: time.sleep(1) print(f"response.ok: {response.ok}, {response.json()} - Computed") return { error_box5: gr.update(visible=False), fhe_execution_time_box: gr.update(visible=True, value=f"{response.json():.2f} seconds"), } def get_output_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Retreive the encrypted data from the server. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms """ if is_none(user_id) or is_none(user_symptoms): return { error_box6: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the server has successfully processed and transmitted the data to the client.", ) } data = { "user_id": user_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: print(f"Receive Data: {response.ok=}") encrypted_output = response.content # Save the encrypted output to bytes in a file as it is too large to pass through # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" with encrypted_output_path.open("wb") as f: f.write(encrypted_output) return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"} def decrypt_fn( user_id: str, user_symptoms: np.ndarray, *checked_symptoms, threshold: int = 0.5 ) -> Dict: """Dencrypt the data on the `Client Side`. Args: user_id (str): The current user's ID user_symptoms (np.ndarray): The user symptoms threshold (float): Probability confidence threshold Returns: Decrypted output """ if is_none(user_id) or is_none(user_symptoms): return { error_box7: gr.update( visible=True, value="⚠️ Please check your connectivity \n" "⚠️ Ensure that the client has successfully received the data from the server.", ) } # Get the encrypted output path encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" if not encrypted_output_path.is_file(): print("Error in decryption step: Please run the FHE execution, first.") return { error_box7: gr.update( visible=True, value="⚠️ Please ensure that: \n" "- the connectivity \n" "- the symptoms have been submitted \n" "- the evaluation key has been generated \n" "- the server processed the encrypted data \n" "- the Client received the data from the Server before decrypting the prediction", ), decrypt_box: None, } # Load the encrypted output as bytes with encrypted_output_path.open("rb") as f: encrypted_output = f.read() # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Deserialize, decrypt and post-process the encrypted output output = client.deserialize_decrypt_dequantize(encrypted_output) top3_diseases = np.argsort(output.flatten())[-3:][::-1] top3_proba = output[0][top3_diseases] out = "" if top3_proba[0] < threshold or abs(top3_proba[0] - top3_proba[1]) < 0.1: out = ( "⚠️ The prediction appears uncertain; including more symptoms " "may improve the results.\n\n" ) out = ( f"{out}Given the symptoms you provided: " f"{pretty_print(checked_symptoms, case_conversion=str.capitalize, delimiter=', ')}\n\n" "Here are the top3 predictions:\n\n" f"1. « {get_disease_name(top3_diseases[0])} » with a probability of {top3_proba[0]:.2%}\n" f"2. « {get_disease_name(top3_diseases[1])} » with a probability of {top3_proba[1]:.2%}\n" f"3. « {get_disease_name(top3_diseases[2])} » with a probability of {top3_proba[2]:.2%}\n" ) return { error_box7: gr.update(visible=False), decrypt_box: out, submit_btn: gr.update(value="Submit"), } def reset_fn(): """Reset the space and clear all the box outputs.""" clean_directory() return { one_hot_vect: None, one_hot_vect_box: None, enc_vect_box: gr.update(visible=True, value=None), quant_vect_box: gr.update(visible=False, value=None), user_id_box: gr.update(visible=False, value=None), default_symptoms: gr.update(visible=True, value=None), default_disease_box: gr.update(visible=True, value=None), key_box: gr.update(visible=True, value=None), key_len_box: gr.update(visible=False, value=None), fhe_execution_time_box: gr.update(visible=True, value=None), decrypt_box: None, submit_btn: gr.update(value="Submit"), error_box7: gr.update(visible=False), error_box1: gr.update(visible=False), error_box2: gr.update(visible=False), error_box3: gr.update(visible=False), error_box4: gr.update(visible=False), error_box5: gr.update(visible=False), error_box6: gr.update(visible=False), srv_resp_send_data_box: None, srv_resp_retrieve_data_box: None, **{box: None for box in check_boxes}, } if __name__ == "__main__": print("Starting demo ...") clean_directory() (X_train, X_test), (y_train, y_test), valid_symptoms, diseases = load_data() with gr.Blocks() as demo: # Link + images gr.Markdown( """
Concrete-ML — Documentation — Community — @zama_fhe
""" ) gr.Markdown("## Notes") gr.Markdown( """ - The private key is used to encrypt and decrypt the data and shall never be shared. - The evaluation key is a public key that the server needs to process encrypted data. """ ) # ------------------------- Step 1 ------------------------- gr.Markdown("\n") gr.Markdown("## Step 1: Select chief complaints") gr.Markdown("