kcelia commited on
Commit
3d845fb
1 Parent(s): 254c61d

chore: update

Browse files
Files changed (1) hide show
  1. app.py +227 -141
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import pickle as pkl
2
  import shutil
 
3
  from pathlib import Path
4
  from time import time
5
  from typing import List, Tuple, Union
@@ -7,33 +8,35 @@ from typing import List, Tuple, Union
7
  import gradio as gr
8
  import numpy as np
9
  import pandas as pd
10
- from sklearn import metrics, preprocessing
11
- from sklearn.ensemble import RandomForestClassifier as SklearnRandomForestClassifier
12
- from sklearn.model_selection import train_test_split
13
 
14
- from concrete.ml.common.serialization.loaders import load, loads
15
  from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
16
  from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
17
 
18
- path_to_model = Path("./client_folder").resolve()
19
-
20
- import subprocess
21
-
22
- from preprocessing import ( # pylint: disable=wrong-import-position, no-name-in-module
23
- map_prediction,
24
- pretty_print,
25
- )
26
- from symptoms_categories import SYMPTOMS_LIST
27
 
28
- ENCRYPTED_DATA_BROWSER_LIMIT = 500
29
- # This repository's directory
30
  REPO_DIR = Path(__file__).parent
 
 
 
 
31
 
32
- print(f"{REPO_DIR=}")
33
  # subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
34
  # time.sleep(3)
35
 
36
 
 
 
 
 
 
 
 
 
 
37
  def load_data():
38
  # Load data
39
  df_train = pd.read_csv("./data/Training_preprocessed.csv")
@@ -61,75 +64,8 @@ def load_model(X_train, y_train):
61
  return classifier, circuit
62
 
63
 
64
- def key_gen():
65
-
66
- # Key serialization
67
- user_id = np.random.randint(0, 2**32)
68
-
69
- client = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
70
- client.load()
71
-
72
- # The client first need to create the private and evaluation keys.
73
-
74
- client.generate_private_and_evaluation_keys()
75
-
76
- # Get the serialized evaluation keys
77
- serialized_evaluation_keys = client.get_serialized_evaluation_keys()
78
- assert isinstance(serialized_evaluation_keys, bytes)
79
-
80
- np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
81
-
82
- serialized_evaluation_keys_shorten = list(serialized_evaluation_keys)[:200]
83
- serialized_evaluation_keys_shorten_hex = "".join(
84
- f"{i:02x}" for i in serialized_evaluation_keys_shorten
85
- )
86
- # Evaluation keys can be quite large files but only have to be shared once with the server.
87
-
88
- # Check the size of the evaluation keys (in MB)
89
- return [
90
- serialized_evaluation_keys_shorten_hex,
91
- user_id,
92
- f"{len(serialized_evaluation_keys) / (10**6):.2f} MB",
93
- ]
94
-
95
-
96
- def encode_quantize_encrypt(user_symptoms, user_id):
97
- # check if the key has been generated
98
- client = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
99
- client.load()
100
-
101
- user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
102
-
103
- quant_user_symptoms = client.model.quantize_input(user_symptoms)
104
- encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
105
-
106
- # print(client.model.predict(vect_x, fhe="simulate"), client.model.predict(vect_x, fhe="execute"))
107
- # pred_s = client.model.fhe_circuit.simulate(quant_vect)
108
- # pred_fhe = client.model.fhe_circuit.encrypt_run_decrypt(quant_vect) #
109
- # non alpha -> \X1124, base64 ou en exa
110
-
111
- # Compute size
112
-
113
- np.save(f".fhe_keys/{user_id}/encrypted_quant_vect.npy", encrypted_quantized_user_symptoms)
114
-
115
- encrypted_quantized_encoding_shorten = list(encrypted_quantized_user_symptoms)[:200]
116
- encrypted_quantized_encoding_shorten_hex = "".join(
117
- f"{i:02x}" for i in encrypted_quantized_encoding_shorten
118
- )
119
-
120
- return user_symptoms, quant_user_symptoms, encrypted_quantized_encoding_shorten_hex
121
-
122
-
123
- def decrypt_prediction(encrypted_quantized_vect, user_id):
124
- fhe_api = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
125
- fhe_api.load()
126
- fhe_api.generate_private_and_evaluation_keys(force=False)
127
- predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_quantized_vect)
128
- return predictions
129
-
130
-
131
  def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array:
132
- symptoms_vector = {key: 0 for key in valid_columns}
133
 
134
  for symptom_box in user_symptoms:
135
  for pretty_symptom in symptom_box:
@@ -148,7 +84,7 @@ def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array:
148
  return user_symptoms_vect
149
 
150
 
151
- def get_user_vect_symptoms_from_default_disease(disease):
152
 
153
  user_symptom_vector = df_test[df_test["prognosis"] == disease].iloc[0].values
154
 
@@ -165,45 +101,40 @@ def get_user_symptoms_from_default_disease(disease):
165
  return pretty_print(columns_with_1)
166
 
167
 
168
- def get_user_symptoms_vector_btn(selected_default_disease, *selected_symptoms):
169
-
170
- if any(lst for lst in selected_symptoms if lst) and (
171
- selected_default_disease is not None and len(selected_default_disease) > 0
172
- ):
173
- # If the user has already selected a disease and added more symptoms, raise an error
174
- if set(pretty_print(selected_symptoms)) - set(
175
- get_user_symptoms_from_default_disease(selected_default_disease)
176
- ):
177
- return {
178
- user_vector_textbox: gr.update(value="An error occurs"),
179
- error_box: gr.update(
180
- visible=True, value="Enter a default disease or select your own symptoms"
181
- ),
182
- }
183
- # If the user has not selected a default disease or symptoms, an error is raised.
184
- if not any(lst for lst in selected_symptoms if lst) and (
185
- selected_default_disease is None
186
- or (selected_default_disease is not None and len(selected_default_disease) < 1)
187
  ):
188
  return {
189
- user_vector_textbox: gr.update(value="An error occurs"),
190
- error_box: gr.update(
191
  visible=True, value="Enter a default disease or select your own symptoms"
192
  ),
193
  }
194
  # Case 1: The user has checked his own symptoms
195
  if any(lst for lst in selected_symptoms if lst):
196
  return {
 
197
  user_vector_textbox: get_user_vect_symptoms_from_checkboxgroup(*selected_symptoms),
198
  }
199
 
200
  # Case 2: The user has selected a default disease
201
  if selected_default_disease is not None and len(selected_default_disease) > 0:
202
  return {
203
- user_vector_textbox: get_user_vect_symptoms_from_default_disease(
204
- selected_default_disease
205
- ),
206
- error_box: gr.update(visible=False),
207
  **{
208
  box: get_user_symptoms_from_default_disease(selected_default_disease)
209
  for box in check_boxes
@@ -211,24 +142,166 @@ def get_user_symptoms_vector_btn(selected_default_disease, *selected_symptoms):
211
  }
212
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def clear_all_btn():
215
  return {
 
216
  user_id_textbox: None,
217
  eval_key_textbox: None,
218
- eval_key_len_textbox: None,
219
  user_vector_textbox: None,
220
- box_default: None,
221
- error_box: gr.update(visible=False),
 
 
 
222
  **{box: None for box in check_boxes},
223
  }
224
 
225
 
226
  if __name__ == "__main__":
227
  print("Starting demo ...")
 
228
 
229
  (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
230
 
231
- valid_columns = X_train.columns.to_list()
232
 
233
  # Load the model
234
  with open("ConcreteXGBoostClassifier.pkl", "r", encoding="utf-8") as file:
@@ -285,6 +358,8 @@ if __name__ == "__main__":
285
  )
286
  check_boxes.append(check_box)
287
 
 
 
288
  # User symptom vector
289
  with gr.Row():
290
  user_vector_textbox = gr.Textbox(
@@ -292,7 +367,6 @@ if __name__ == "__main__":
292
  interactive=False,
293
  max_lines=100,
294
  )
295
- error_box = gr.Textbox(label="Error", visible=False)
296
 
297
  with gr.Row():
298
  # Submit botton
@@ -300,20 +374,22 @@ if __name__ == "__main__":
300
  submit_button = gr.Button("Submit")
301
  # Clear botton
302
  with gr.Column():
303
- clear_button = gr.Button("Clear", style="background-color: yellow;")
304
 
305
  # Click submit botton
306
 
307
  submit_button.click(
308
- fn=get_user_symptoms_vector_btn,
309
  inputs=[box_default, *check_boxes],
310
- outputs=[user_vector_textbox, error_box, *check_boxes],
311
  )
312
 
313
  gr.Markdown("# Step 2: Generate the keys")
314
  gr.Markdown("Client side")
315
 
316
- gen_key = gr.Button("Generate the keys and send public part to server")
 
 
317
 
318
  with gr.Row():
319
  # User ID
@@ -338,25 +414,18 @@ if __name__ == "__main__":
338
  interactive=False,
339
  )
340
 
341
- gen_key.click(key_gen, outputs=[eval_key_textbox, user_id_textbox, eval_key_len_textbox])
342
-
343
- clear_button.click(
344
- clear_all_btn,
345
- outputs=[
346
- user_id_textbox,
347
- user_vector_textbox,
348
- eval_key_textbox,
349
- eval_key_len_textbox,
350
- box_default,
351
- error_box,
352
- *check_boxes,
353
- ],
354
  )
355
 
356
  gr.Markdown("# Step 3: Encode the message with the private key")
357
  gr.Markdown("Client side")
358
 
359
- encode_msg = gr.Button("Generate the keys and send public part to server")
 
 
360
 
361
  with gr.Row():
362
 
@@ -377,10 +446,10 @@ if __name__ == "__main__":
377
  label="Encrypted vector:", max_lines=4, interactive=False
378
  )
379
 
380
- encode_msg.click(
381
- encode_quantize_encrypt,
382
  inputs=[user_vector_textbox, user_id_textbox],
383
- outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox],
384
  )
385
 
386
  gr.Markdown("# Step 4: Run the FHE evaluation")
@@ -396,10 +465,27 @@ if __name__ == "__main__":
396
  label="Encrypted vector:", max_lines=4, interactive=False
397
  )
398
 
399
- decrypt_target_botton.click(
400
- decrypt_prediction,
401
- inputs=[encrypted_vect_textbox, user_id_textbox],
402
- outputs=[decrypt_target_textbox],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  )
404
 
405
  demo.launch()
 
1
+ import os
2
  import shutil
3
+ import subprocess
4
  from pathlib import Path
5
  from time import time
6
  from typing import List, Tuple, Union
 
8
  import gradio as gr
9
  import numpy as np
10
  import pandas as pd
11
+ from preprocessing import pretty_print
12
+ from symptoms_categories import SYMPTOMS_LIST
 
13
 
14
+ from concrete.ml.common.serialization.loaders import load
15
  from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
16
  from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
17
 
18
+ INPUT_BROWSER_LIMIT = 635
 
 
 
 
 
 
 
 
19
 
20
+ # This repository's main necessary folders
 
21
  REPO_DIR = Path(__file__).parent
22
+ MODEL_PATH = REPO_DIR / "client_folder"
23
+ KEYS_PATH = REPO_DIR / ".fhe_keys"
24
+ CLIENT_PATH = MODEL_PATH / "client.zip"
25
+ SERVER_PATH = MODEL_PATH / "server.zip"
26
 
 
27
  # subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
28
  # time.sleep(3)
29
 
30
 
31
+ def clean_directory():
32
+ target_dir = ".fhe_keys"
33
+ if os.path.exists(target_dir) and os.path.isdir(target_dir):
34
+ shutil.rmtree(target_dir)
35
+ print("The .fhe_keys directory and its contents have been successfully removed.")
36
+ else:
37
+ print("The .keys directory does not exist.")
38
+
39
+
40
  def load_data():
41
  # Load data
42
  df_train = pd.read_csv("./data/Training_preprocessed.csv")
 
64
  return classifier, circuit
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array:
68
+ symptoms_vector = {key: 0 for key in VALID_COLUMNS}
69
 
70
  for symptom_box in user_symptoms:
71
  for pretty_symptom in symptom_box:
 
84
  return user_symptoms_vect
85
 
86
 
87
+ def get_user_vector_from_default_disease(disease):
88
 
89
  user_symptom_vector = df_test[df_test["prognosis"] == disease].iloc[0].values
90
 
 
101
  return pretty_print(columns_with_1)
102
 
103
 
104
+ def get_user_symptoms_vector_fn(selected_default_disease, *selected_symptoms):
105
+
106
+ # Display an error box, if:
107
+ # 1. The user has already selected a default disease and added more symptoms, or
108
+ # 2. The the user has not selected a default disease or symptoms
109
+ if (
110
+ any(lst for lst in selected_symptoms if lst)
111
+ and (selected_default_disease is not None and len(selected_default_disease) > 0)
112
+ and set(pretty_print(selected_symptoms))
113
+ - set(get_user_symptoms_from_default_disease(selected_default_disease))
114
+ ) or (
115
+ not any(lst for lst in selected_symptoms if lst)
116
+ and (
117
+ selected_default_disease is None
118
+ or (selected_default_disease is not None and len(selected_default_disease) < 1)
119
+ )
 
 
 
120
  ):
121
  return {
122
+ error_box_1: gr.update(
 
123
  visible=True, value="Enter a default disease or select your own symptoms"
124
  ),
125
  }
126
  # Case 1: The user has checked his own symptoms
127
  if any(lst for lst in selected_symptoms if lst):
128
  return {
129
+ error_box_1: gr.update(visible=False),
130
  user_vector_textbox: get_user_vect_symptoms_from_checkboxgroup(*selected_symptoms),
131
  }
132
 
133
  # Case 2: The user has selected a default disease
134
  if selected_default_disease is not None and len(selected_default_disease) > 0:
135
  return {
136
+ user_vector_textbox: get_user_vector_from_default_disease(selected_default_disease),
137
+ error_box_1: gr.update(visible=False),
 
 
138
  **{
139
  box: get_user_symptoms_from_default_disease(selected_default_disease)
140
  for box in check_boxes
 
142
  }
143
 
144
 
145
+ def key_gen_fn(user_symptoms):
146
+
147
+ print("Cleaning directory ...")
148
+ clean_directory()
149
+
150
+ if user_symptoms is None or (user_symptoms is not None and len(user_symptoms) < 1):
151
+ print("Please submit your symptoms first")
152
+ return {
153
+ error_box_2: gr.update(visible=True, value="Please submit your symptoms first"),
154
+ }
155
+
156
+ # Key serialization
157
+ user_id = np.random.randint(0, 2**32)
158
+
159
+ client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}")
160
+ client.load()
161
+
162
+ # The client first need to create the private and evaluation keys.
163
+
164
+ client.generate_private_and_evaluation_keys()
165
+
166
+ # Get the serialized evaluation keys
167
+ serialized_evaluation_keys = client.get_serialized_evaluation_keys()
168
+ assert isinstance(serialized_evaluation_keys, bytes)
169
+
170
+ # np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
171
+ evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key"
172
+ with evaluation_key_path.open("wb") as evaluation_key_file:
173
+ evaluation_key_file.write(serialized_evaluation_keys)
174
+
175
+ serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT]
176
+
177
+ return {
178
+ error_box_2: gr.update(visible=False),
179
+ eval_key_textbox: serialized_evaluation_keys_shorten_hex,
180
+ user_id_textbox: user_id,
181
+ eval_key_len_textbox: f"{len(serialized_evaluation_keys) / (10**6):.2f} MB",
182
+ }
183
+
184
+
185
+ def encrypt_fn(user_symptoms, user_id):
186
+
187
+ if not user_symptoms or not user_symptoms:
188
+ return {
189
+ error_box_3: gr.update(
190
+ visible=True, value="Please ensure that the evaluation key has been generated!"
191
+ )
192
+ }
193
+
194
+ # Retrieve the client API
195
+
196
+ client = FHEModelClient(path_dir=MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}")
197
+ client.load()
198
+
199
+ user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
200
+
201
+ quant_user_symptoms = client.model.quantize_input(user_symptoms)
202
+ encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
203
+
204
+ encrypted_input_path = KEYS_PATH / f"{user_id}/encrypted_symptoms"
205
+
206
+ with encrypted_input_path.open("wb") as f:
207
+ f.write(encrypted_quantized_user_symptoms)
208
+
209
+ # print(client.model.predict(vect_x, fhe="simulate"), client.model.predict(vect_x, fhe="execute"))
210
+ # pred_s = client.model.fhe_circuit.simulate(quant_vect)
211
+ # pred_fhe = client.model.fhe_circuit.encrypt_run_decrypt(quant_vect) #
212
+ # non alpha -> \X1124, base64 ou en exa
213
+
214
+ # Compute size
215
+
216
+ # np.save(f".fhe_keys/{user_id}/encrypted_quant_vect.npy", encrypted_quantized_user_symptoms)
217
+
218
+ encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[
219
+ :INPUT_BROWSER_LIMIT
220
+ ]
221
+
222
+ return {
223
+ error_box_3: gr.update(visible=False),
224
+ vect_textbox: user_symptoms,
225
+ quant_vect_textbox: quant_user_symptoms,
226
+ encrypted_vect_textbox: encrypted_quantized_user_symptoms_shorten_hex,
227
+ }
228
+
229
+
230
+ # def send_input(user_id, user_symptoms):
231
+ # """Send the encrypted input image as well as the evaluation key to the server.
232
+
233
+ # Args:
234
+ # user_id (int): The current user's ID.
235
+ # filter_name (str): The current filter to consider.
236
+ # """
237
+ # # Get the evaluation key path
238
+
239
+
240
+ # evaluation_key_path = get_client_file_path("evaluation_key", user_id, filter_name)
241
+
242
+ # if user_id == "" or not evaluation_key_path.is_file():
243
+ # raise gr.Error("Please generate the private key first.")
244
+
245
+ # encrypted_input_path = get_client_file_path("encrypted_image", user_id, filter_name)
246
+ # encrypted_symptoms_path = KEYS_PATH / f"{user_id}" / "encrypted_symtoms"
247
+
248
+ # if not encrypted_input_path.is_file():
249
+ # raise gr.Error("Please generate the private key and then encrypt an image first.")
250
+
251
+ # # Define the data and files to post
252
+ # data = {
253
+ # "user_id": user_id,
254
+ # "filter": filter_name,
255
+ # }
256
+
257
+ # files = [
258
+ # ("files", open(encrypted_input_path, "rb")),
259
+ # ("files", open(evaluation_key_path, "rb")),
260
+ # ]
261
+
262
+ # # Send the encrypted input image and evaluation key to the server
263
+ # url = SERVER_URL + "send_input"
264
+ # with requests.post(
265
+ # url=url,
266
+ # data=data,
267
+ # files=files,
268
+ # ) as response:
269
+ # return response.ok
270
+
271
+
272
+ # def decrypt_prediction(encrypted_quantized_vect, user_id):
273
+ # fhe_api = FHEModelClient(path_dir=REPO_DIR, key_dir=f".fhe_keys/{user_id}")
274
+ # fhe_api.load()
275
+ # fhe_api.generate_private_and_evaluation_keys(force=False)
276
+ # predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_quantized_vect)
277
+ # return predictions
278
+
279
+
280
+
281
+
282
  def clear_all_btn():
283
  return {
284
+ box_default: None,
285
  user_id_textbox: None,
286
  eval_key_textbox: None,
287
+ quant_vect_textbox: None,
288
  user_vector_textbox: None,
289
+ eval_key_len_textbox: None,
290
+ encrypted_vect_textbox: None,
291
+ error_box_1: gr.update(visible=False),
292
+ error_box_2: gr.update(visible=False),
293
+ error_box_3: gr.update(visible=False),
294
  **{box: None for box in check_boxes},
295
  }
296
 
297
 
298
  if __name__ == "__main__":
299
  print("Starting demo ...")
300
+
301
 
302
  (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
303
 
304
+ VALID_COLUMNS = X_train.columns.to_list()
305
 
306
  # Load the model
307
  with open("ConcreteXGBoostClassifier.pkl", "r", encoding="utf-8") as file:
 
358
  )
359
  check_boxes.append(check_box)
360
 
361
+ error_box_1 = gr.Textbox(label="Error", visible=False)
362
+
363
  # User symptom vector
364
  with gr.Row():
365
  user_vector_textbox = gr.Textbox(
 
367
  interactive=False,
368
  max_lines=100,
369
  )
 
370
 
371
  with gr.Row():
372
  # Submit botton
 
374
  submit_button = gr.Button("Submit")
375
  # Clear botton
376
  with gr.Column():
377
+ clear_button = gr.Button("Clear")
378
 
379
  # Click submit botton
380
 
381
  submit_button.click(
382
+ fn=get_user_symptoms_vector_fn,
383
  inputs=[box_default, *check_boxes],
384
+ outputs=[user_vector_textbox, error_box_1, *check_boxes],
385
  )
386
 
387
  gr.Markdown("# Step 2: Generate the keys")
388
  gr.Markdown("Client side")
389
 
390
+ gen_key_btn = gr.Button("Generate the keys and send public part to server")
391
+
392
+ error_box_2 = gr.Textbox(label="Error", visible=False)
393
 
394
  with gr.Row():
395
  # User ID
 
414
  interactive=False,
415
  )
416
 
417
+ gen_key_btn.click(
418
+ key_gen_fn,
419
+ inputs=user_vector_textbox,
420
+ outputs=[eval_key_textbox, user_id_textbox, eval_key_len_textbox, error_box_2],
 
 
 
 
 
 
 
 
 
421
  )
422
 
423
  gr.Markdown("# Step 3: Encode the message with the private key")
424
  gr.Markdown("Client side")
425
 
426
+ encrypt_btn = gr.Button("Encode the message with the private key and send it to the server")
427
+
428
+ error_box_3 = gr.Textbox(label="Error", visible=False)
429
 
430
  with gr.Row():
431
 
 
446
  label="Encrypted vector:", max_lines=4, interactive=False
447
  )
448
 
449
+ encrypt_btn.click(
450
+ encrypt_fn,
451
  inputs=[user_vector_textbox, user_id_textbox],
452
+ outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox, error_box_3],
453
  )
454
 
455
  gr.Markdown("# Step 4: Run the FHE evaluation")
 
465
  label="Encrypted vector:", max_lines=4, interactive=False
466
  )
467
 
468
+ # decrypt_target_botton.click(
469
+ # decrypt_prediction,
470
+ # inputs=[encrypted_vect_textbox, user_id_textbox],
471
+ # outputs=[decrypt_target_textbox],
472
+ # )
473
+
474
+ clear_button.click(
475
+ clear_all_btn,
476
+ outputs=[
477
+ box_default,
478
+ error_box_1,
479
+ error_box_2,
480
+ error_box_3,
481
+ user_id_textbox,
482
+ eval_key_textbox,
483
+ quant_vect_textbox,
484
+ user_vector_textbox,
485
+ eval_key_len_textbox,
486
+ encrypted_vect_textbox,
487
+ *check_boxes,
488
+ ],
489
  )
490
 
491
  demo.launch()