Zamanonymize3 / fhe_anonymizer.py
jfrery-zama's picture
update
b160148
raw
history blame
2.74 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
import uuid
import json
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
self.embeddings_model = gensim.models.FastText.load(
str(base_dir / "models/without_pronoun_embedded_model.model")
)
self.punctuation_list = punctuation_list
with open(base_dir / "models/without_pronoun_cml_xgboost.model", "r") as model_file:
self.fhe_ner_detection = load(file=model_file)
with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
self.uuid_map = json.load(file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
tokens = re.findall(token_pattern, text)
identified_words_with_prob = []
processed_tokens = []
print(tokens)
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if not token.strip() or not re.match(r"\w+", token):
processed_tokens.append(token)
continue
# Prediction for each word
x = self.embeddings_model.wv[token][None]
# prediction_proba = self.fhe_ner_detection.predict_proba(x)
prediction_proba = self.fhe_inference(x)
probability = prediction_proba[0][1]
if probability >= 0.5:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
processed_tokens.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
processed_tokens.append(token)
# Reconstruct the sentence
reconstructed_sentence = ''.join(processed_tokens)
return reconstructed_sentence, identified_words_with_prob