encrypted-anonymization / fhe_anonymizer.py
kcelia's picture
chore: update the space with layout
174cd37 unverified
raw history blame
No virus
4.38 kB
import json
import re
import time
import uuid
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
from utils_demo import *
from concrete.ml.common.serialization.loaders import load
from concrete.ml.deployment import FHEModelClient, FHEModelServer
TOLERANCE_PROBA = 0.77
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
class FHEAnonymizer:
def __init__(self):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
self.punctuation_list = PUNCTUATION_LIST
self.uuid_map = read_json(MAPPING_UUID_PATH)
self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR)
self.server = FHEModelServer(DEPLOYMENT_DIR)
def generate_key(self):
clean_directory()
# Creates the private and evaluation keys on the client side
self.client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
self.evaluation_key = self.client.get_serialized_evaluation_keys()
assert isinstance(self.evaluation_key, bytes)
evaluation_key_path = KEYS_DIR / "evaluation_key"
with evaluation_key_path.open("wb") as f:
f.write(self.evaluation_key)
def encrypt_query(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
encrypted_tokens = []
for token in tokens:
if bool(re.match(r"^\s+$", token)):
continue
# Directly append non-word tokens or whitespace to processed_tokens
# Prediction for each word
emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
encrypted_x = self.client.quantize_encrypt_serialize(emb_x)
assert isinstance(encrypted_x, bytes)
encrypted_tokens.append(encrypted_x)
write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens)
def run_server(self):
encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query")
encrypted_output, timing = [], []
for enc_x in encrypted_tokens:
start_time = time.time()
enc_y = self.server.run(enc_x, self.evaluation_key)
timing.append((time.time() - start_time) / 60.0)
encrypted_output.append(enc_y)
write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output)
write_pickle(KEYS_DIR / f"encrypted_timing", timing)
return encrypted_output, timing
def decrypt_output(self, text):
encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output")
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
decrypted_output, identified_words_with_prob = [], []
i = 0
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if bool(re.match(r"^\s+$", token)):
continue
else:
encrypted_token = encrypted_output[i]
prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token)
probability = prediction_proba[0][1]
i += 1
if probability >= TOLERANCE_PROBA:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
decrypted_output.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
decrypted_output.append(token)
# Update the UUID map with query.
with open(MAPPING_UUID_PATH, "w") as file:
json.dump(self.uuid_map, file)
write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output))
write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob)
def run_server_and_decrypt_output(self, text):
self.run_server()
self.decrypt_output(text)