import os
import numpy as np
import pandas as pd
import torch
import json
import tensorflow as tf
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
# Function to load and preprocess the dataset
def load_and_preprocess_data(metadata_file, data_dir):
dff = pd.read_csv(metadata_file, skiprows=0)
if 'Unnamed: 0' in dff.columns:
del dff['Unnamed: 0']
# Filter and map classes to 0 and 1
classified_df = dff[dff['Class'].isin([1, 3])]
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
df = classified_df.set_index('PatientID')
# Filter for patients that have corresponding WSI data
available_patients = set(os.listdir(data_dir))
df = df.loc[df.index.intersection(available_patients)]
df = df.sample(frac=1)
return df
# Function to create bags of tiles
def create_bags(df, data_dir):
data = {'test2': {'X': [], 'Y': []}}
for pID, row in df.iterrows():
fol_p = os.path.join(data_dir, pID)
tiles = os.listdir(fol_p)
tile_data = []
for tile in tiles:
tile_p = os.path.join(fol_p, tile)
np1 = torch.load(tile_p).numpy()
bag = np.squeeze(tile_data, axis=1)
bag_label = row['Class']
data['test2']['X'] = np.array(data['test2']['X'])
data['test2']['Y'] = np.array(data['test2']['Y'])
print(f"Data[test2]['X'] shape: {data['test2']['X'].shape}, dtype: {data['test2']['X'].dtype}")
return data
# Function to pad the data to ensure uniform bag length
def prepare_data_with_padding(data, max_length=2000):
padded_data = []
for bag in data:
if len(bag) < max_length:
padding = np.zeros((max_length - len(bag), bag.shape[1]))
padded_bag = np.vstack((bag, padding))
padded_bag = bag
return np.array(padded_data)
# Function to compute additional metrics using sklearn
def compute_additional_metrics(X, Y, model):
predictions = model.predict(X).flatten()
predictions_binary = (predictions > 0.5).astype(int) # Convert probabilities to class labels (0 or 1)
auc = roc_auc_score(Y, predictions)
precision = precision_score(Y, predictions_binary)
recall = recall_score(Y, predictions_binary)
f1 = f1_score(Y, predictions_binary)
return auc, precision, recall, f1, predictions
# Function to evaluate the model on a given dataset using sklearn metrics
def evaluate_dataset(model, X, Y, dataset_name, save_dir):
# Evaluate using TensorFlow's model.evaluate() for loss and accuracy
eval_metrics = model.evaluate(X, Y, verbose=0)
# Compute additional metrics using sklearn
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model)
metrics = {
'loss': eval_metrics[0],
'accuracy': eval_metrics[1],
'auc': auc,
'precision': precision,
'recall': recall,
'f1_score': f1
# Save the predictions for each sample
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y)
return metrics
if __name__ == "__main__":
# Command line arguments
parser = argparse.ArgumentParser(description='Evaluate a trained model on a secondary test dataset (test2).')
parser.add_argument('--metadata_file', type=str, required=True, help='Path to the metadata CSV file for test2.')
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the extracted tissue features.')
parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.')
parser.add_argument('--save_dir', type=str, default='./evaluation_results_test2/', help='Directory to save evaluation results.')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
# Load and preprocess the test2 data
df_test2 = load_and_preprocess_data(args.metadata_file, args.data_dir)
data_test2 = create_bags(df_test2, args.data_dir)
# Prepare the test2 data with padding
test2_X = prepare_data_with_padding(data_test2['test2']['X'], max_length=2000)
test2_Y = np.array(data_test2['test2']['Y']).flatten()
# Load the saved model
model = tf.keras.models.load_model(args.model_path)
# Evaluate the model on the test2 dataset
test2_metrics = evaluate_dataset(model, test2_X, test2_Y, "test2", args.save_dir)
# Save the metrics to a JSON file
with open(os.path.join(args.save_dir, 'evaluation_metrics_test2.json'), 'w') as f:
json.dump(test2_metrics, f, indent=4)
print("Evaluation metrics saved to evaluation_metrics_test2.json")