|
"""A local gradio app that filters images using FHE.""" |
|
|
|
import os |
|
import shutil |
|
import subprocess |
|
import time |
|
import gradio as gr |
|
import numpy |
|
import requests |
|
from itertools import chain |
|
|
|
from settings import ( |
|
REPO_DIR, |
|
SERVER_URL, |
|
FHE_KEYS, |
|
CLIENT_FILES, |
|
SERVER_FILES, |
|
DEPLOYMENT_PATH, |
|
INITIAL_INPUT_SHAPE, |
|
INPUT_INDEXES, |
|
START_POSITIONS, |
|
) |
|
|
|
from development.client_server_interface import MultiInputsFHEModelClient |
|
|
|
|
|
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) |
|
time.sleep(3) |
|
|
|
|
|
def shorten_bytes_object(bytes_object, limit=500): |
|
"""Shorten the input bytes object to a given length. |
|
|
|
Encrypted data is too large for displaying it in the browser using Gradio. This function |
|
provides a shorten representation of it. |
|
|
|
Args: |
|
bytes_object (bytes): The input to shorten |
|
limit (int): The length to consider. Default to 500. |
|
|
|
Returns: |
|
str: Hexadecimal string shorten representation of the input byte object. |
|
|
|
""" |
|
|
|
shift = 100 |
|
return bytes_object[shift : limit + shift].hex() |
|
|
|
|
|
def get_client(client_id, client_type): |
|
"""Get the client API. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party'). |
|
|
|
Returns: |
|
FHEModelClient: The client API. |
|
""" |
|
key_dir = FHE_KEYS / f"{client_type}_{client_id}" |
|
|
|
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir) |
|
|
|
|
|
def get_client_file_path(name, client_id, client_type): |
|
"""Get the correct temporary file path for the client. |
|
|
|
Args: |
|
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs'). |
|
client_id (int): The client ID to consider. |
|
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party'). |
|
|
|
Returns: |
|
pathlib.Path: The file path. |
|
""" |
|
return CLIENT_FILES / f"{name}_{client_type}_{client_id}" |
|
|
|
|
|
def clean_temporary_files(n_keys=20): |
|
"""Clean keys and encrypted images. |
|
|
|
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this |
|
limit is reached, the oldest files are deleted. |
|
|
|
Args: |
|
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. |
|
|
|
""" |
|
|
|
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) |
|
|
|
|
|
user_ids = [] |
|
if len(key_dirs) > n_keys: |
|
n_keys_to_delete = len(key_dirs) - n_keys |
|
for key_dir in key_dirs[:n_keys_to_delete]: |
|
user_ids.append(key_dir.name) |
|
shutil.rmtree(key_dir) |
|
|
|
|
|
client_files = CLIENT_FILES.iterdir() |
|
server_files = SERVER_FILES.iterdir() |
|
|
|
|
|
for file in chain(client_files, server_files): |
|
for user_id in user_ids: |
|
if user_id in file.name: |
|
file.unlink() |
|
|
|
|
|
def keygen(client_id, client_type): |
|
"""Generate the private key associated to a filter. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). |
|
""" |
|
|
|
clean_temporary_files() |
|
|
|
|
|
client = get_client(client_id, client_type) |
|
|
|
|
|
client.generate_private_and_evaluation_keys(force=True) |
|
|
|
|
|
|
|
|
|
evaluation_key = client.get_serialized_evaluation_keys() |
|
|
|
|
|
|
|
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type) |
|
|
|
with evaluation_key_path.open("wb") as evaluation_key_file: |
|
evaluation_key_file.write(evaluation_key) |
|
|
|
|
|
def send_input(client_id, client_type): |
|
"""Send the encrypted input image as well as the evaluation key to the server. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). |
|
""" |
|
|
|
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type) |
|
encrypted_input_path = get_client_file_path("encrypted_inputs", client_id, client_type) |
|
|
|
|
|
data = { |
|
"client_id": client_id, |
|
"client_type": client_type, |
|
} |
|
|
|
files = [ |
|
("files", open(encrypted_input_path, "rb")), |
|
("files", open(evaluation_key_path, "rb")), |
|
] |
|
|
|
|
|
url = SERVER_URL + "send_input" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
files=files, |
|
) as response: |
|
return response.ok |
|
|
|
|
|
def keygen_encrypt_send(inputs, client_type): |
|
"""Encrypt the given inputs for a specific client. |
|
|
|
Args: |
|
inputs (numpy.ndarray): The inputs to encrypt. |
|
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
client_id = numpy.random.randint(0, 2**32) |
|
|
|
keygen(client_id, client_type) |
|
|
|
|
|
client = get_client(client_id, client_type) |
|
|
|
|
|
|
|
|
|
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( |
|
inputs, |
|
input_index=INPUT_INDEXES[client_type], |
|
initial_input_shape=INITIAL_INPUT_SHAPE, |
|
start_position=START_POSITIONS[client_type], |
|
) |
|
|
|
|
|
|
|
encrypted_inputs_path = get_client_file_path("encrypted_inputs", client_id, client_type) |
|
|
|
with encrypted_inputs_path.open("wb") as encrypted_inputs_file: |
|
encrypted_inputs_file.write(encrypted_inputs) |
|
|
|
|
|
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) |
|
|
|
send_input(client_id, client_type) |
|
|
|
|
|
return encrypted_inputs_short |
|
|
|
|
|
def run_fhe(client_id): |
|
"""Run the model on the encrypted inputs previously sent using FHE. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
""" |
|
|
|
|
|
|
|
data = { |
|
"client_id": client_id, |
|
} |
|
|
|
|
|
url = SERVER_URL + "run_fhe" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
) as response: |
|
if response.ok: |
|
return response.json() |
|
else: |
|
raise gr.Error("Please wait for the inputs to be sent to the server.") |
|
|
|
|
|
def get_output(client_id): |
|
"""Retrieve the encrypted output. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
|
|
Returns: |
|
output_encrypted_representation (numpy.ndarray): A representation of the encrypted output. |
|
|
|
""" |
|
data = { |
|
"client_id": client_id, |
|
} |
|
|
|
|
|
url = SERVER_URL + "get_output" |
|
with requests.post( |
|
url=url, |
|
data=data, |
|
) as response: |
|
if response.ok: |
|
encrypted_output = response.content |
|
|
|
|
|
|
|
|
|
encrypted_output_path = get_client_file_path("encrypted_output", client_id, "user") |
|
|
|
with encrypted_output_path.open("wb") as encrypted_output_file: |
|
encrypted_output_file.write(encrypted_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None |
|
else: |
|
raise gr.Error("Please wait for the FHE execution to be completed.") |
|
|
|
|
|
def decrypt_output(client_id, client_type): |
|
"""Decrypt the result. |
|
|
|
Args: |
|
client_id (int): The client ID to consider. |
|
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). |
|
|
|
Returns: |
|
output(numpy.ndarray): The decrypted output |
|
|
|
""" |
|
|
|
encrypted_output_path = get_client_file_path("encrypted_output", client_id, client_type) |
|
|
|
if not encrypted_output_path.is_file(): |
|
raise gr.Error("Please run the FHE execution first.") |
|
|
|
|
|
with encrypted_output_path.open("rb") as encrypted_output_file: |
|
encrypted_output_proba = encrypted_output_file.read() |
|
|
|
|
|
client = get_client(client_id, client_type) |
|
|
|
|
|
output_proba = client.deserialize_decrypt_post_process(encrypted_output_proba) |
|
|
|
|
|
output = numpy.argmax(output_proba, axis=1) |
|
|
|
return output |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
|
|
print("Starting the demo...") |
|
with demo: |
|
gr.Markdown( |
|
""" |
|
<h1 align="center">Credit Card Approval Prediction Using Fully Homomorphic Encryption</h1> |
|
""" |
|
) |
|
|
|
gr.Markdown("## Client side") |
|
|
|
gr.Markdown("### Step 1: Infos. ") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### User") |
|
|
|
choice_1 = gr.Dropdown(choices=["Yes, No"], label="Choose", interactive=True) |
|
slide_1 = gr.Slider(2, 20, value=4, label="Count", info="Choose between 2 and 20") |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Bank ") |
|
|
|
checkbox_1 = gr.CheckboxGroup(["USA", "Japan", "Pakistan"], label="Countries", info="Where are they from?") |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Third Party ") |
|
|
|
radio_1 = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?") |
|
|
|
|
|
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### User") |
|
encrypt_button_user = gr.Button("Encrypt the inputs and send to server.") |
|
keys_user = gr.Textbox( |
|
label="Keys representation:", max_lines=2, interactive=False |
|
) |
|
encrypted_input_user = gr.Textbox( |
|
label="Encrypted input representation:", max_lines=2, interactive=False |
|
) |
|
|
|
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### Bank ") |
|
encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.") |
|
keys_bank = gr.Textbox( |
|
label="Keys representation:", max_lines=2, interactive=False |
|
) |
|
encrypted_input_bank = gr.Textbox( |
|
label="Encrypted input representation:", max_lines=2, interactive=False |
|
) |
|
|
|
bank_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### Third Party ") |
|
encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.") |
|
keys_3 = gr.Textbox( |
|
label="Keys representation:", max_lines=2, interactive=False |
|
) |
|
encrypted_input__third_party = gr.Textbox( |
|
label="Encrypted input representation:", max_lines=2, interactive=False |
|
) |
|
|
|
third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) |
|
|
|
gr.Markdown("## Server side") |
|
gr.Markdown( |
|
"The encrypted values are received by the server. The server can then compute the prediction " |
|
"directly over them. Once the computation is finished, the server returns " |
|
"the encrypted result to the client." |
|
) |
|
|
|
gr.Markdown("### Step 6: Run FHE execution.") |
|
execute_fhe_button = gr.Button("Run FHE execution.") |
|
fhe_execution_time = gr.Textbox( |
|
label="Total FHE execution time (in seconds):", max_lines=1, interactive=False |
|
) |
|
|
|
gr.Markdown("## Client side") |
|
gr.Markdown( |
|
"The encrypted output is sent back to the client, who can finally decrypt it with the " |
|
"private key." |
|
) |
|
|
|
gr.Markdown("### Step 7: Receive the encrypted output from the server.") |
|
gr.Markdown( |
|
"The output displayed here is the encrypted result sent by the server, which has been " |
|
"decrypted using a different private key. This is only used to visually represent an " |
|
"encrypted output." |
|
) |
|
get_output_button = gr.Button("Receive the encrypted output from the server.") |
|
|
|
encrypted_output_representation = gr.Textbox( |
|
label="Encrypted output representation: ", max_lines=1, interactive=False |
|
) |
|
|
|
gr.Markdown("### Step 8: Decrypt the output.") |
|
decrypt_button = gr.Button("Decrypt the output") |
|
|
|
prediction_output = gr.Textbox( |
|
label="Credit card approval decision: ", max_lines=1, interactive=False |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
"The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " |
|
"Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " |
|
"Try it yourself and don't forget to star on Github ⭐." |
|
) |
|
|
|
demo.launch(share=False) |
|
|