|
import json |
|
import re |
|
import time |
|
import uuid |
|
from pathlib import Path |
|
|
|
from transformers import AutoModel, AutoTokenizer |
|
from utils_demo import * |
|
|
|
from concrete.ml.common.serialization.loaders import load |
|
from concrete.ml.deployment import FHEModelClient, FHEModelServer |
|
|
|
TOLERANCE_PROBA = 0.77 |
|
|
|
CURRENT_DIR = Path(__file__).parent |
|
|
|
DEPLOYMENT_DIR = CURRENT_DIR / "deployment" |
|
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" |
|
|
|
|
|
class FHEAnonymizer: |
|
def __init__(self): |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2") |
|
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2") |
|
|
|
self.punctuation_list = PUNCTUATION_LIST |
|
self.uuid_map = read_json(MAPPING_UUID_PATH) |
|
|
|
self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR) |
|
self.server = FHEModelServer(DEPLOYMENT_DIR) |
|
|
|
def generate_key(self): |
|
|
|
clean_directory() |
|
|
|
|
|
self.client.generate_private_and_evaluation_keys() |
|
|
|
|
|
self.evaluation_key = self.client.get_serialized_evaluation_keys() |
|
assert isinstance(self.evaluation_key, bytes) |
|
|
|
evaluation_key_path = KEYS_DIR / "evaluation_key" |
|
|
|
with evaluation_key_path.open("wb") as f: |
|
f.write(self.evaluation_key) |
|
|
|
def encrypt_query(self, text: str): |
|
|
|
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text) |
|
encrypted_tokens = [] |
|
|
|
for token in tokens: |
|
if bool(re.match(r"^\s+$", token)): |
|
continue |
|
|
|
|
|
|
|
emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer) |
|
encrypted_x = self.client.quantize_encrypt_serialize(emb_x) |
|
assert isinstance(encrypted_x, bytes) |
|
|
|
encrypted_tokens.append(encrypted_x) |
|
|
|
write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens) |
|
|
|
def run_server(self): |
|
|
|
encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query") |
|
|
|
encrypted_output, timing = [], [] |
|
for enc_x in encrypted_tokens: |
|
start_time = time.time() |
|
enc_y = self.server.run(enc_x, self.evaluation_key) |
|
timing.append((time.time() - start_time) / 60.0) |
|
encrypted_output.append(enc_y) |
|
|
|
write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output) |
|
write_pickle(KEYS_DIR / f"encrypted_timing", timing) |
|
|
|
return encrypted_output, timing |
|
|
|
def decrypt_output(self, text): |
|
|
|
encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output") |
|
|
|
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text) |
|
decrypted_output, identified_words_with_prob = [], [] |
|
|
|
i = 0 |
|
for token in tokens: |
|
|
|
if bool(re.match(r"^\s+$", token)): |
|
continue |
|
else: |
|
encrypted_token = encrypted_output[i] |
|
prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token) |
|
probability = prediction_proba[0][1] |
|
i += 1 |
|
|
|
if probability >= TOLERANCE_PROBA: |
|
identified_words_with_prob.append((token, probability)) |
|
|
|
|
|
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8]) |
|
decrypted_output.append(tmp_uuid) |
|
self.uuid_map[token] = tmp_uuid |
|
else: |
|
decrypted_output.append(token) |
|
|
|
|
|
with open(MAPPING_UUID_PATH, "w") as file: |
|
json.dump(self.uuid_map, file) |
|
|
|
write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output)) |
|
write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob) |
|
|
|
|
|
def run_server_and_decrypt_output(self, text): |
|
self.run_server() |
|
self.decrypt_output(text) |
|
|