File size: 6,583 Bytes
70884da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
import json
import pandas as pd
# Define the function to create the multiple instance learning (MIL) model
def create_simple_model(instance_shape, max_length):
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
norm_1 = layers.LayerNormalization()(aggregated)
output = layers.Dense(1, activation="sigmoid")(norm_1)
return Model(inputs, output)
# Function to compute class weights
def compute_class_weights(labels):
negative_count = len(np.where(labels == 0)[0])
positive_count = len(np.where(labels == 1)[0])
total_count = negative_count + positive_count
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
# Function to generate batches of data
def data_generator(data, labels, batch_size=1):
class_weights = compute_class_weights(labels)
while True:
for i in range(0, len(data), batch_size):
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
yield batch_data, batch_labels, batch_weights
# Learning rate scheduler
def lr_scheduler(epoch, lr):
decay_rate = 0.1
decay_step = 10
if epoch % decay_step == 0 and epoch:
return lr * decay_rate
return lr
# Function to train the model
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
model_path = os.path.join(save_dir, "best_model.h5")
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
train_gen = data_generator(train_data, train_labels)
val_gen = data_generator(val_data, val_labels), steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
return model
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
def compute_additional_metrics(X, Y, model):
predictions = model.predict(X).flatten()
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
auc = roc_auc_score(Y, predictions)
precision = precision_score(Y, predictions_binary)
recall = recall_score(Y, predictions_binary)
f1 = f1_score(Y, predictions_binary)
return auc, precision, recall, f1, predictions
# Function to evaluate the model on a given dataset
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
eval_metrics = model.evaluate(X, Y, verbose=0)
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
metrics = {
'loss': eval_metrics[0],
'accuracy': eval_metrics[1],
'auc': auc,
'precision': precision,
'recall': recall,
'f1_score': f1
# Save the predictions for each sample
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
return metrics
# Function to evaluate the model on train, validate, and test datasets
def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)
metrics = {
'train': train_metrics,
'validate': validate_metrics,
'test': test_metrics
# Display the metrics in a tabular format
metrics_df = pd.DataFrame(metrics).T
# Save metrics to a JSON file
with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
json.dump(metrics, f, indent=4)
print("Evaluation metrics saved to evaluation_metrics.json")
return metrics
if __name__ == "__main__":
# Command line arguments
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
# Load the preprocessed data
data = np.load(args.data_file)
train_X, train_Y = data['train_X'], data['train_Y']
validate_X, validate_Y = data['validate_X'], data['validate_Y']
test_X, test_Y = data['test_X'], data['test_Y']
# Create the model
instance_shape = (train_X.shape[-1],)
max_length = train_X.shape[1]
model = create_simple_model(instance_shape, max_length)
# Train the model
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
# Save the final model after training
final_model_path = os.path.join(args.save_dir, "risk_classifier_model.h5")
print(f"Model saved successfully to {final_model_path}")
# Evaluate the model
metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)