File size: 3,968 Bytes
58df7f1 ce79710 58df7f1 4f6d82f 58df7f1 9b80a97 58df7f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import shutil
from pathlib import Path
from typing import Any, List, Tuple
import numpy
import pandas
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
# Max Input to be displayed on the HuggingFace space brower using Gradio
# Too large inputs, slow down the server: https://github.com/gradio-app/gradio/issues/1877
INPUT_BROWSER_LIMIT = 635
# Store the server's URL
SERVER_URL = "http://localhost:8000/"
CURRENT_DIR = Path(__file__).parent
DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
# Columns that define the target
TARGET_COLUMNS = ["prognosis_encoded", "prognosis"]
TRAINING_FILENAME = "./data/Training_preprocessed.csv"
TESTING_FILENAME = "./data/Testing_preprocessed.csv"
# pylint: disable=invalid-name
def pretty_print(inputs):
"""
Prettify and sort the input as a list of string.
Args:
inputs (Any): The inputs to be prettified.
Returns:
List: The prettified and sorted list of inputs.
"""
# Convert to a list if necessary
if not isinstance(inputs, (List, Tuple)):
inputs = list(inputs)
# Flatten the list if required
pretty_list = []
for item in inputs:
if isinstance(item, list):
pretty_list.extend([" ".join(subitem.split("_")).title() for subitem in item])
else:
pretty_list.append(" ".join(item.split("_")).title())
# Sort and prettify the input
pretty_list = sorted(list(set(pretty_list)))
return pretty_list
def clean_directory() -> None:
"""
Clear direcgtories
"""
print("Cleaning...\n")
for target_dir in ALL_DIRS:
if os.path.exists(target_dir) and os.path.isdir(target_dir):
shutil.rmtree(target_dir)
target_dir.mkdir(exist_ok=True, parents=True)
def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str:
"""Return the disease name given its encoded label.
Args:
encoded_prediction (int): The encoded prediction
file_name (str): The data file path
Returns:
str: The according disease name
"""
df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates()
disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten()
return disease_name
def load_data() -> Tuple[pandas.DataFrame, pandas.DataFrame, numpy.ndarray]:
"""
Return the data
Args:
None
Return:
Tuple[pandas.DataFrame, pandas.DataFrame, numpy.ndarray]: The train and testing set.
"""
# Load data
df_train = pandas.read_csv(TRAINING_FILENAME)
df_test = pandas.read_csv(TESTING_FILENAME)
# Separate the traget from the training / testing set:
# TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease
# TARGET_COLUMNS[1] -> "prognosis" -> contains the name of the disease
y_train = df_train[TARGET_COLUMNS[0]]
X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
y_test = df_test[TARGET_COLUMNS[0]]
X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore")
return (X_train, X_test), (y_train, y_test)
def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray):
"""
Load a pretrained serialized model
Args:
X_train (pandas.DataFrame): Training set
y_train (numpy.ndarray): Targets of the training set
Return:
The Concrete ML model and its circuit
"""
# Parameters
concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
classifier = ConcreteXGBoostClassifier(**concrete_args)
# Train the model
classifier.fit(X_train, y_train)
# Compile the model
circuit = classifier.compile(X_train)
return classifier, circuit
|