chore: add encryption
Browse files- app.py +35 -11
- utils_demo.py +5 -0
app.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import os
|
4 |
import re
|
5 |
from typing import Dict, List
|
6 |
-
|
7 |
import gradio as gr
|
8 |
import pandas as pd
|
9 |
from fhe_anonymizer import FHEAnonymizer
|
@@ -11,16 +11,23 @@ from openai import OpenAI
|
|
11 |
from utils_demo import *
|
12 |
from concrete.ml.deployment import FHEModelClient
|
13 |
|
|
|
14 |
ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
|
15 |
ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
|
16 |
MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
|
17 |
|
|
|
|
|
|
|
18 |
clean_directory()
|
19 |
|
20 |
anonymizer = FHEAnonymizer()
|
21 |
|
22 |
client = OpenAI(api_key=os.environ.get("openaikey"))
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
def select_static_sentences_fn(selected_sentences: List):
|
26 |
|
@@ -39,11 +46,7 @@ def key_gen_fn() -> Dict:
|
|
39 |
Returns:
|
40 |
dict: A dictionary containing the generated keys and related information.
|
41 |
"""
|
42 |
-
print("Key
|
43 |
-
|
44 |
-
# Generate a random user ID
|
45 |
-
user_id = np.random.randint(0, 2**32)
|
46 |
-
print(f"Your user ID is: {user_id}....")
|
47 |
|
48 |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
49 |
client.load()
|
@@ -74,16 +77,16 @@ def key_gen_fn() -> Dict:
|
|
74 |
|
75 |
|
76 |
def encrypt_query_fn(query):
|
77 |
-
print(f"Query: {query}")
|
78 |
|
79 |
-
|
|
|
|
|
80 |
|
81 |
if not evaluation_key_path.is_file():
|
82 |
error_message = "Error β: Please generate the key first!"
|
83 |
return {output_encrypted_box: gr.update(value=error_message)}
|
84 |
|
85 |
if is_user_query_valid(query):
|
86 |
-
# TODO: check if the query is related to our context
|
87 |
error_msg = (
|
88 |
"Unable to process β: The request exceeds the length limit or falls "
|
89 |
"outside the scope of this document. Please refine your query."
|
@@ -91,9 +94,30 @@ def encrypt_query_fn(query):
|
|
91 |
print(error_msg)
|
92 |
return {query_box: gr.update(value=error_msg)}
|
93 |
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
|
98 |
encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
|
99 |
|
|
|
3 |
import os
|
4 |
import re
|
5 |
from typing import Dict, List
|
6 |
+
import numpy
|
7 |
import gradio as gr
|
8 |
import pandas as pd
|
9 |
from fhe_anonymizer import FHEAnonymizer
|
|
|
11 |
from utils_demo import *
|
12 |
from concrete.ml.deployment import FHEModelClient
|
13 |
|
14 |
+
|
15 |
ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
|
16 |
ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
|
17 |
MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
|
18 |
|
19 |
+
subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
|
20 |
+
time.sleep(3)
|
21 |
+
|
22 |
clean_directory()
|
23 |
|
24 |
anonymizer = FHEAnonymizer()
|
25 |
|
26 |
client = OpenAI(api_key=os.environ.get("openaikey"))
|
27 |
|
28 |
+
# Generate a random user ID
|
29 |
+
user_id = numpy.random.randint(0, 2**32)
|
30 |
+
print(f"Your user ID is: {user_id}....")
|
31 |
|
32 |
def select_static_sentences_fn(selected_sentences: List):
|
33 |
|
|
|
46 |
Returns:
|
47 |
dict: A dictionary containing the generated keys and related information.
|
48 |
"""
|
49 |
+
print("Step 1: Key Generation:")
|
|
|
|
|
|
|
|
|
50 |
|
51 |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
52 |
client.load()
|
|
|
77 |
|
78 |
|
79 |
def encrypt_query_fn(query):
|
|
|
80 |
|
81 |
+
print(f"Step 2 Query encryption: {query=}")
|
82 |
+
|
83 |
+
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
|
84 |
|
85 |
if not evaluation_key_path.is_file():
|
86 |
error_message = "Error β: Please generate the key first!"
|
87 |
return {output_encrypted_box: gr.update(value=error_message)}
|
88 |
|
89 |
if is_user_query_valid(query):
|
|
|
90 |
error_msg = (
|
91 |
"Unable to process β: The request exceeds the length limit or falls "
|
92 |
"outside the scope of this document. Please refine your query."
|
|
|
94 |
print(error_msg)
|
95 |
return {query_box: gr.update(value=error_msg)}
|
96 |
|
97 |
+
# Retrieve the client API
|
98 |
+
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
99 |
+
client.load()
|
100 |
+
|
101 |
+
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
|
102 |
+
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", query)
|
103 |
+
encrypted_tokens = []
|
104 |
+
|
105 |
+
for token in tokens:
|
106 |
+
if bool(re.match(r"^\s+$", token)):
|
107 |
+
continue
|
108 |
+
# Directly append non-word tokens or whitespace to processed_tokens
|
109 |
+
|
110 |
+
# Prediction for each word
|
111 |
+
emb_x = get_batch_text_representation([token], EMBEDDINGS_MODEL, TOKENIZER)
|
112 |
+
encrypted_x = client.quantize_encrypt_serialize(emb_x)
|
113 |
+
assert isinstance(encrypted_x, bytes)
|
114 |
+
|
115 |
+
encrypted_tokens.append(encrypted_x)
|
116 |
+
|
117 |
+
write_pickle(KEYS_DIR / f"{user_id}/encrypted_input", encrypted_tokens)
|
118 |
+
|
119 |
|
120 |
+
#anonymizer.encrypt_query(query)
|
121 |
|
122 |
encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
|
123 |
|
utils_demo.py
CHANGED
@@ -6,6 +6,7 @@ import shutil
|
|
6 |
import string
|
7 |
from collections import Counter
|
8 |
from pathlib import Path
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
@@ -35,6 +36,10 @@ PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"
|
|
35 |
|
36 |
ALL_DIRS = [KEYS_DIR]
|
37 |
|
|
|
|
|
|
|
|
|
38 |
PUNCTUATION_LIST = list(string.punctuation)
|
39 |
PUNCTUATION_LIST.remove("%")
|
40 |
PUNCTUATION_LIST.remove("$")
|
|
|
6 |
import string
|
7 |
from collections import Counter
|
8 |
from pathlib import Path
|
9 |
+
from transformers import AutoModel, AutoTokenizer
|
10 |
|
11 |
import numpy as np
|
12 |
import torch
|
|
|
36 |
|
37 |
ALL_DIRS = [KEYS_DIR]
|
38 |
|
39 |
+
# Load tokenizer and model
|
40 |
+
TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
|
41 |
+
EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
|
42 |
+
|
43 |
PUNCTUATION_LIST = list(string.punctuation)
|
44 |
PUNCTUATION_LIST.remove("%")
|
45 |
PUNCTUATION_LIST.remove("$")
|