omics-plip-1 / train_risk_classifier.py
VatsalPatel18's picture
Upload 19 files
70884da verified
raw
history blame
5.16 kB
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import argparse
from datetime import datetime
# Define the function to create the multiple instance learning (MIL) model
def create_simple_model2(instance_shape, max_length):
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
aggregated = layers.GlobalAveragePooling1D()(dropout_2)
norm_1 = layers.LayerNormalization()(aggregated)
output = layers.Dense(1, activation="sigmoid")(norm_1)
return Model(inputs, output)
def create_simple_model(instance_shape, max_length, num_heads=4, key_dim=64):
inputs = layers.Input(shape=(max_length, instance_shape[-1]), name="bag_input")
flatten = layers.TimeDistributed(layers.Flatten())(inputs)
dense_1 = layers.TimeDistributed(layers.Dense(256, activation="relu"))(flatten)
dropout_1 = layers.TimeDistributed(layers.Dropout(0.5))(dense_1)
dense_2 = layers.TimeDistributed(layers.Dense(64, activation="relu"))(dropout_1)
dropout_2 = layers.TimeDistributed(layers.Dropout(0.5))(dense_2)
attention_output, attention_scores = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
value_dim=64,
dropout=0.1,
use_bias=True
)(query=dropout_2, value=dropout_2, key=dropout_2, return_attention_scores=True)
aggregated = layers.GlobalAveragePooling1D()(attention_output)
norm_1 = layers.LayerNormalization()(aggregated)
output = layers.Dense(1, activation="sigmoid")(norm_1)
return Model(inputs, output)
# Function to compute class weights
def compute_class_weights(labels):
negative_count = len(np.where(labels == 0)[0])
positive_count = len(np.where(labels == 1)[0])
total_count = negative_count + positive_count
return {0: (1 / negative_count) * (total_count / 2), 1: (1 / positive_count) * (total_count / 2)}
# Function to generate batches of data
def data_generator(data, labels, batch_size=1):
class_weights = compute_class_weights(labels)
while True:
for i in range(0, len(data), batch_size):
batch_data = np.array(data[i:i + batch_size], dtype=np.float32)
batch_labels = np.array(labels[i:i + batch_size], dtype=np.float32)
batch_weights = np.array([class_weights[int(label)] for label in batch_labels], dtype=np.float32)
yield batch_data, batch_labels, batch_weights
# Learning rate scheduler
def lr_scheduler(epoch, lr):
decay_rate = 0.1
decay_step = 10
if epoch % decay_step == 0 and epoch:
return lr * decay_rate
return lr
# Function to train the model
def train(train_data, train_labels, val_data, val_labels, model, save_dir):
model_path = os.path.join(save_dir, "risk_classifier_model.h5")
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, monitor="val_loss", verbose=1, mode="min", save_best_only=True, save_weights_only=False)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, mode="min")
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", "AUC"])
train_gen = data_generator(train_data, train_labels)
val_gen = data_generator(val_data, val_labels)
model.fit(train_gen, steps_per_epoch=len(train_data), validation_data=val_gen, validation_steps=len(val_data), epochs=50, batch_size=1, callbacks=[early_stopping, model_checkpoint, lr_callback], verbose=1)
return model
if __name__ == "__main__":
# Command line arguments
parser = argparse.ArgumentParser(description='Train a multiple instance learning classifier on risk data.')
parser.add_argument('--data_file', type=str, required=True, help='Path to the saved .npz file with training and validation data.')
parser.add_argument('--save_dir', type=str, default='./model_save/', help='Directory to save the model.')
parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# Load the preprocessed data
data = np.load(args.data_file)
train_X, train_Y = data['train_X'], data['train_Y']
validate_X, validate_Y = data['validate_X'], data['validate_Y']
# Create the model
instance_shape = (train_X.shape[-1],)
max_length = train_X.shape[1]
model = create_simple_model(instance_shape, max_length)
# Train the model
trained_model = train(train_X, train_Y, validate_X, validate_Y, model, args.save_dir)
# Final message after training and saving the model
print(f"Model saved successfully to {os.path.join(args.save_dir, 'risk_classifier_model.h5')}")