File size: 6,583 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
import json
import pandas as pd

# Define the function to create the multiple instance learning (MIL) model
def create_simple_model(instance_shape, max_length):
    inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
    flatten = layers.TimeDistributed(layers.Flatten())(inputs)
    dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
    dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
    dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
    dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
    aggregated = layers.GlobalAveragePooling1D()(dropout_2)
    norm_1 = layers.LayerNormalization()(aggregated)
    output = layers.Dense(1, activation="sigmoid")(norm_1)
    return Model(inputs, output)

# Function to compute class weights
def compute_class_weights(labels):
    negative_count = len(np.where(labels == 0)[0])
    positive_count = len(np.where(labels == 1)[0])
    total_count = negative_count + positive_count
    return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}

# Function to generate batches of data
def data_generator(data, labels, batch_size=1):
    class_weights = compute_class_weights(labels)
    while True:
        for i in range(0, len(data), batch_size):
            batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
            batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
            batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
            yield batch_data, batch_labels, batch_weights

# Learning rate scheduler
def lr_scheduler(epoch, lr):
    decay_rate = 0.1
    decay_step = 10
    if epoch % decay_step == 0 and epoch:
        return lr * decay_rate
    return lr

# Function to train the model
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
    model_path = os.path.join(save_dir, "best_model.h5")
    model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
    train_gen = data_generator(train_data, train_labels)
    val_gen = data_generator(val_data, val_labels)
    model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
    return model

# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
def compute_additional_metrics(X, Y, model):
    predictions = model.predict(X).flatten()
    predictions_binary = (predictions > 0.5).astype(int)  # Convert probabilities to class labels (0 or 1)
    auc = roc_auc_score(Y, predictions)
    precision = precision_score(Y, predictions_binary)
    recall = recall_score(Y, predictions_binary)
    f1 = f1_score(Y, predictions_binary)
    return auc, precision, recall, f1, predictions

# Function to evaluate the model on a given dataset
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
    eval_metrics = model.evaluate(X, Y, verbose=0)
    auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
    metrics = {
        'loss': eval_metrics[0],
        'accuracy': eval_metrics[1],
        'auc': auc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }
    
    # Save the predictions for each sample
    np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
    
    return metrics

# Function to evaluate the model on train, validate, and test datasets
def evaluate_all_datasets(model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, save_dir):
    train_metrics = evaluate_dataset(model, train_X, train_Y, "train", save_dir)
    validate_metrics = evaluate_dataset(model, validate_X, validate_Y, "validate", save_dir)
    test_metrics = evaluate_dataset(model, test_X, test_Y, "test", save_dir)

    metrics = {
        'train': train_metrics,
        'validate': validate_metrics,
        'test': test_metrics
    }

    # Display the metrics in a tabular format
    metrics_df = pd.DataFrame(metrics).T
    print(metrics_df.to_string())

    # Save metrics to a JSON file
    with open(os.path.join(save_dir, 'evaluation_metrics.json'), 'w') as f:
        json.dump(metrics, f, indent=4)

    print("Evaluation metrics saved to evaluation_metrics.json")
    
    return metrics

if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
    parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
    parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model and evaluation metrics.')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # Load the preprocessed data
    data = np.load(args.data_file)
    train_X, train_Y = data['train_X'], data['train_Y']
    validate_X, validate_Y = data['validate_X'], data['validate_Y']
    test_X, test_Y = data['test_X'], data['test_Y']

    # Create the model
    instance_shape = (train_X.shape[-1],)
    max_length = train_X.shape[1]
    model = create_simple_model(instance_shape, max_length)

    # Train the model
    trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)

    # Save the final model after training
    final_model_path = os.path.join(args.save_dir, "risk_classifier_model.h5")
    trained_model.save(final_model_path)
    print(f"Model saved successfully to {final_model_path}")

    # Evaluate the model
    metrics = evaluate_all_datasets(trained_model, train_X, train_Y, validate_X, validate_Y, test_X, test_Y, args.save_dir)