kcelia commited on
Commit
cf6aebf
β€’
1 Parent(s): 1a494e6

chore: add encryption

Browse files
Files changed (2) hide show
  1. app.py +35 -11
  2. utils_demo.py +5 -0
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import os
4
  import re
5
  from typing import Dict, List
6
-
7
  import gradio as gr
8
  import pandas as pd
9
  from fhe_anonymizer import FHEAnonymizer
@@ -11,16 +11,23 @@ from openai import OpenAI
11
  from utils_demo import *
12
  from concrete.ml.deployment import FHEModelClient
13
 
 
14
  ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
15
  ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
16
  MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
17
 
 
 
 
18
  clean_directory()
19
 
20
  anonymizer = FHEAnonymizer()
21
 
22
  client = OpenAI(api_key=os.environ.get("openaikey"))
23
 
 
 
 
24
 
25
  def select_static_sentences_fn(selected_sentences: List):
26
 
@@ -39,11 +46,7 @@ def key_gen_fn() -> Dict:
39
  Returns:
40
  dict: A dictionary containing the generated keys and related information.
41
  """
42
- print("Key Gen..")
43
-
44
- # Generate a random user ID
45
- user_id = np.random.randint(0, 2**32)
46
- print(f"Your user ID is: {user_id}....")
47
 
48
  client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
49
  client.load()
@@ -74,16 +77,16 @@ def key_gen_fn() -> Dict:
74
 
75
 
76
  def encrypt_query_fn(query):
77
- print(f"Query: {query}")
78
 
79
- evaluation_key_path = KEYS_DIR / "evaluation_key"
 
 
80
 
81
  if not evaluation_key_path.is_file():
82
  error_message = "Error ❌: Please generate the key first!"
83
  return {output_encrypted_box: gr.update(value=error_message)}
84
 
85
  if is_user_query_valid(query):
86
- # TODO: check if the query is related to our context
87
  error_msg = (
88
  "Unable to process ❌: The request exceeds the length limit or falls "
89
  "outside the scope of this document. Please refine your query."
@@ -91,9 +94,30 @@ def encrypt_query_fn(query):
91
  print(error_msg)
92
  return {query_box: gr.update(value=error_msg)}
93
 
94
- anonymizer.encrypt_query(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- encrypted_tokens = read_pickle(KEYS_DIR / "encrypted_quantized_query")
97
 
98
  encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
99
 
 
3
  import os
4
  import re
5
  from typing import Dict, List
6
+ import numpy
7
  import gradio as gr
8
  import pandas as pd
9
  from fhe_anonymizer import FHEAnonymizer
 
11
  from utils_demo import *
12
  from concrete.ml.deployment import FHEModelClient
13
 
14
+
15
  ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
16
  ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
17
  MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
18
 
19
+ subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
20
+ time.sleep(3)
21
+
22
  clean_directory()
23
 
24
  anonymizer = FHEAnonymizer()
25
 
26
  client = OpenAI(api_key=os.environ.get("openaikey"))
27
 
28
+ # Generate a random user ID
29
+ user_id = numpy.random.randint(0, 2**32)
30
+ print(f"Your user ID is: {user_id}....")
31
 
32
  def select_static_sentences_fn(selected_sentences: List):
33
 
 
46
  Returns:
47
  dict: A dictionary containing the generated keys and related information.
48
  """
49
+ print("Step 1: Key Generation:")
 
 
 
 
50
 
51
  client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
52
  client.load()
 
77
 
78
 
79
  def encrypt_query_fn(query):
 
80
 
81
+ print(f"Step 2 Query encryption: {query=}")
82
+
83
+ evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
84
 
85
  if not evaluation_key_path.is_file():
86
  error_message = "Error ❌: Please generate the key first!"
87
  return {output_encrypted_box: gr.update(value=error_message)}
88
 
89
  if is_user_query_valid(query):
 
90
  error_msg = (
91
  "Unable to process ❌: The request exceeds the length limit or falls "
92
  "outside the scope of this document. Please refine your query."
 
94
  print(error_msg)
95
  return {query_box: gr.update(value=error_msg)}
96
 
97
+ # Retrieve the client API
98
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
99
+ client.load()
100
+
101
+ # Pattern to identify words and non-words (including punctuation, spaces, etc.)
102
+ tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", query)
103
+ encrypted_tokens = []
104
+
105
+ for token in tokens:
106
+ if bool(re.match(r"^\s+$", token)):
107
+ continue
108
+ # Directly append non-word tokens or whitespace to processed_tokens
109
+
110
+ # Prediction for each word
111
+ emb_x = get_batch_text_representation([token], EMBEDDINGS_MODEL, TOKENIZER)
112
+ encrypted_x = client.quantize_encrypt_serialize(emb_x)
113
+ assert isinstance(encrypted_x, bytes)
114
+
115
+ encrypted_tokens.append(encrypted_x)
116
+
117
+ write_pickle(KEYS_DIR / f"{user_id}/encrypted_input", encrypted_tokens)
118
+
119
 
120
+ #anonymizer.encrypt_query(query)
121
 
122
  encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
123
 
utils_demo.py CHANGED
@@ -6,6 +6,7 @@ import shutil
6
  import string
7
  from collections import Counter
8
  from pathlib import Path
 
9
 
10
  import numpy as np
11
  import torch
@@ -35,6 +36,10 @@ PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"
35
 
36
  ALL_DIRS = [KEYS_DIR]
37
 
 
 
 
 
38
  PUNCTUATION_LIST = list(string.punctuation)
39
  PUNCTUATION_LIST.remove("%")
40
  PUNCTUATION_LIST.remove("$")
 
6
  import string
7
  from collections import Counter
8
  from pathlib import Path
9
+ from transformers import AutoModel, AutoTokenizer
10
 
11
  import numpy as np
12
  import torch
 
36
 
37
  ALL_DIRS = [KEYS_DIR]
38
 
39
+ # Load tokenizer and model
40
+ TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
41
+ EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
42
+
43
  PUNCTUATION_LIST = list(string.punctuation)
44
  PUNCTUATION_LIST.remove("%")
45
  PUNCTUATION_LIST.remove("$")