File size: 4,378 Bytes
174cd37 646bd9e 174cd37 646bd9e 174cd37 646bd9e 174cd37 646bd9e 174cd37 646bd9e 2b591f4 646bd9e 174cd37 646bd9e d0b1031 1dfccc3 174cd37 646bd9e 174cd37 df6182e 174cd37 646bd9e 174cd37 646bd9e 174cd37 646bd9e 174cd37 646bd9e 174cd37 df6182e 174cd37 df6182e 174cd37 df6182e 174cd37 646bd9e 174cd37 1dfccc3 174cd37 646bd9e 174cd37 b160148 174cd37 df6182e 174cd37 646bd9e d0b1031 174cd37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import json
import re
import time
import uuid
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
from utils_demo import *
from concrete.ml.common.serialization.loaders import load
from concrete.ml.deployment import FHEModelClient, FHEModelServer
TOLERANCE_PROBA = 0.77
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
class FHEAnonymizer:
def __init__(self):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
self.punctuation_list = PUNCTUATION_LIST
self.uuid_map = read_json(MAPPING_UUID_PATH)
self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR)
self.server = FHEModelServer(DEPLOYMENT_DIR)
def generate_key(self):
clean_directory()
# Creates the private and evaluation keys on the client side
self.client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
self.evaluation_key = self.client.get_serialized_evaluation_keys()
assert isinstance(self.evaluation_key, bytes)
evaluation_key_path = KEYS_DIR / "evaluation_key"
with evaluation_key_path.open("wb") as f:
f.write(self.evaluation_key)
def encrypt_query(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
encrypted_tokens = []
for token in tokens:
if bool(re.match(r"^\s+$", token)):
continue
# Directly append non-word tokens or whitespace to processed_tokens
# Prediction for each word
emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
encrypted_x = self.client.quantize_encrypt_serialize(emb_x)
assert isinstance(encrypted_x, bytes)
encrypted_tokens.append(encrypted_x)
write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens)
def run_server(self):
encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query")
encrypted_output, timing = [], []
for enc_x in encrypted_tokens:
start_time = time.time()
enc_y = self.server.run(enc_x, self.evaluation_key)
timing.append((time.time() - start_time) / 60.0)
encrypted_output.append(enc_y)
write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output)
write_pickle(KEYS_DIR / f"encrypted_timing", timing)
return encrypted_output, timing
def decrypt_output(self, text):
encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output")
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
decrypted_output, identified_words_with_prob = [], []
i = 0
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if bool(re.match(r"^\s+$", token)):
continue
else:
encrypted_token = encrypted_output[i]
prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token)
probability = prediction_proba[0][1]
i += 1
if probability >= TOLERANCE_PROBA:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
decrypted_output.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
decrypted_output.append(token)
# Update the UUID map with query.
with open(MAPPING_UUID_PATH, "w") as file:
json.dump(self.uuid_map, file)
write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output))
write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob)
def run_server_and_decrypt_output(self, text):
self.run_server()
self.decrypt_output(text)
|