import json import os import pickle as pkl import re import shutil import string from collections import Counter from pathlib import Path import numpy as np import torch from transformers import AutoModel, AutoTokenizer from pathlib import Path # Core Application URL SERVER_URL = "http://localhost:8000/" # Maximum length for user queries MAX_USER_QUERY_LEN = 80 # Base Directories CURRENT_DIR = Path(__file__).parent DEPLOYMENT_DIR = CURRENT_DIR / "deployment" DATA_PATH = CURRENT_DIR / "files" # Deployment Directories CLIENT_DIR = DEPLOYMENT_DIR / "client_dir" SERVER_DIR = DEPLOYMENT_DIR / "server_dir" KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" # All Directories ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR] # Model and Data Files LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model" ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt" ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt" MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json" MAPPING_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl" PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt" # List of example queries for easy access DEFAULT_QUERIES = { "Example Query 1": "Who visited microsoft.com on September 18?", "Example Query 2": "Does Kate have a driving licence?", "Example Query 3": "What's David Johnson's phone number?", } # Load tokenizer and model TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2") EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2") PUNCTUATION_LIST = list(string.punctuation) PUNCTUATION_LIST.remove("%") PUNCTUATION_LIST.remove("$") PUNCTUATION_LIST = "".join(PUNCTUATION_LIST) def clean_directory() -> None: """Clear direcgtories""" print("Cleaning...\n") for target_dir in ALL_DIRS: if os.path.exists(target_dir) and os.path.isdir(target_dir): shutil.rmtree(target_dir) target_dir.mkdir(exist_ok=True, parents=True) def get_batch_text_representation(texts, model, tokenizer, batch_size=1): """Get mean-pooled representations of given texts in batches.""" mean_pooled_batch = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=False) last_hidden_states = outputs.last_hidden_state input_mask_expanded = ( inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float() ) sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) mean_pooled = sum_embeddings / sum_mask mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy()) return np.array(mean_pooled_batch) def is_user_query_valid(user_query: str) -> bool: """ Check if the `user_query` is None and not empty. Args: user_query (str): The input text to be checked. Returns: bool: True if the `user_query` is None or empty, False otherwise. """ # If the query is not part of the default queries is_default_query = user_query in DEFAULT_QUERIES.values() # Check if the query exceeds the length limit is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN return not is_default_query and not is_exceeded_max_length def compare_texts_ignoring_extra_spaces(original_text, modified_text): """Check if the modified_text is identical to the original_text except for additional spaces. Args: original_text (str): The original text for comparison. modified_text (str): The modified text to compare against the original. Returns: (bool): True if the modified_text is the same as the original_text except for additional spaces; False otherwise. """ normalized_original = " ".join(original_text.split()) normalized_modified = " ".join(modified_text.split()) return normalized_original == normalized_modified def is_strict_deletion_only(original_text, modified_text): # Define a regex pattern that matches a word character next to a punctuation # or a punctuation next to a word character, without a space between them. pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])" # Replace instances found by the pattern with a space original_text = re.sub(pattern, " ", original_text) modified_text = re.sub(pattern, " ", modified_text) # Tokenize the texts into words, considering also punctuation original_words = Counter(original_text.lower().split()) modified_words = Counter(modified_text.lower().split()) base_words = all(item in original_words.keys() for item in modified_words.keys()) base_count = all(original_words[k] >= v for k, v in modified_words.items()) return base_words and base_count def read_txt(file_path): """Read text from a file.""" with open(file_path, "r", encoding="utf-8") as file: return file.read() def write_txt(file_path, data): """Write text to a file.""" with open(file_path, "w", encoding="utf-8") as file: file.write(data) def write_pickle(file_path, data): """Save data to a pickle file.""" with open(file_path, "wb") as f: pkl.dump(data, f) def read_pickle(file_name): """Load data from a pickle file.""" with open(file_name, "rb") as file: return pkl.load(file) def read_json(file_name): """Load data from a json file.""" with open(file_name, "r") as file: return json.load(file) def write_json(file_name, data): """Save data to a json file.""" with open(file_name, "w", encoding="utf-8") as file: json.dump(data, file, indent=4, sort_keys=True) def write_bytes(path, data): """Save binary data.""" with path.open("wb") as f: f.write(data) def read_bytes(path): """Load data from a binary file.""" with path.open("rb") as f: return f.read()