Roman commited on
Commit
7b32412
•
1 Parent(s): b1501ef

chore: remove CML client/server API by only using CN

Browse files
app.py CHANGED
@@ -19,7 +19,7 @@ from common import (
19
  REPO_DIR,
20
  SERVER_URL,
21
  )
22
- from custom_client_server import CustomFHEClient
23
 
24
  # Uncomment here to have both the server and client in the same terminal
25
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
@@ -33,11 +33,11 @@ def decrypt_output_with_wrong_key(encrypted_image, filter_name):
33
  filter_path = FILTERS_PATH / f"{filter_name}/deployment"
34
 
35
  # Instantiate the client interface and generate a new private key
36
- wrong_client = CustomFHEClient(filter_path, WRONG_KEYS_PATH)
37
  wrong_client.generate_private_and_evaluation_keys(force=True)
38
 
39
- # Deserialize, decrypt and post-processing the encrypted output using the new private key
40
- output_image = wrong_client.deserialize_decrypt_dequantize(encrypted_image)
41
 
42
  return output_image
43
 
@@ -53,7 +53,7 @@ def shorten_bytes_object(bytes_object, limit=500):
53
  limit (int): The length to consider. Default to 500.
54
 
55
  Returns:
56
- Any: The fitted model.
57
 
58
  """
59
  # Define a shift for better display
@@ -69,9 +69,9 @@ def get_client(user_id, filter_name):
69
  filter_name (str): The filter chosen by the user
70
 
71
  Returns:
72
- CustomFHEClient: The client API.
73
  """
74
- return CustomFHEClient(
75
  FILTERS_PATH / f"{filter_name}/deployment", KEYS_PATH / f"{filter_name}_{user_id}"
76
  )
77
 
@@ -184,11 +184,8 @@ def encrypt(user_id, input_image, filter_name):
184
  # Retrieve the client API
185
  client = get_client(user_id, filter_name)
186
 
187
- # Pre-process the input image as Torch and Numpy don't follow the same shape format
188
- preprocessed_input_image = client.model.pre_processing(input_image)
189
-
190
- # Encrypt and serialize the image
191
- encrypted_image = client.quantize_encrypt_serialize(preprocessed_input_image)
192
 
193
  # Compute the input's size in Megabytes
194
  encrypted_input_size = len(encrypted_image) / 1000000
@@ -341,7 +338,7 @@ def decrypt_output(user_id, filter_name):
341
  client = get_client(user_id, filter_name)
342
 
343
  # Deserialize, decrypt and post-process the encrypted output
344
- output_image = client.deserialize_decrypt_dequantize(encrypted_output_image)
345
 
346
  return output_image, False, False
347
 
 
19
  REPO_DIR,
20
  SERVER_URL,
21
  )
22
+ from client_server_interface import FHEClient
23
 
24
  # Uncomment here to have both the server and client in the same terminal
25
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
 
33
  filter_path = FILTERS_PATH / f"{filter_name}/deployment"
34
 
35
  # Instantiate the client interface and generate a new private key
36
+ wrong_client = FHEClient(filter_path, WRONG_KEYS_PATH)
37
  wrong_client.generate_private_and_evaluation_keys(force=True)
38
 
39
+ # Deserialize, decrypt and post-process the encrypted output using the new private key
40
+ output_image = wrong_client.deserialize_decrypt_post_process(encrypted_image)
41
 
42
  return output_image
43
 
 
53
  limit (int): The length to consider. Default to 500.
54
 
55
  Returns:
56
+ str: Hexadecimal string shorten representation of the input byte object.
57
 
58
  """
59
  # Define a shift for better display
 
69
  filter_name (str): The filter chosen by the user
70
 
71
  Returns:
72
+ FHEClient: The client API.
73
  """
74
+ return FHEClient(
75
  FILTERS_PATH / f"{filter_name}/deployment", KEYS_PATH / f"{filter_name}_{user_id}"
76
  )
77
 
 
184
  # Retrieve the client API
185
  client = get_client(user_id, filter_name)
186
 
187
+ # Pre-process, encrypt and serialize the image
188
+ encrypted_image = client.pre_process_encrypt_serialize(input_image)
 
 
 
189
 
190
  # Compute the input's size in Megabytes
191
  encrypted_input_size = len(encrypted_image) / 1000000
 
338
  client = get_client(user_id, filter_name)
339
 
340
  # Deserialize, decrypt and post-process the encrypted output
341
+ output_image = client.deserialize_decrypt_post_process(encrypted_output_image)
342
 
343
  return output_image, False, False
344
 
client_server_interface.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Client-server interface custom implementation for filter models."
2
+
3
+ import zipfile
4
+ import json
5
+ from filters import Filter
6
+
7
+ import concrete.numpy as cnp
8
+
9
+ class FHEServer:
10
+ """Server interface run a FHE circuit."""
11
+
12
+ def __init__(self, path_dir):
13
+ """Initialize the FHE interface.
14
+
15
+ Args:
16
+ path_dir (Path): The path to the directory where the circuit is saved.
17
+ """
18
+ self.path_dir = path_dir
19
+
20
+ # Load the FHE circuit
21
+ self.server = cnp.Server.load(self.path_dir / "server.zip")
22
+
23
+ def run(self, serialized_encrypted_image, serialized_evaluation_keys):
24
+ """Run the filter on the server over an encrypted image.
25
+
26
+ Args:
27
+ serialized_encrypted_image (bytes): The encrypted and serialized image.
28
+ serialized_evaluation_keys (bytes): The serialized evaluation keys.
29
+
30
+ Returns:
31
+ bytes: The filter's output.
32
+ """
33
+ # Deserialize the encrypted input image and the evaluation keys
34
+ deserialized_encrypted_image = self.server.client_specs.unserialize_public_args(
35
+ serialized_encrypted_image
36
+ )
37
+ deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)
38
+
39
+ # Execute the filter in FHE
40
+ result = self.server.run(
41
+ deserialized_encrypted_image, deserialized_evaluation_keys
42
+ )
43
+
44
+ # Serialize the encrypted output image
45
+ serialized_result = self.server.client_specs.serialize_public_result(result)
46
+
47
+ return serialized_result
48
+
49
+
50
+ class FHEDev:
51
+ """Development interface to save and load the filter."""
52
+
53
+ def __init__(self, filter, path_dir):
54
+ """Initialize the FHE interface.
55
+
56
+ Args:
57
+ path_dir (str): The path to the directory where the circuit is saved.
58
+ filter (Filter): The filter to use in the FHE interface.
59
+ """
60
+
61
+ self.filter = filter
62
+ self.path_dir = path_dir
63
+
64
+ self.path_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ def save(self):
67
+ """Export all needed artifacts for the client and server interfaces."""
68
+
69
+ assert self.filter.fhe_circuit is not None, (
70
+ "The model must be compiled before saving it."
71
+ )
72
+
73
+ # Export to json the parameters needed for loading the filter in the other interfaces
74
+ serialized_processing = {"filter_name": self.filter.filter_name}
75
+
76
+ json_path = self.path_dir / "serialized_processing.json"
77
+ with open(json_path, "w", encoding="utf-8") as file:
78
+ json.dump(serialized_processing, file)
79
+
80
+ # Save the circuit for the server
81
+ path_circuit_server = self.path_dir / "server.zip"
82
+ self.filter.fhe_circuit.server.save(path_circuit_server)
83
+
84
+ # Save the circuit for the client
85
+ path_circuit_client = self.path_dir / "client.zip"
86
+ self.filter.fhe_circuit.client.save(path_circuit_client)
87
+
88
+ with zipfile.ZipFile(path_circuit_client, "a") as zip_file:
89
+ zip_file.write(filename=json_path, arcname="serialized_processing.json")
90
+
91
+
92
+ class FHEClient:
93
+ """Client interface to encrypt and decrypt FHE data associated to a Filter."""
94
+
95
+ def __init__(self, path_dir, key_dir):
96
+ """Initialize the FHE interface.
97
+
98
+ Args:
99
+ path_dir (Path): the path to the directory where the circuit is saved
100
+ key_dir (Path): the path to the directory where the keys are stored
101
+ """
102
+ self.path_dir = path_dir
103
+ self.key_dir = key_dir
104
+
105
+ # If path_dir does not exist raise
106
+ assert path_dir.exists(), f"{path_dir} does not exist. Please specify a valid path."
107
+
108
+ # Load the client
109
+ self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
110
+
111
+ # Load the parameters
112
+ with zipfile.ZipFile(self.path_dir / "client.zip") as client_zip:
113
+ with client_zip.open("serialized_processing.json", mode="r") as file:
114
+ serialized_processing = json.load(file)
115
+
116
+ # Instantiate the filter
117
+ filter_name = serialized_processing["filter_name"]
118
+ self.filter = Filter(filter_name)
119
+
120
+ def generate_private_and_evaluation_keys(self, force=False):
121
+ """Generate the private and evaluation keys.
122
+
123
+ Args:
124
+ force (bool): If True, regenerate the keys even if they already exist.
125
+ """
126
+ self.client.keygen(force)
127
+
128
+ def get_serialized_evaluation_keys(self):
129
+ """Get the serialized evaluation keys.
130
+
131
+ Returns:
132
+ bytes: The evaluation keys.
133
+ """
134
+ return self.client.evaluation_keys.serialize()
135
+
136
+ def pre_process_encrypt_serialize(self, input_image):
137
+ """Pre-process, encrypt and serialize the input image.
138
+
139
+ Args:
140
+ input_image (numpy.ndarray): The image to pre-process, encrypt and serialize.
141
+
142
+ Returns:
143
+ bytes: The pre-processed, encrypted and serialized image.
144
+ """
145
+ # Pre-process the image
146
+ preprocessed_image = self.filter.pre_processing(input_image)
147
+
148
+ # Encrypt the image
149
+ encrypted_image = self.client.encrypt(preprocessed_image)
150
+
151
+ # Serialize the encrypted image to be sent to the server
152
+ serialized_encrypted_image = self.client.specs.serialize_public_args(encrypted_image)
153
+ return serialized_encrypted_image
154
+
155
+ def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
156
+ """Deserialize, decrypt and post-process the output image.
157
+
158
+ Args:
159
+ serialized_encrypted_output_image (bytes): The serialized and encrypted output image.
160
+
161
+ Returns:
162
+ numpy.ndarray: The decrypted, deserialized and post-processed image.
163
+ """
164
+ # Deserialize the encrypted image
165
+ encrypted_output_image = self.client.specs.unserialize_public_result(
166
+ serialized_encrypted_output_image
167
+ )
168
+
169
+ # Decrypt the image
170
+ output_image = self.client.decrypt(encrypted_output_image)
171
+
172
+ # Post-process the image
173
+ post_processed_output_image = self.filter.post_processing(output_image)
174
+
175
+ return post_processed_output_image
compile.py CHANGED
@@ -3,9 +3,9 @@
3
  import json
4
  import shutil
5
  import onnx
 
6
  from common import AVAILABLE_FILTERS, FILTERS_PATH, KEYS_PATH
7
- from custom_client_server import CustomFHEClient
8
- from concrete.ml.deployment import FHEModelDev
9
 
10
  print("Starting compiling the filters.")
11
 
@@ -16,13 +16,13 @@ for filter_name in AVAILABLE_FILTERS:
16
  deployment_path = FILTERS_PATH / f"{filter_name}/deployment"
17
 
18
  # Retrieve the client associated to the current filter
19
- model = CustomFHEClient(deployment_path, KEYS_PATH).model
20
 
21
- # Load the onnx model
22
- onnx_model = onnx.load(FILTERS_PATH / f"{filter_name}/server.onnx")
23
 
24
- # Compile the model on a representative inputset, using the loaded onnx model
25
- model.compile(onnx_model=onnx_model)
26
 
27
  processing_json_path = deployment_path / "serialized_processing.json"
28
 
@@ -35,7 +35,7 @@ for filter_name in AVAILABLE_FILTERS:
35
  shutil.rmtree(deployment_path)
36
 
37
  # Save the development files needed for deployment
38
- fhe_dev = FHEModelDev(model=model, path_dir=deployment_path)
39
  fhe_dev.save()
40
 
41
  # Write the serialized_processing.json file in the deployment directory
 
3
  import json
4
  import shutil
5
  import onnx
6
+
7
  from common import AVAILABLE_FILTERS, FILTERS_PATH, KEYS_PATH
8
+ from client_server_interface import FHEClient, FHEDev
 
9
 
10
  print("Starting compiling the filters.")
11
 
 
16
  deployment_path = FILTERS_PATH / f"{filter_name}/deployment"
17
 
18
  # Retrieve the client associated to the current filter
19
+ filter = FHEClient(deployment_path, KEYS_PATH).filter
20
 
21
+ # Load the onnx graph
22
+ onnx_graph = onnx.load(FILTERS_PATH / f"{filter_name}/server.onnx")
23
 
24
+ # Compile the filter on a representative inputset, using the loaded onnx graph
25
+ filter.compile(onnx_graph=onnx_graph)
26
 
27
  processing_json_path = deployment_path / "serialized_processing.json"
28
 
 
35
  shutil.rmtree(deployment_path)
36
 
37
  # Save the development files needed for deployment
38
+ fhe_dev = FHEDev(filter, deployment_path)
39
  fhe_dev.save()
40
 
41
  # Write the serialized_processing.json file in the deployment directory
custom_client_server.py DELETED
@@ -1,35 +0,0 @@
1
- "Client-server interface custom implementation for filter models."
2
-
3
- import json
4
- import concrete.numpy as cnp
5
- from filters import Filter
6
-
7
- from concrete.ml.deployment import FHEModelClient
8
- from concrete.ml.version import __version__ as CML_VERSION
9
-
10
-
11
- class CustomFHEClient(FHEModelClient):
12
- """Client interface to encrypt and decrypt FHE data associated to a Filter."""
13
-
14
- def load(self):
15
- """Load the parameters along with the FHE specs."""
16
-
17
- # Load the client
18
- self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
19
-
20
- # Load the filter's parameters from the json file
21
- with (self.path_dir / "serialized_processing.json").open("r", encoding="utf-8") as f:
22
- serialized_processing = json.load(f)
23
-
24
- # Make sure the version in serialized_model is the same as CML_VERSION
25
- assert serialized_processing["cml_version"] == CML_VERSION, (
26
- f"The version of Concrete ML library ({CML_VERSION}) is different "
27
- f"from the one used to save the model ({serialized_processing['cml_version']}). "
28
- "Please update to the proper Concrete ML version.",
29
- )
30
-
31
- # Initialize the filter model using its filter name
32
- filter_name = serialized_processing["model_post_processing_params"]["filter_name"]
33
- self.model = Filter(filter_name)
34
-
35
- return self.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
filters.py CHANGED
@@ -2,17 +2,16 @@
2
 
3
  import numpy as np
4
  import torch
5
- from common import AVAILABLE_FILTERS, INPUT_SHAPE
6
- from concrete.numpy.compilation.compiler import Compiler
7
  from torch import nn
 
8
 
9
- from concrete.ml.common.debugging.custom_assert import assert_true
10
  from concrete.ml.common.utils import generate_proxy_function
11
  from concrete.ml.onnx.convert import get_equivalent_numpy_forward
12
  from concrete.ml.torch.numpy_module import NumpyModule
13
 
14
 
15
- class _TorchIdentity(nn.Module):
16
  """Torch identity model."""
17
 
18
  def forward(self, x):
@@ -27,7 +26,7 @@ class _TorchIdentity(nn.Module):
27
  return x
28
 
29
 
30
- class _TorchInverted(nn.Module):
31
  """Torch inverted model."""
32
 
33
  def forward(self, x):
@@ -42,7 +41,7 @@ class _TorchInverted(nn.Module):
42
  return 255 - x
43
 
44
 
45
- class _TorchRotate(nn.Module):
46
  """Torch rotated model."""
47
 
48
  def forward(self, x):
@@ -57,8 +56,8 @@ class _TorchRotate(nn.Module):
57
  return x.transpose(2, 3)
58
 
59
 
60
- class _TorchConv2D(nn.Module):
61
- """Torch model for applying a single 2D convolution operator on images."""
62
 
63
  def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
64
  """Initialize the filter.
@@ -74,7 +73,7 @@ class _TorchConv2D(nn.Module):
74
  self.threshold = threshold
75
 
76
  def forward(self, x):
77
- """Forward pass for filtering the image using a 2D kernel.
78
 
79
  Args:
80
  x (torch.Tensor): The input image.
@@ -133,20 +132,13 @@ class Filter:
133
  filter_name (str): The filter to consider.
134
  """
135
 
136
- assert_true(
137
- filter_name in AVAILABLE_FILTERS,
138
  f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
139
  f"but got {filter_name}",
140
  )
141
 
142
- # Define attributes needed in order to prevent the Concrete-ML client-server interface
143
- # from breaking
144
- self.post_processing_params = {"filter_name": filter_name}
145
- self.input_quantizers = []
146
- self.output_quantizers = []
147
-
148
  # Define attributes associated to the filter
149
- self.filter = filter_name
150
  self.onnx_model = None
151
  self.fhe_circuit = None
152
  self.divide = None
@@ -154,13 +146,13 @@ class Filter:
154
 
155
  # Instantiate the torch module associated to the given filter name
156
  if filter_name == "identity":
157
- self.torch_model = _TorchIdentity()
158
 
159
  elif filter_name == "inverted":
160
- self.torch_model = _TorchInverted()
161
 
162
  elif filter_name == "rotate":
163
- self.torch_model = _TorchRotate()
164
 
165
  elif filter_name == "black and white":
166
  # Define the grayscale weights (RGB order)
@@ -173,7 +165,7 @@ class Filter:
173
  # post-processing in order to retrieve the correct result
174
  kernel = [299, 587, 114]
175
 
176
- self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
177
 
178
  # Define the value used when for dividing the output values in post-processing
179
  self.divide = 1000
@@ -185,7 +177,7 @@ class Filter:
185
  elif filter_name == "blur":
186
  kernel = np.ones((3, 3))
187
 
188
- self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
189
 
190
  # Define the value used when for dividing the output values in post-processing
191
  self.divide = 9
@@ -197,7 +189,7 @@ class Filter:
197
  [0, -1, 0],
198
  ]
199
 
200
- self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
201
 
202
  elif filter_name == "ridge detection":
203
  kernel = [
@@ -208,18 +200,18 @@ class Filter:
208
 
209
  # Additionally to the convolution operator, the filter will subtract a given threshold
210
  # value to the result in order to better display the ridges
211
- self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1, threshold=900)
212
 
213
  # Indicate that the out_channels will need to be repeated, as Gradio requires all
214
  # images to have a RGB format, even for grayscaled ones. Ridge detection images are
215
  # ususally displayed as such
216
  self.repeat_out_channels = True
217
 
218
- def compile(self, onnx_model=None):
219
  """Compile the model on a representative inputset.
220
 
221
  Args:
222
- onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
223
  generated automatically using a NumpyModule. Default to None.
224
  """
225
  # Generate a random representative set of images used for compilation, following Torch's
@@ -232,17 +224,17 @@ class Filter:
232
  )
233
 
234
  # If no onnx model was given, generate a new one.
235
- if onnx_model is None:
236
  numpy_module = NumpyModule(
237
  self.torch_model,
238
  dummy_input=torch.from_numpy(inputset[0]),
239
  )
240
 
241
- onnx_model = numpy_module.onnx_model
242
 
243
  # Get the proxy function and parameter mappings for initializing the compiler
244
- self.onnx_model = onnx_model
245
- numpy_filter = get_equivalent_numpy_forward(onnx_model)
246
 
247
  numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
248
 
@@ -256,20 +248,6 @@ class Filter:
256
 
257
  return self.fhe_circuit
258
 
259
- def quantize_input(self, input_image):
260
- """Quantize the input.
261
-
262
- Images are already quantized in this case, however we need to define this method in order
263
- to prevent the Concrete-ML client-server interface from breaking.
264
-
265
- Args:
266
- input_image (np.ndarray): The input to quantize.
267
-
268
- Returns:
269
- np.ndarray: The quantized input.
270
- """
271
- return input_image
272
-
273
  def pre_processing(self, input_image):
274
  """Apply pre-processing to the encrypted input images.
275
 
 
2
 
3
  import numpy as np
4
  import torch
 
 
5
  from torch import nn
6
+ from common import AVAILABLE_FILTERS, INPUT_SHAPE
7
 
8
+ from concrete.numpy.compilation.compiler import Compiler
9
  from concrete.ml.common.utils import generate_proxy_function
10
  from concrete.ml.onnx.convert import get_equivalent_numpy_forward
11
  from concrete.ml.torch.numpy_module import NumpyModule
12
 
13
 
14
+ class TorchIdentity(nn.Module):
15
  """Torch identity model."""
16
 
17
  def forward(self, x):
 
26
  return x
27
 
28
 
29
+ class TorchInverted(nn.Module):
30
  """Torch inverted model."""
31
 
32
  def forward(self, x):
 
41
  return 255 - x
42
 
43
 
44
+ class TorchRotate(nn.Module):
45
  """Torch rotated model."""
46
 
47
  def forward(self, x):
 
56
  return x.transpose(2, 3)
57
 
58
 
59
+ class TorchConv(nn.Module):
60
+ """Torch model for applying convolution operators on images."""
61
 
62
  def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
63
  """Initialize the filter.
 
73
  self.threshold = threshold
74
 
75
  def forward(self, x):
76
+ """Forward pass for filtering the image using a 1D or 2D kernel.
77
 
78
  Args:
79
  x (torch.Tensor): The input image.
 
132
  filter_name (str): The filter to consider.
133
  """
134
 
135
+ assert filter_name in AVAILABLE_FILTERS, (
 
136
  f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
137
  f"but got {filter_name}",
138
  )
139
 
 
 
 
 
 
 
140
  # Define attributes associated to the filter
141
+ self.filter_name = filter_name
142
  self.onnx_model = None
143
  self.fhe_circuit = None
144
  self.divide = None
 
146
 
147
  # Instantiate the torch module associated to the given filter name
148
  if filter_name == "identity":
149
+ self.torch_model = TorchIdentity()
150
 
151
  elif filter_name == "inverted":
152
+ self.torch_model = TorchInverted()
153
 
154
  elif filter_name == "rotate":
155
+ self.torch_model = TorchRotate()
156
 
157
  elif filter_name == "black and white":
158
  # Define the grayscale weights (RGB order)
 
165
  # post-processing in order to retrieve the correct result
166
  kernel = [299, 587, 114]
167
 
168
+ self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1)
169
 
170
  # Define the value used when for dividing the output values in post-processing
171
  self.divide = 1000
 
177
  elif filter_name == "blur":
178
  kernel = np.ones((3, 3))
179
 
180
+ self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)
181
 
182
  # Define the value used when for dividing the output values in post-processing
183
  self.divide = 9
 
189
  [0, -1, 0],
190
  ]
191
 
192
+ self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)
193
 
194
  elif filter_name == "ridge detection":
195
  kernel = [
 
200
 
201
  # Additionally to the convolution operator, the filter will subtract a given threshold
202
  # value to the result in order to better display the ridges
203
+ self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1, threshold=900)
204
 
205
  # Indicate that the out_channels will need to be repeated, as Gradio requires all
206
  # images to have a RGB format, even for grayscaled ones. Ridge detection images are
207
  # ususally displayed as such
208
  self.repeat_out_channels = True
209
 
210
+ def compile(self, onnx_graph=None):
211
  """Compile the model on a representative inputset.
212
 
213
  Args:
214
+ onnx_graph (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
215
  generated automatically using a NumpyModule. Default to None.
216
  """
217
  # Generate a random representative set of images used for compilation, following Torch's
 
224
  )
225
 
226
  # If no onnx model was given, generate a new one.
227
+ if onnx_graph is None:
228
  numpy_module = NumpyModule(
229
  self.torch_model,
230
  dummy_input=torch.from_numpy(inputset[0]),
231
  )
232
 
233
+ onnx_graph = numpy_module.onnx_model
234
 
235
  # Get the proxy function and parameter mappings for initializing the compiler
236
+ self.onnx_graph = onnx_graph
237
+ numpy_filter = get_equivalent_numpy_forward(onnx_graph)
238
 
239
  numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
240
 
 
248
 
249
  return self.fhe_circuit
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def pre_processing(self, input_image):
252
  """Apply pre-processing to the encrypted input images.
253
 
filters/black and white/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c6a88c9717d1ec81035e715ad84803b2351756c1a1a6fb51786e2e07b8cbe84
3
- size 388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa7aa78811be0810d523a8d94a1aa27e24e40d29eced0395fd3bc743568b62b8
3
+ size 550
filters/black and white/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "black and white"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "black and white"}
filters/black and white/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3efb2af650a4b4690a8096048b084a03f9a505985fc8dee4da7d3367ea040918
3
  size 4364
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fe0885c010b076062a9b5887d0ee5ebf07f7b879633324af1b14e58a2fefeec
3
  size 4364
filters/blur/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f724173d51c70053038bd40ce911ce4ef9ea50a70b077129b86abb482ad4a21e
3
- size 391
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fafcd6cd32109e17bad3ee5945b54c64525f3db5d3e893a5618cfb765a8748e
3
+ size 542
filters/blur/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "blur"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "blur"}
filters/blur/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1193734d14b02195075fc402e3d84e11b8b7216a83ceff9a0becb16f1b3fbcf0
3
  size 7263
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49b6c8a391e67ba424f156ef6049175e5c49b13d5b92052fddf05214741175c6
3
  size 7263
filters/identity/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e396e33163faf6dbf2a8eca318908efa92dea5d0d8c24a46439a925497543431
3
- size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19a7d7831af7f4a7a55a734a12e772ec41058502138e15925e229c89fcd8b195
3
+ size 533
filters/identity/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "identity"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "identity"}
filters/identity/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:da52e793a997ded3b0c383f246da31d51317e2461ff1955f9f01014258272f9b
3
  size 2559
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2891ffa3e35d14d40a79b533fb331d557c82b4a8fe20568aa095aa7a22164a9
3
  size 2559
filters/inverted/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e396e33163faf6dbf2a8eca318908efa92dea5d0d8c24a46439a925497543431
3
- size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67169abe3c33f7c7f377cd7e3b17031dd43054432a9d1b39f5469417156b5f2d
3
+ size 533
filters/inverted/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "inverted"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "inverted"}
filters/inverted/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe95edd2998cee4ff7e40fde889f7a85bbf69218a4f1b517565de79d82517c4f
3
  size 4179
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:781488531b0049ecd05d3cc0e0eb95a9350553848bf218e726e97dce2b3ebd42
3
  size 4179
filters/ridge detection/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1bf9931bbf568d5b74fd3e1bab8fcd48780c88b2f36289b26976cb1ebf4c665
3
- size 397
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05b54c87d88297316aeb864d7292c9a4c930d486e7d0b7232bdf77e9b76a7692
3
+ size 559
filters/ridge detection/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "ridge detection"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "ridge detection"}
filters/ridge detection/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a31acdba1d94ec7fec833fc8e0b0de7a6b345c9dfef5d0b1a7cbdf30613fdc44
3
  size 5043
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:238980dd76c8155164b84d0096d11a8cbba25c933f4335fc7369e77f2328bd26
3
  size 5043
filters/rotate/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e396e33163faf6dbf2a8eca318908efa92dea5d0d8c24a46439a925497543431
3
- size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7a3c2ae45ef9887682e3e89d4138b2ce74b8e560b858f3adc0461f98f223f3f
3
+ size 531
filters/rotate/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "rotate"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "rotate"}
filters/rotate/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa91ba0d7021fcc6237cd628b6f03151b05625f377149d3d0eedd4a124407646
3
  size 4431
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a92a49387f05f4548cb4910506e66a6a2fa591b8f27818934d4283c8c2981a99
3
  size 4431
filters/sharpen/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c5dee467fb63804731e998a11323a460d2fc8a9f08f4480d9d9c6deb4431447
3
- size 396
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cf53584e83a91cb975e9f078ef63e777f11453582b664f65685b3a6da89f17e
3
+ size 550
filters/sharpen/deployment/serialized_processing.json CHANGED
@@ -1 +1 @@
1
- {"model_type": "Filter", "model_post_processing_params": {"filter_name": "sharpen"}, "input_quantizers": [], "output_quantizers": [], "cml_version": "0.6.1"}
 
1
+ {"filter_name": "sharpen"}
filters/sharpen/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:87285146e97b8787261a7aa15db77819ac9e10a8c165708792db682d6a5072c7
3
  size 7311
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4710fbe92afdd4f8f6beff7eca302a46ee4fde5b85fe8aa7d6ab832080ae5a2e
3
  size 7311
generate_dev_files.py CHANGED
@@ -4,7 +4,7 @@ import shutil
4
  import onnx
5
  from common import AVAILABLE_FILTERS, FILTERS_PATH
6
  from filters import Filter
7
- from concrete.ml.deployment import FHEModelDev
8
 
9
  print("Generating deployment files for all available filters")
10
 
@@ -28,10 +28,10 @@ for filter_name in AVAILABLE_FILTERS:
28
  shutil.rmtree(deployment_path)
29
 
30
  # Save the files needed for deployment
31
- fhe_dev_filter = FHEModelDev(deployment_path, filter)
32
  fhe_dev_filter.save()
33
 
34
  # Save the ONNX model
35
- onnx.save(filter.onnx_model, filter_path / "server.onnx")
36
 
37
  print("Done !")
 
4
  import onnx
5
  from common import AVAILABLE_FILTERS, FILTERS_PATH
6
  from filters import Filter
7
+ from client_server_interface import FHEDev
8
 
9
  print("Generating deployment files for all available filters")
10
 
 
28
  shutil.rmtree(deployment_path)
29
 
30
  # Save the files needed for deployment
31
+ fhe_dev_filter = FHEDev(filter, deployment_path)
32
  fhe_dev_filter.save()
33
 
34
  # Save the ONNX model
35
+ onnx.save(filter.onnx_graph, filter_path / "server.onnx")
36
 
37
  print("Done !")
server.py CHANGED
@@ -2,12 +2,11 @@
2
 
3
  import time
4
  from typing import List
5
-
6
- from common import FILTERS_PATH, SERVER_TMP_PATH
7
  from fastapi import FastAPI, File, Form, UploadFile
8
  from fastapi.responses import JSONResponse, Response
9
- from pydantic import BaseModel
10
- from concrete.ml.deployment import FHEModelServer
 
11
 
12
 
13
  def get_server_file_path(name, user_id, filter_name):
@@ -24,10 +23,6 @@ def get_server_file_path(name, user_id, filter_name):
24
  return SERVER_TMP_PATH / f"{name}_{filter_name}_{user_id}"
25
 
26
 
27
- class FilterRequest(BaseModel):
28
- filter: str
29
-
30
-
31
  # Initialize an instance of FastAPI
32
  app = FastAPI()
33
 
@@ -74,7 +69,7 @@ def run_fhe(
74
  evaluation_key = evaluation_key_file.read()
75
 
76
  # Load the FHE server
77
- fhe_server = FHEModelServer(FILTERS_PATH / f"{filter}/deployment")
78
 
79
  # Run the FHE execution
80
  start = time.time()
 
2
 
3
  import time
4
  from typing import List
 
 
5
  from fastapi import FastAPI, File, Form, UploadFile
6
  from fastapi.responses import JSONResponse, Response
7
+
8
+ from common import FILTERS_PATH, SERVER_TMP_PATH
9
+ from client_server_interface import FHEServer
10
 
11
 
12
  def get_server_file_path(name, user_id, filter_name):
 
23
  return SERVER_TMP_PATH / f"{name}_{filter_name}_{user_id}"
24
 
25
 
 
 
 
 
26
  # Initialize an instance of FastAPI
27
  app = FastAPI()
28
 
 
69
  evaluation_key = evaluation_key_file.read()
70
 
71
  # Load the FHE server
72
+ fhe_server = FHEServer(FILTERS_PATH / f"{filter}/deployment")
73
 
74
  # Run the FHE execution
75
  start = time.time()