omics-plip-1 / benchmark_train_resnet50.py
VatsalPatel18's picture
Upload 19 files
70884da verified
import os
import numpy as np
import tensorflow as tf
import json
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
import pandas as pd
# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
def compute_additional_metrics(generator, model):
y_true = generator.classes
y_pred_prob = model.predict(generator)
y_pred = np.argmax(y_pred_prob, axis=1)
auc = roc_auc_score(y_true, y_pred_prob[:, 1])
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
accuracy = accuracy_score(y_true, y_pred)
return auc, precision, recall, f1, accuracy, y_pred_prob
# Function to save evaluation metrics
def save_evaluation_metrics(generator, model, dataset_name, save_dir):
auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
metrics = {
'auc': auc,
'precision': precision,
'recall': recall,
'f1_score': f1,
'accuracy': accuracy
}
# Save predictions
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train and evaluate ResNet50 on benchmark datasets.')
parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')
args = parser.parse_args()
train_dir = os.path.join(args.dataset_dir, 'train')
validate_dir = os.path.join(args.dataset_dir, 'validate')
test_dir = os.path.join(args.dataset_dir, 'test')
test2_dir = os.path.join(args.dataset_dir, 'test2')
os.makedirs(args.save_dir, exist_ok=True)
# Set up ResNet50 model
with tf.device('GPU:0'):
resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
last_layer = resnet.get_layer('conv5_block3_out')
last_output = last_layer.output
x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
x = tf.keras.layers.Dense(2, activation='softmax')(x) # Assuming binary classification
model = tf.keras.Model(inputs=resnet.input, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])
# Image data generators
train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
validate_datagen = IDG(rescale=1/255.0)
test_datagen = IDG(rescale=1/255.0)
batch_size = 64
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(224, 224),
class_mode='categorical', batch_size=batch_size)
# Training the model
hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)
# Save the trained model
model.save(os.path.join(args.save_dir, 'risk_classifier_resnet_model.hdf5'))
# Save training history separately
training_log = {
'loss': hist.history['loss'],
'val_loss': hist.history['val_loss'],
'accuracy': hist.history['accuracy'],
'val_accuracy': hist.history['val_accuracy'],
'recall': hist.history['recall'],
'val_recall': hist.history['val_recall'],
'precision': hist.history['precision'],
'val_precision': hist.history['val_precision']
}
with open(os.path.join(args.save_dir, 'resnet_training_log.json'), 'w') as f:
json.dump(training_log, f)
# Evaluate the model on each dataset and save metrics
train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)
# Save the evaluation metrics in a JSON file
evaluation_metrics = {
'train_metrics': train_metrics,
'validate_metrics': validate_metrics,
'test_metrics': test_metrics,
'test2_metrics': test2_metrics
}
with open(os.path.join(args.save_dir, 'resnet_evaluation_metrics.json'), 'w') as f:
json.dump(evaluation_metrics, f)
print("Training and evaluation metrics saved.")