diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..25adca7b9eb27f9d10f3c11258498cd4b54e4117 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.venv +.playground/ +.artifacts +.fhe_keys +server_tmp/ +client_tmp/ +.artifacts diff --git a/README.md b/README.md index 7b4235e841616fe1442c601c760c1deb42a54673..70ade190f8ab3617cf8c1c6a0a89dff802a2aefb 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,76 @@ --- -title: Encrypted Image Filtering -emoji: 🐨 +title: Image Filtering on Encrypted Images using FHE +emoji: 🥷💬 colorFrom: yellow -colorTo: pink +colorTo: yellow sdk: gradio -sdk_version: 3.16.1 +sdk_version: 3.2 app_file: app.py -pinned: false +pinned: true +tags: [FHE, PPML, privacy, privacy preserving machine learning, homomorphic encryption, + security] +python_version: 3.8.15 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +# Image filtering using FHE + +## Running the application on your machine + +In this directory, ie `image_filtering`, you can do the following steps. + +### Do once + +First, create a virtual env and activate it: + + + +```bash +python3 -m venv .venv +source .venv/bin/activate +``` + +Then, install required packages: + + + +```bash +pip3 install -U pip wheel setuptools --ignore-installed +pip3 install -r requirements.txt --ignore-installed +``` + +If not on Linux, or if you want to compile the FHE filters by yourself: + + + +```bash +python3 compile.py +``` + +Check it finish well (with a "Done!"). + +It is also possible to manually add some new filters in `filters.py`. Yet, in order to be able to use +them interactively in the app, you first need to update the `AVAILABLE_FILTERS` list found in `common.py` +and then compile them by running : + + + +```bash +python3 generate_dev_filters.py +``` + +## Run the following steps each time you relaunch the application + +In a terminal, run: + + + +```bash +source .venv/bin/activate +python3 app.py +``` + +## Interacting with the application + +Open the given URL link (search for a line like `Running on local URL: http://127.0.0.1:8888/`). diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..91e17a4cdd9dd93b9851dcd1332d41811b6a4ae9 --- /dev/null +++ b/app.py @@ -0,0 +1,463 @@ +"""A local gradio app that filters images using FHE.""" + +import os +import shutil +import subprocess +import time + +import gradio as gr +import numpy +import requests +from common import ( + AVAILABLE_FILTERS, + CLIENT_TMP_PATH, + EXAMPLES, + FILTERS_PATH, + INPUT_SHAPE, + KEYS_PATH, + REPO_DIR, + SERVER_URL, +) +from custom_client_server import CustomFHEClient + +# Uncomment here to have both the server and client in the same terminal +subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) +time.sleep(3) + + +def shorten_bytes_object(bytes_object, limit=500): + """Shorten the input bytes object to a given length. + + Encrypted data is too large for displaying it in the browser using Gradio. This function + provides a shorten representation of it. + + Args: + bytes_object (bytes): The input to shorten + limit (int): The length to consider. Default to 500. + + Returns: + Any: The fitted model. + + """ + # Define a shift for better display + shift = 100 + return bytes_object[shift : limit + shift].hex() + + +def get_client(user_id, image_filter): + """Get the client API. + + Args: + user_id (int): The current user's ID. + image_filter (str): The filter chosen by the user + + Returns: + CustomFHEClient: The client API. + """ + return CustomFHEClient( + FILTERS_PATH / f"{image_filter}/deployment", KEYS_PATH / f"{image_filter}_{user_id}" + ) + + +def get_client_file_path(name, user_id, image_filter): + """Get the correct temporary file path for the client. + + Args: + name (str): The desired file name. + user_id (int): The current user's ID. + image_filter (str): The filter chosen by the user + + Returns: + pathlib.Path: The file path. + """ + return CLIENT_TMP_PATH / f"{name}_{image_filter}_{user_id}" + + +def clean_temporary_files(n_keys=20): + """Clean keys and encrypted images. + + A maximum of n_keys keys are allowed to be stored. Once this limit is reached, the oldest are + deleted. + + Args: + n_keys (int): The maximum number of keys to be stored. Default to 20. + + """ + # Get the oldest files in the key directory + list_files = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime) + + # If more than n_keys keys are found, remove the oldest + user_ids = [] + if len(list_files) > n_keys: + n_files_to_delete = len(list_files) - n_keys + for p in list_files[:n_files_to_delete]: + user_ids.append(p.name) + shutil.rmtree(p) + + # Get all the encrypted objects in the temporary folder + list_files_tmp = CLIENT_TMP_PATH.iterdir() + + # Delete all files related to the current user + for file in list_files_tmp: + for user_id in user_ids: + if file.name.endswith(f"{user_id}.npy"): + file.unlink() + + +def keygen(image_filter): + """Generate the private key associated to a filter. + + Args: + image_filter (str): The current filter to consider. + + Returns: + (user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display. + + """ + # Clean temporary files + clean_temporary_files() + + # Create an ID for the current user + user_id = numpy.random.randint(0, 2**32) + + # Retrieve the client API + # Currently, the key generation needs to be done after choosing a filter + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2258 + client = get_client(user_id, image_filter) + + # Generate a private key + client.generate_private_and_evaluation_keys(force=True) + + # Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this + # evaluation key is empty. However, for software reasons, it is still needed for proper FHE + # execution + evaluation_key = client.get_serialized_evaluation_keys() + + # Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio + # buttons (see https://github.com/gradio-app/gradio/issues/1877) + evaluation_key_path = get_client_file_path("evaluation_key", user_id, image_filter) + + with evaluation_key_path.open("wb") as evaluation_key_file: + evaluation_key_file.write(evaluation_key) + + return (user_id, True) + + +def encrypt(user_id, input_image, image_filter): + """Encrypt the given image for a specific user and filter. + + Args: + user_id (int): The current user's ID. + input_image (numpy.ndarray): The image to encrypt. + image_filter (str): The current filter to consider. + + Returns: + (input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its + representation. + + """ + if user_id == "": + raise gr.Error("Please generate the private key first.") + + # Retrieve the client API + client = get_client(user_id, image_filter) + + # Pre-process, encrypt and serialize the image + encrypted_image = client.pre_process_encrypt_serialize(input_image) + + # Save encrypted_image to bytes in a file, since too large to pass through regular Gradio + # buttons, https://github.com/gradio-app/gradio/issues/1877 + encrypted_image_path = get_client_file_path("encrypted_image", user_id, image_filter) + + with encrypted_image_path.open("wb") as encrypted_image_file: + encrypted_image_file.write(encrypted_image) + + # Create a truncated version of the encrypted image for display + encrypted_image_short = shorten_bytes_object(encrypted_image) + + return (input_image, encrypted_image_short) + + +def send_input(user_id, image_filter): + """Send the encrypted input image as well as the evaluation key to the server. + + Args: + user_id (int): The current user's ID. + image_filter (str): The current filter to consider. + """ + # Get the evaluation key path + evaluation_key_path = get_client_file_path("evaluation_key", user_id, image_filter) + + if user_id == "" or not evaluation_key_path.is_file(): + raise gr.Error("Please generate the private key first.") + + encrypted_input_path = get_client_file_path("encrypted_image", user_id, image_filter) + + if not encrypted_input_path.is_file(): + raise gr.Error("Please generate the private key and then encrypt an image first.") + + # Define the data and files to post + data = { + "user_id": user_id, + "filter": image_filter, + } + + files = [ + ("files", open(encrypted_input_path, "rb")), + ("files", open(evaluation_key_path, "rb")), + ] + + # Send the encrypted input image and evaluation key to the server + url = SERVER_URL + "send_input" + with requests.post( + url=url, + data=data, + files=files, + ) as response: + return response.ok + + +def run_fhe(user_id, image_filter): + """Apply the filter on the encrypted image previously sent using FHE. + + Args: + user_id (int): The current user's ID. + image_filter (str): The current filter to consider. + """ + data = { + "user_id": user_id, + "filter": image_filter, + } + + # Trigger the FHE execution on the encrypted image previously sent + url = SERVER_URL + "run_fhe" + with requests.post( + url=url, + data=data, + ) as response: + if response.ok: + return response.json() + else: + raise gr.Error("Please wait for the input image to be sent to the server.") + + +def get_output(user_id, image_filter): + """Retrieve the encrypted output image. + + Args: + user_id (int): The current user's ID. + image_filter (str): The current filter to consider. + + Returns: + encrypted_output_image_short (bytes): A representation of the encrypted result. + + """ + data = { + "user_id": user_id, + "filter": image_filter, + } + + # Retrieve the encrypted output image + url = SERVER_URL + "get_output" + with requests.post( + url=url, + data=data, + ) as response: + if response.ok: + # Save the encrypted output to bytes in a file as it is too large to pass through regular + # Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) + encrypted_output_path = get_client_file_path("encrypted_output", user_id, image_filter) + + with encrypted_output_path.open("wb") as encrypted_output_file: + encrypted_output_file.write(response.content) + + # Create a truncated version of the encrypted output for display + encrypted_output_image_short = shorten_bytes_object(response.content) + + return encrypted_output_image_short + else: + raise gr.Error("Please wait for the FHE execution to be completed.") + + +def decrypt_output(user_id, image_filter): + """Decrypt the result. + + Args: + user_id (int): The current user's ID. + image_filter (str): The current filter to consider. + + Returns: + (output_image, False, False) ((Tuple[numpy.ndarray, bool, bool]): The decrypted output, as + well as two booleans used for resetting Gradio checkboxes + + """ + if user_id == "": + raise gr.Error("Please generate the private key first.") + + # Get the encrypted output path + encrypted_output_path = get_client_file_path("encrypted_output", user_id, image_filter) + + if not encrypted_output_path.is_file(): + raise gr.Error("Please run the FHE execution first.") + + # Load the encrypted output as bytes + with encrypted_output_path.open("rb") as encrypted_output_file: + encrypted_output_image = encrypted_output_file.read() + + # Retrieve the client API + client = get_client(user_id, image_filter) + + # Deserialize, decrypt and post-process the encrypted output + output_image = client.deserialize_decrypt_post_process(encrypted_output_image) + + return output_image, False, False + + +demo = gr.Blocks() + + +print("Starting the demo...") +with demo: + gr.Markdown( + """ +
+
++
+ """ + ) + + gr.Markdown("## Client side") + gr.Markdown( + f"Step 1. Upload an image. It will automatically be resized to shape ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]})." + "The image is however displayed using its original resolution." + ) + with gr.Row(): + input_image = gr.Image( + label="Upload an image here.", shape=INPUT_SHAPE, source="upload", interactive=True + ) + + examples = gr.Examples( + examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use." + ) + + gr.Markdown("Step 2. Choose your filter") + image_filter = gr.Dropdown( + choices=AVAILABLE_FILTERS, value="inverted", label="Choose your filter", interactive=True + ) + + gr.Markdown("### Notes") + gr.Markdown( + """ + - The private key is used to encrypt and decrypt the data and shall never be shared. + - No public key are required for these filter operators. + """ + ) + + with gr.Row(): + keygen_button = gr.Button("Step 3. Generate the private key.") + + keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False) + + with gr.Row(): + encrypt_button = gr.Button("Step 4. Encrypt the image using FHE.") + + user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) + + # Display an image representation + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265 + encrypted_image = gr.Textbox( + label="Encrypted image representation:", max_lines=2, interactive=False + ) + + gr.Markdown("## Server side") + gr.Markdown( + "The encrypted value is received by the server. The server can then compute the filter " + "directly over encrypted values. Once the computation is finished, the server returns " + "the encrypted results to the client." + ) + + with gr.Row(): + send_input_button = gr.Button("Step 5. Send the encrypted image to the server.") + + send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False) + + with gr.Row(): + execute_fhe_button = gr.Button("Step 6. Run FHE execution") + + fhe_execution_time = gr.Textbox( + label="Total FHE execution time (in seconds).", max_lines=1, interactive=False + ) + + with gr.Row(): + get_output_button = gr.Button("Step 7. Receive the encrypted output image from the server.") + + # Display an image representation + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265 + encrypted_output_image = gr.Textbox( + label="Encrypted output image representation:", max_lines=2, interactive=False + ) + + gr.Markdown("## Client side") + gr.Markdown( + "The encrypted output is sent back to client, who can finally decrypt it with its " + "private key. Only the client is aware of the original image and its transformed version." + ) + + decrypt_button = gr.Button("Step 8. Decrypt the output") + + # Final input vs output display + with gr.Row(): + original_image = gr.Image( + input_image.value, + label=f"Input image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):", + interactive=False, + ) + original_image.style(height=256, width=256) + + output_image = gr.Image( + label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):", interactive=False + ) + output_image.style(height=256, width=256) + + # Button to generate the private key + keygen_button.click( + keygen, + inputs=[image_filter], + outputs=[user_id, keygen_checkbox], + ) + + # Button to encrypt inputs on the client side + encrypt_button.click( + encrypt, + inputs=[user_id, input_image, image_filter], + outputs=[original_image, encrypted_image], + ) + + # Button to send the encodings to the server using post method + send_input_button.click( + send_input, inputs=[user_id, image_filter], outputs=[send_input_checkbox] + ) + + # Button to send the encodings to the server using post method + execute_fhe_button.click(run_fhe, inputs=[user_id, image_filter], outputs=[fhe_execution_time]) + + # Button to send the encodings to the server using post method + get_output_button.click( + get_output, inputs=[user_id, image_filter], outputs=[encrypted_output_image] + ) + + # Button to decrypt the output on the client side + decrypt_button.click( + decrypt_output, + inputs=[user_id, image_filter], + outputs=[output_image, keygen_checkbox, send_input_checkbox], + ) + + gr.Markdown( + "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " + "Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " + "Try it yourself and don't forget to star on Github ⭐." + ) + +demo.launch(share=False) diff --git a/common.py b/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2357a18fa77598eedf35612e72ab12e8a53cf95c --- /dev/null +++ b/common.py @@ -0,0 +1,54 @@ +"All the constants used in this repo." + +from pathlib import Path + +import numpy as np +from PIL import Image + +# The repository's directory +REPO_DIR = Path(__file__).parent + +# The repository's main directories +FILTERS_PATH = REPO_DIR / "filters" +KEYS_PATH = REPO_DIR / ".fhe_keys" +CLIENT_TMP_PATH = REPO_DIR / "client_tmp" +SERVER_TMP_PATH = REPO_DIR / "server_tmp" + +# Create the directories if it does not exist yet +KEYS_PATH.mkdir(exist_ok=True) +CLIENT_TMP_PATH.mkdir(exist_ok=True) +SERVER_TMP_PATH.mkdir(exist_ok=True) + +# All the filters currently available in the app +AVAILABLE_FILTERS = [ + "identity", + "inverted", + "rotate", + "black and white", + "blur", + "sharpen", + "ridge detection", +] + +# The input image's shape. Images with larger input shapes will be cropped and/or resized to this +INPUT_SHAPE = (100, 100) + +# Generate random images as an inputset for compilation +np.random.seed(42) +INPUTSET = tuple( + np.random.randint(0, 255, size=(INPUT_SHAPE + (3,)), dtype=np.int64) for _ in range(10) +) + + +def load_image(image_path): + image = Image.open(image_path).convert("RGB").resize(INPUT_SHAPE) + image = np.asarray(image, dtype="int64") + return image + + +_INPUTSET_DIR = REPO_DIR / "input_examples" + +# List of all image examples suggested in the app +EXAMPLES = [str(image) for image in _INPUTSET_DIR.glob("**/*")] + +SERVER_URL = "http://localhost:8000/" diff --git a/compile.py b/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..04ab50f49d12f10b1fd7f3b560ab11f078f10bca --- /dev/null +++ b/compile.py @@ -0,0 +1,47 @@ +"A script to manually compile all filters." + +import json +import shutil + +import numpy as np +import onnx +from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET, KEYS_PATH +from custom_client_server import CustomFHEClient, CustomFHEDev + +print("Starting compiling the filters.") + +for image_filter in AVAILABLE_FILTERS: + print("\nCompiling filter:", image_filter) + + # Load the onnx model + onnx_model = onnx.load(FILTERS_PATH / f"{image_filter}/server.onnx") + + deployment_path = FILTERS_PATH / f"{image_filter}/deployment" + + # Retrieve the client API related to the current filter + model = CustomFHEClient(deployment_path, KEYS_PATH).model + + image_shape = INPUT_SHAPE + (3,) + + # Compile the model using the loaded onnx model + model.compile(INPUTSET, onnx_model=onnx_model) + + processing_json_path = deployment_path / "serialized_processing.json" + + # Load the serialized_processing.json file + with open(processing_json_path, "r") as f: + serialized_processing = json.load(f) + + # Delete the deployment folder and its content if it exist + if deployment_path.is_dir(): + shutil.rmtree(deployment_path) + + # Save the files needed for deployment + fhe_api = CustomFHEDev(model=model, path_dir=deployment_path) + fhe_api.save() + + # Write the serialized_processing.json file to the deployment folder + with open(processing_json_path, "w") as f: + json.dump(serialized_processing, f) + +print("Done!") diff --git a/custom_client_server.py b/custom_client_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b597859c26bd281159c187ec2d0f702c9244bdbe --- /dev/null +++ b/custom_client_server.py @@ -0,0 +1,204 @@ +"Client-server interface implementation for custom models." + +from pathlib import Path +from typing import Any + +import concrete.numpy as cnp +import numpy as np +from filters import Filter + +from concrete.ml.common.debugging.custom_assert import assert_true + + +class CustomFHEDev: + """Dev API to save the custom model and then load and run the FHE circuit.""" + + model: Any = None + + def __init__(self, path_dir: str, model: Any = None): + """Initialize the FHE API. + + Args: + path_dir (str): the path to the directory where the circuit is saved + model (Any): the model to use for the FHE API + """ + + self.path_dir = Path(path_dir) + self.model = model + + # Create the directory path if it does not exist yet + Path(self.path_dir).mkdir(parents=True, exist_ok=True) + + def save(self): + """Export all needed artifacts for the client and server. + + Raises: + Exception: path_dir is not empty + """ + # Check if the path_dir is empty with pathlib + listdir = list(Path(self.path_dir).glob("**/*")) + if len(listdir) > 0: + raise Exception( + f"path_dir: {self.path_dir} is not empty." + "Please delete it before saving a new model." + ) + + assert_true( + hasattr(self.model, "fhe_circuit"), + "The model must be compiled and have a fhe_circuit object", + ) + + # Model must be compiled with jit=False + # In a jit model, everything is in memory so it is not serializable. + assert_true( + not self.model.fhe_circuit.configuration.jit, + "The model must be compiled with the configuration option jit=False.", + ) + + # Export the parameters + self.model.to_json(path_dir=self.path_dir, file_name="serialized_processing") + + # Save the circuit for the server + path_circuit_server = self.path_dir / "server.zip" + self.model.fhe_circuit.server.save(path_circuit_server) + + # Save the circuit for the client + path_circuit_client = self.path_dir / "client.zip" + self.model.fhe_circuit.client.save(path_circuit_client) + + +class CustomFHEClient: + """Client API to encrypt and decrypt FHE data.""" + + client: cnp.Client + + def __init__(self, path_dir: str, key_dir: str = None): + """Initialize the FHE API. + + Args: + path_dir (str): the path to the directory where the circuit is saved + key_dir (str): the path to the directory where the keys are stored + """ + self.path_dir = Path(path_dir) + self.key_dir = Path(key_dir) + + # If path_dir does not exist, raise an error + assert_true( + Path(path_dir).exists(), f"{path_dir} does not exist. Please specify a valid path." + ) + + # Load + self.load() + + def load(self): # pylint: disable=no-value-for-parameter + """Load the parameters along with the FHE specs.""" + + # Load the client + self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir) + + # Load the model + self.model = Filter.from_json(self.path_dir / "serialized_processing.json") + + def generate_private_and_evaluation_keys(self, force=False): + """Generate the private and evaluation keys. + + Args: + force (bool): if True, regenerate the keys even if they already exist + """ + self.client.keygen(force) + + def get_serialized_evaluation_keys(self) -> cnp.EvaluationKeys: + """Get the serialized evaluation keys. + + Returns: + cnp.EvaluationKeys: the evaluation keys + """ + return self.client.evaluation_keys.serialize() + + def pre_process_encrypt_serialize(self, x: np.ndarray) -> cnp.PublicArguments: + """Encrypt and serialize the values. + + Args: + x (numpy.ndarray): the values to encrypt and serialize + + Returns: + cnp.PublicArguments: the encrypted and serialized values + """ + # Pre-process the values + x = self.model.pre_processing(x) + + # Encrypt the values + enc_x = self.client.encrypt(x) + + # Serialize the encrypted values to be sent to the server + serialized_enc_x = self.client.specs.serialize_public_args(enc_x) + return serialized_enc_x + + def deserialize_decrypt_post_process( + self, serialized_encrypted_output: cnp.PublicArguments + ) -> np.ndarray: + """Deserialize, decrypt and post-process the values. + + Args: + serialized_encrypted_output (cnp.PublicArguments): the serialized and encrypted output + + Returns: + numpy.ndarray: the decrypted values + """ + # Deserialize the encrypted values + deserialized_encrypted_output = self.client.specs.unserialize_public_result( + serialized_encrypted_output + ) + + # Decrypt the values + deserialized_decrypted_output = self.client.decrypt(deserialized_encrypted_output) + + # Apply the model post processing + deserialized_decrypted_output = self.model.post_processing(deserialized_decrypted_output) + return deserialized_decrypted_output + + +class CustomFHEServer: + """Server API to load and run the FHE circuit.""" + + server: cnp.Server + + def __init__(self, path_dir: str): + """Initialize the FHE API. + + Args: + path_dir (str): the path to the directory where the circuit is saved + """ + + self.path_dir = Path(path_dir) + + # Load the FHE circuit + self.load() + + def load(self): + """Load the circuit.""" + self.server = cnp.Server.load(self.path_dir / "server.zip") + + def run( + self, + serialized_encrypted_data: cnp.PublicArguments, + serialized_evaluation_keys: cnp.EvaluationKeys, + ) -> cnp.PublicResult: + """Run the model on the server over encrypted data. + + Args: + serialized_encrypted_data (cnp.PublicArguments): the encrypted and serialized data + serialized_evaluation_keys (cnp.EvaluationKeys): the serialized evaluation keys + + Returns: + cnp.PublicResult: the result of the model + """ + assert_true(self.server is not None, "Model has not been loaded.") + + deserialized_encrypted_data = self.server.client_specs.unserialize_public_args( + serialized_encrypted_data + ) + deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys) + result = self.server.run(deserialized_encrypted_data, deserialized_evaluation_keys) + serialized_result = self.server.client_specs.serialize_public_result(result) + return serialized_result diff --git a/filters.py b/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..36ed0a71d10465187aaeee11a800f13adeba67d6 --- /dev/null +++ b/filters.py @@ -0,0 +1,359 @@ +"Filter definitions, with pre-processing, post-processing and compilation methods." + +import json + +import numpy as np +import torch +from common import AVAILABLE_FILTERS +from concrete.numpy.compilation.compiler import Compiler +from torch import nn + +from concrete.ml.common.debugging.custom_assert import assert_true +from concrete.ml.common.utils import generate_proxy_function +from concrete.ml.onnx.convert import get_equivalent_numpy_forward +from concrete.ml.torch.numpy_module import NumpyModule +from concrete.ml.version import __version__ as CML_VERSION + +# Add a "black and white" filter +# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2277 + + +class _TorchIdentity(nn.Module): + """Torch identity model.""" + + def forward(self, x): + """Identity forward pass. + + Args: + x (torch.Tensor): The input image. + + Returns: + x (torch.Tensor): The input image. + """ + return x + + +class _TorchInverted(nn.Module): + """Torch inverted model.""" + + def forward(self, x): + """Forward pass for inverting an image's colors. + + Args: + x (torch.Tensor): The input image. + + Returns: + torch.Tensor: The (color) inverted image. + """ + return 255 - x + + +class _TorchRotate(nn.Module): + """Torch rotated model.""" + + def forward(self, x): + """Forward pass for rotating an image. + + Args: + x (torch.Tensor): The input image. + + Returns: + torch.Tensor: The rotated image. + """ + return x.transpose(2, 3) + + +class _TorchConv2D(nn.Module): + """Torch model for applying a single 2D convolution operator on images.""" + + def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1): + """Initializing the filter + + Args: + kernel (np.ndarray): The convolution kernel to consider. + """ + super().__init__() + self.kernel = kernel + self.n_out_channels = n_out_channels + self.n_in_channels = n_in_channels + self.groups = groups + + def forward(self, x): + """Forward pass for filtering the image using a 2D kernel. + + Args: + x (torch.Tensor): The input image. + + Returns: + torch.Tensor: The filtered image. + + """ + # Define the convolution parameters + stride = 1 + kernel_shape = self.kernel.shape + + # Ensure the kernel has a proper shape + # If the kernel has a 1D shape, a (1, 1) kernel is used for each in_channels + if len(kernel_shape) == 1: + kernel = self.kernel.reshape( + self.n_out_channels, + self.n_in_channels // self.groups, + 1, + 1, + ) + + # Else, if the kernel has a 2D shape, a single (Kw, Kh) kernel is used on all in_channels + elif len(kernel_shape) == 2: + kernel = self.kernel.expand( + self.n_out_channels, + self.n_in_channels // self.groups, + kernel_shape[0], + kernel_shape[1], + ) + else: + raise ValueError( + "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape " + f"{kernel_shape}" + ) + + return nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups) + + +class Filter: + """Filter class used in the app.""" + + def __init__(self, image_filter="inverted"): + """Initializing the filter class using a given filter. + + Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing). + + Args: + image_filter (str): The filter to consider. Default to "inverted". + """ + + assert_true( + image_filter in AVAILABLE_FILTERS, + f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, " + f"but got {image_filter}", + ) + + self.filter = image_filter + self.divide = None + self.repeat_out_channels = False + + if image_filter == "identity": + self.torch_model = _TorchIdentity() + + elif image_filter == "inverted": + self.torch_model = _TorchInverted() + + elif image_filter == "rotate": + self.torch_model = _TorchRotate() + + elif image_filter == "black and white": + # Define the grayscale weights (RGB order) + # These weights were used in PAL and NTSC video systems and can be found at + # https://en.wikipedia.org/wiki/Grayscale + # There are initially supposed to be float weights (0.299, 0.587, 0.114), with + # 0.299 + 0.587 + 0.114 = 1 + # However, since FHE computations require weights to be integers, we first multiply + # these by a factor of 1000. The output image's values are then divided by 1000 in + # post-processing in order to retrieve the correct result + kernel = torch.tensor([299, 587, 114]) + + self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1) + + # Division value for post-processing + self.divide = 1000 + + # Grayscaled image needs to be put in RGB format for Gradio display + self.repeat_out_channels = True + + elif image_filter == "blur": + kernel = torch.ones((3, 3), dtype=torch.int64) + + self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) + + # Division value for post-processing + self.divide = 9 + + elif image_filter == "sharpen": + kernel = torch.tensor( + [ + [0, -1, 0], + [-1, 5, -1], + [0, -1, 0], + ] + ) + + self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) + + elif image_filter == "ridge detection": + # Make the filter properly grayscaled, as it is commonly used + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265 + + kernel = torch.tensor( + [ + [-1, -1, -1], + [-1, 9, -1], + [-1, -1, -1], + ] + ) + + self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1) + + # Ridge detection is usually displayed as a grayscaled image, which needs to be put in + # RGB format for Gradio display + self.repeat_out_channels = True + + self.onnx_model = None + self.fhe_circuit = None + + def compile(self, inputset, onnx_model=None): + """Compile the model using an inputset. + + Args: + inputset (List[np.ndarray]): The set of images to use for compilation + onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be + generated automatically using a NumpyModule. Default to None. + """ + # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow + # the same shape conventions. + inputset = tuple( + np.expand_dims(input.transpose(2, 0, 1), axis=0).astype(np.int64) for input in inputset + ) + + # If no onnx model was given, generate a new one. + if onnx_model is None: + numpy_module = NumpyModule( + self.torch_model, + dummy_input=torch.from_numpy(inputset[0]), + ) + + onnx_model = numpy_module.onnx_model + + # Get the proxy function and parameter mappings for initializing the compiler + self.onnx_model = onnx_model + numpy_filter = get_equivalent_numpy_forward(onnx_model) + + numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"]) + + compiler = Compiler( + numpy_filter_proxy, + {parameters_mapping["inputs"]: "encrypted"}, + ) + + # Compile the filter + self.fhe_circuit = compiler.compile(inputset) + + return self.fhe_circuit + + def pre_processing(self, input_image): + """Processing that needs to be applied before encryption. + + Args: + input_image (np.ndarray): The image to pre-process + + Returns: + input_image (np.ndarray): The pre-processed image + """ + # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow + # the same shape conventions. + input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64) + + return input_image + + def post_processing(self, output_image): + """Processing that needs to be applied after decryption. + + Args: + input_image (np.ndarray): The decrypted image to post-process + + Returns: + input_image (np.ndarray): The post-processed image + """ + # Apply a division if needed + if self.divide is not None: + output_image //= self.divide + + # Clip the image's values to proper RGB standards as filters don't handle such constraints + output_image = output_image.clip(0, 255) + + # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow + # the same shape conventions. + output_image = output_image.transpose(0, 2, 3, 1).squeeze(0) + + # Grayscaled image needs to be put in RGB format for Gradio display + if self.repeat_out_channels: + output_image = output_image.repeat(3, axis=2) + + return output_image + + @classmethod + def from_json(cls, json_path): + """Instantiate a filter using a json file. + + Args: + json_path (Union[str, pathlib.Path]): Path to the json file. + + Returns: + model (Filter): The instantiated filter class. + """ + # Load the parameters from the json file + with open(json_path, "r", encoding="utf-8") as f: + serialized_processing = json.load(f) + + # Make sure the version in serialized_model is the same as CML_VERSION + assert_true( + serialized_processing["cml_version"] == CML_VERSION, + f"The version of Concrete ML library ({CML_VERSION}) is different " + f"from the one used to save the model ({serialized_processing['cml_version']}). " + "Please update to the proper Concrete ML version.", + ) + + # Initialize the model + model = cls(image_filter=serialized_processing["model_filter"]) + + return model + + def to_json(self, path_dir, file_name="serialized_processing"): + """Export the parameters to a json file. + + Args: + path_dir (Union[str, pathlib.Path]): The path to consider when saving the file. + file_name (str): The file name + """ + # Serialize the parameters + serialized_processing = { + "model_filter": self.filter, + } + serialized_processing = self._clean_dict_types_for_json(serialized_processing) + + # Add the version of the current CML library + serialized_processing["cml_version"] = CML_VERSION + + # Save the json file + with open(path_dir / f"{file_name}.json", "w", encoding="utf-8") as f: + json.dump(serialized_processing, f) + + def _clean_dict_types_for_json(self, d: dict) -> dict: + """Clean all values in the dict to be json serializable. + + Args: + d (Dict): The dict to clean + + Returns: + Dict: The cleaned dict + """ + key_to_delete = [] + for key, value in d.items(): + if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict): + d[key] = [self._clean_dict_types_for_json(v) for v in value] + elif isinstance(value, dict): + d[key] = self._clean_dict_types_for_json(value) + elif isinstance(value, (np.generic, np.ndarray)): + d[key] = d[key].tolist() + + for key in key_to_delete: + d.pop(key) + return d diff --git a/filters/black and white/deployment/client.zip b/filters/black and white/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..5d2cfb05ccc1a4639557af355d411dc60f92396a --- /dev/null +++ b/filters/black and white/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b20b09e3118c2dc24004c4f1c4bc1465cf4b0ed0e1c907fffb7695b3db6bbace +size 388 diff --git a/filters/black and white/deployment/serialized_processing.json b/filters/black and white/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..976800699b01718109a5043c33b013ab3a3fae4c --- /dev/null +++ b/filters/black and white/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "black and white", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/black and white/deployment/server.zip b/filters/black and white/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..984db3d1d0b0c595e7b8982749096f83b2854acc --- /dev/null +++ b/filters/black and white/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fb1f2ff4aa7a1a56cf5d0f8d63d34ee912c06b347fe5e97088c79ad0ba6e902 +size 4870 diff --git a/filters/black and white/server.onnx b/filters/black and white/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..093ee1651db658de3c17bc9b8883ff0cce2e68b4 --- /dev/null +++ b/filters/black and white/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f4774c394a6fec8cc43dae14ce627837aa998fcc78ba4ab67ad1c5bf92dd3ee +size 336 diff --git a/filters/black_and_white/deployment/client.zip b/filters/black_and_white/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..ee8a626022ae9166e998dd543aa51993f2c4d8a8 --- /dev/null +++ b/filters/black_and_white/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:000285a62f642b20eda541c6697e33de3d725c254ff5c2098e3157fc73cd017b +size 388 diff --git a/filters/black_and_white/deployment/serialized_processing.json b/filters/black_and_white/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..a27e01ffaf4686e193c3b8a1dca70ee76a8adc44 --- /dev/null +++ b/filters/black_and_white/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "black_and_white", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/black_and_white/deployment/server.zip b/filters/black_and_white/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..c791f8c2302675f3c239e832d89dcddadc76529c --- /dev/null +++ b/filters/black_and_white/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9867657ff1e7b2c8eb3c72f28be8b8e8ee0b355762b99f34a25a2c9de0cb104c +size 4762 diff --git a/filters/black_and_white/server.onnx b/filters/black_and_white/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..093ee1651db658de3c17bc9b8883ff0cce2e68b4 --- /dev/null +++ b/filters/black_and_white/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f4774c394a6fec8cc43dae14ce627837aa998fcc78ba4ab67ad1c5bf92dd3ee +size 336 diff --git a/filters/blur/deployment/client.zip b/filters/blur/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..5276b773ce0e7515ed208304ee345bfc5da4889c --- /dev/null +++ b/filters/blur/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8846612ef61a81a0f18b96a4bcca90bffde5400f9e689ac208f40673e3581aca +size 391 diff --git a/filters/blur/deployment/serialized_processing.json b/filters/blur/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..14bdc60c17a0181e85aa0a3b0b0f29e39273685c --- /dev/null +++ b/filters/blur/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "blur", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/blur/deployment/server.zip b/filters/blur/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..7e7cf8c9aeaa206514b0aec9067cfff49d2d9b4c --- /dev/null +++ b/filters/blur/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2652c963dbd9b82788671b9f133e70131f9da5ecf0a3123c9aa323ff69ee77a3 +size 8651 diff --git a/filters/blur/server.onnx b/filters/blur/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..39c604f354840fcc3701cd6cdfaee9a2642219cf --- /dev/null +++ b/filters/blur/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8fd3d313ec3a9d565a0621921768317f66e53596ad950ca2be6b1efbcf984bd +size 532 diff --git a/filters/identity/deployment/client.zip b/filters/identity/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..558958357d8477317d4c1843cf97184efa38cdb9 --- /dev/null +++ b/filters/identity/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7320407c56796bf0fe4d719f5e5826650f83c8424cb15779ac8c5b5ef0722fd +size 378 diff --git a/filters/identity/deployment/serialized_processing.json b/filters/identity/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..f1406498a353ae32c845fb748a714c24c5ac040c --- /dev/null +++ b/filters/identity/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "identity", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/identity/deployment/server.zip b/filters/identity/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..1cc8af29f7aaa34242112407332092c59a651cb8 --- /dev/null +++ b/filters/identity/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:246a6063277532e2ef246df2d0ce7f0c5d38fbfaa85c8a0d649cada63e7b0bb9 +size 2637 diff --git a/filters/identity/server.onnx b/filters/identity/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa75f644955ae09910ec3fd4a811e0d2004d134f --- /dev/null +++ b/filters/identity/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71a8a398ea2edac9b0dfd41232c74549d1b8c159d391a4b3d42e2b4b731da02b +size 155 diff --git a/filters/inverted/deployment/client.zip b/filters/inverted/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..558958357d8477317d4c1843cf97184efa38cdb9 --- /dev/null +++ b/filters/inverted/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7320407c56796bf0fe4d719f5e5826650f83c8424cb15779ac8c5b5ef0722fd +size 378 diff --git a/filters/inverted/deployment/serialized_processing.json b/filters/inverted/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..2e283802de695027642aa82a65bc5b35756722a8 --- /dev/null +++ b/filters/inverted/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "inverted", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/inverted/deployment/server.zip b/filters/inverted/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..32b2014545ac2dab5ddd4f0ec811c799a79b1545 --- /dev/null +++ b/filters/inverted/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7f54eb012a12e29a927fbfeb2d7c811533ebaa1d50527e14019c940a7c86f52 +size 5136 diff --git a/filters/inverted/server.onnx b/filters/inverted/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1491bf5a61750f492b543d41cfe63737400d7866 --- /dev/null +++ b/filters/inverted/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2fead709dbc8f0ab1f19ff1878e0ac9ce110b2b3ced261d7a87d32e0fc58b61 +size 211 diff --git a/filters/ridge detection/deployment/client.zip b/filters/ridge detection/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..c9b0b42ab04fcddc13fe5fabde22f41a13e43f62 --- /dev/null +++ b/filters/ridge detection/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88f760b83837021929bccf86aaffefed2f9e5e97c3638346d32238e1027cb7a2 +size 397 diff --git a/filters/ridge detection/deployment/serialized_processing.json b/filters/ridge detection/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..8aca5d73d3ca127cfd319066a70858be9a2a8c9e --- /dev/null +++ b/filters/ridge detection/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "ridge detection", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/ridge detection/deployment/server.zip b/filters/ridge detection/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..45c23e41d165241b8f01a52a0d1f2513cef8c713 --- /dev/null +++ b/filters/ridge detection/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bef0e5a94d7d50c8ac658b9e1a411c81051dba914a1b19f0e2badc53a2f36fdc +size 5020 diff --git a/filters/ridge detection/server.onnx b/filters/ridge detection/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..061cfef1370975519a35672c1d71b83cf5ec5534 --- /dev/null +++ b/filters/ridge detection/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48821745ed7a9b25b5ba8ae0dc3da35739985bf5dd1dac5b3a9c207adbbf1c45 +size 532 diff --git a/filters/ridge_detection/deployment/client.zip b/filters/ridge_detection/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..949f5edba27c63e2807bee639463f21451d4a806 --- /dev/null +++ b/filters/ridge_detection/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d241694b8c01dce2ad8a5ce2dbe12190e40d6912e88d086dbc0e047aba4dfafb +size 397 diff --git a/filters/ridge_detection/deployment/serialized_processing.json b/filters/ridge_detection/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..520aacffe04f376ea907b657e8c418dbca199b19 --- /dev/null +++ b/filters/ridge_detection/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "ridge_detection", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/ridge_detection/deployment/server.zip b/filters/ridge_detection/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..212ee357f22001a8d5bbe617c3d52df523a2e580 --- /dev/null +++ b/filters/ridge_detection/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3605e14d8533e3c57edf30a7da32d4441fcb68228a8ebd028015338b8b5d5f70 +size 4884 diff --git a/filters/ridge_detection/server.onnx b/filters/ridge_detection/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0be0b5bb34fc387c1eeaf92665f46701d9a2d496 --- /dev/null +++ b/filters/ridge_detection/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e05d56c4988abd621aee6dea4efe2dfdaf1d09dfb78bb7bf7b6bb3a00d3e80b +size 532 diff --git a/filters/rotate/deployment/client.zip b/filters/rotate/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..507daffab3c04ed05318a1aad16538d0b290a346 --- /dev/null +++ b/filters/rotate/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f43305830edade90df38b59070c255948810dc9a8c58eda16157e0424b9bffe +size 383 diff --git a/filters/rotate/deployment/serialized_processing.json b/filters/rotate/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..1832ccad346940503fa68c374cdf09e51996952f --- /dev/null +++ b/filters/rotate/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "rotate", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/rotate/deployment/server.zip b/filters/rotate/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..913656b866f19a0ec37d81fbdfc707a02243ec6f --- /dev/null +++ b/filters/rotate/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b382f12dafa436e4e5e4adc0346fa81b52d5fad709a19f1b2cad52001a97c984 +size 5366 diff --git a/filters/rotate/server.onnx b/filters/rotate/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..def14d1ae8389dd1a052aab547240b881813e854 --- /dev/null +++ b/filters/rotate/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa03ea9a684b65c29c2cc0e6ab20f6b6349f35c4bd70921d264e74298a758de1 +size 178 diff --git a/filters/sharpen/deployment/client.zip b/filters/sharpen/deployment/client.zip new file mode 100644 index 0000000000000000000000000000000000000000..972b32d46b53d7f574ea36d8e432b2d91f1c39aa --- /dev/null +++ b/filters/sharpen/deployment/client.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d666044a75f5e7d4642145181ea239de6076f8ae424d971c7139e3467a758793 +size 396 diff --git a/filters/sharpen/deployment/serialized_processing.json b/filters/sharpen/deployment/serialized_processing.json new file mode 100644 index 0000000000000000000000000000000000000000..418780f9ce354df43ce443bf0d0a2faec48c6ec5 --- /dev/null +++ b/filters/sharpen/deployment/serialized_processing.json @@ -0,0 +1 @@ +{"model_filter": "sharpen", "cml_version": "0.6.0-rc0"} \ No newline at end of file diff --git a/filters/sharpen/deployment/server.zip b/filters/sharpen/deployment/server.zip new file mode 100644 index 0000000000000000000000000000000000000000..df2cc7efb3d8cd8e74441bd82c14e064b5aa59b0 --- /dev/null +++ b/filters/sharpen/deployment/server.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6922f5fa0d0a0584636ce755dbd921bc36fca082f2a0facb74669f4b24b67368 +size 8720 diff --git a/filters/sharpen/server.onnx b/filters/sharpen/server.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b5a278fc2f94455bae1b00ebcf32ee9bb9a858e3 --- /dev/null +++ b/filters/sharpen/server.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7958a3c9be1b578486ec1708701340263ce3ad70b7cd3ff281230797f67de0d +size 532 diff --git a/generate_dev_files.py b/generate_dev_files.py new file mode 100644 index 0000000000000000000000000000000000000000..093f9ae01288d9d2a0e9c7d3bed86ce0c707f838 --- /dev/null +++ b/generate_dev_files.py @@ -0,0 +1,40 @@ +"A script to generate all development files necessary for the image filtering demo." + +import shutil +from pathlib import Path + +import numpy as np +import onnx +from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET +from custom_client_server import CustomFHEDev +from filters import Filter + +print("Generating deployment files for all available filters") + +for image_filter in AVAILABLE_FILTERS: + print("Filter:", image_filter, "\n") + + # Create the filter instance + filter = Filter(image_filter) + + image_shape = INPUT_SHAPE + (3,) + + # Compile the filter on the inputset + filter.compile(INPUTSET) + + filter_path = FILTERS_PATH / image_filter + + deployment_path = filter_path / "deployment" + + # Delete the deployment folder and its content if it exist + if deployment_path.is_dir(): + shutil.rmtree(deployment_path) + + # Save the files needed for deployment + fhe_dev_filter = CustomFHEDev(deployment_path, filter) + fhe_dev_filter.save() + + # Save the ONNX model + onnx.save(filter.onnx_model, filter_path / "server.onnx") + +print("Done !") diff --git a/input_examples/arc.jpg b/input_examples/arc.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b2520e557522f4794c4ee5e2e8348a04923a3f3 Binary files /dev/null and b/input_examples/arc.jpg differ diff --git a/input_examples/book.jpg b/input_examples/book.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f487c6ccde701caaa63fea38ff3b85c4436073e1 Binary files /dev/null and b/input_examples/book.jpg differ diff --git a/input_examples/computer.jpg b/input_examples/computer.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b096eaa94001fb42891e2a3dd4fa889d678c05b3 Binary files /dev/null and b/input_examples/computer.jpg differ diff --git a/input_examples/tree.jpg b/input_examples/tree.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d6b197da91551e3e174d4de42a07d344665bc95b Binary files /dev/null and b/input_examples/tree.jpg differ diff --git a/input_examples/zama_math.jpg b/input_examples/zama_math.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9b62acef99ae46f526670bdd33e53d8eccf0ec71 Binary files /dev/null and b/input_examples/zama_math.jpg differ diff --git a/input_examples/zebra.jpg b/input_examples/zebra.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4581306990a19e3e27a766beb55e5386c9604be1 Binary files /dev/null and b/input_examples/zebra.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b05d6d95ad87961bb3f1f30c1bd02dc29b5beff --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +concrete-ml==0.6.0 +gradio==3.11.0 +uvicorn==0.20.0 +fastapi==0.87.0 +jupyter==1.0.0 diff --git a/server.py b/server.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ec566f2f15b9c593f0557463a989d9135bce7a --- /dev/null +++ b/server.py @@ -0,0 +1,97 @@ +"""Server that will listen for GET and POST requests from the client.""" + +import time +from typing import List + +from common import FILTERS_PATH, SERVER_TMP_PATH +from custom_client_server import CustomFHEServer +from fastapi import FastAPI, File, Form, UploadFile +from fastapi.responses import JSONResponse, Response +from pydantic import BaseModel + + +def get_server_file_path(name, user_id, image_filter): + """Get the correct temporary file path for the server. + + Args: + name (str): The desired file name. + user_id (int): The current user's ID. + image_filter (str): The filter chosen by the user + + Returns: + pathlib.Path: The file path. + """ + return SERVER_TMP_PATH / f"{name}_{image_filter}_{user_id}" + + +class FilterRequest(BaseModel): + filter: str + + +# Initialize an instance of FastAPI +app = FastAPI() + +# Define the default route +@app.get("/") +def root(): + return {"message": "Welcome to Your Image FHE Filter Server!"} + + +@app.post("/send_input") +def send_input( + user_id: str = Form(), + filter: str = Form(), + files: List[UploadFile] = File(), +): + encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id) + evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id) + + with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open( + "wb" + ) as evaluation_key: + encrypted_image.write(files[0].file.read()) + evaluation_key.write(files[1].file.read()) + + +@app.post("/run_fhe") +def run_fhe( + user_id: str = Form(), + filter: str = Form(), +): + + encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id) + evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id) + + with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open( + "rb" + ) as evaluation_key_file: + encrypted_image = encrypted_image_file.read() + evaluation_key = evaluation_key_file.read() + + # Load the model + fhe_model = CustomFHEServer(FILTERS_PATH / f"{filter}/deployment") + + # Run the FHE execution + start = time.time() + encrypted_output_image = fhe_model.run(encrypted_image, evaluation_key) + fhe_execution_time = round(time.time() - start, 2) + + encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id) + + with encrypted_output_path.open("wb") as encrypted_output: + encrypted_output.write(encrypted_output_image) + + return JSONResponse(content=fhe_execution_time) + + +@app.post("/get_output") +def get_output( + user_id: str = Form(), + filter: str = Form(), +): + encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id) + + with encrypted_output_path.open("rb") as encrypted_output_file: + encrypted_output = encrypted_output_file.read() + + return Response(encrypted_output)