|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import json |
|
import tensorflow as tf |
|
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score |
|
import argparse |
|
|
|
|
|
def load_and_preprocess_data(metadata_file, data_dir): |
|
dff = pd.read_csv(metadata_file, skiprows=0) |
|
if 'Unnamed: 0' in dff.columns: |
|
del dff['Unnamed: 0'] |
|
|
|
|
|
classified_df = dff[dff['Class'].isin([1, 3])] |
|
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0}) |
|
df = classified_df.set_index('PatientID') |
|
|
|
|
|
available_patients = set(os.listdir(data_dir)) |
|
df = df.loc[df.index.intersection(available_patients)] |
|
df = df.sample(frac=1) |
|
|
|
return df |
|
|
|
|
|
def create_bags(df, data_dir): |
|
data = {'test2': {'X': [], 'Y': []}} |
|
for pID, row in df.iterrows(): |
|
fol_p = os.path.join(data_dir, pID) |
|
tiles = os.listdir(fol_p) |
|
tile_data = [] |
|
for tile in tiles: |
|
tile_p = os.path.join(fol_p, tile) |
|
np1 = torch.load(tile_p).numpy() |
|
tile_data.append(np1) |
|
bag = np.squeeze(tile_data, axis=1) |
|
bag_label = row['Class'] |
|
data['test2']['X'].append(bag) |
|
data['test2']['Y'].append(np.array([bag_label])) |
|
data['test2']['X'] = np.array(data['test2']['X']) |
|
data['test2']['Y'] = np.array(data['test2']['Y']) |
|
print(f"Data[test2]['X'] shape: {data['test2']['X'].shape}, dtype: {data['test2']['X'].dtype}") |
|
return data |
|
|
|
|
|
def prepare_data_with_padding(data, max_length=2000): |
|
padded_data = [] |
|
for bag in data: |
|
if len(bag) < max_length: |
|
padding = np.zeros((max_length - len(bag), bag.shape[1])) |
|
padded_bag = np.vstack((bag, padding)) |
|
else: |
|
padded_bag = bag |
|
padded_data.append(padded_bag) |
|
return np.array(padded_data) |
|
|
|
|
|
def compute_additional_metrics(X, Y, model): |
|
predictions = model.predict(X).flatten() |
|
predictions_binary = (predictions > 0.5).astype(int) |
|
auc = roc_auc_score(Y, predictions) |
|
precision = precision_score(Y, predictions_binary) |
|
recall = recall_score(Y, predictions_binary) |
|
f1 = f1_score(Y, predictions_binary) |
|
return auc, precision, recall, f1, predictions |
|
|
|
|
|
def evaluate_dataset(model, X, Y, dataset_name, save_dir): |
|
|
|
eval_metrics = model.evaluate(X, Y, verbose=0) |
|
|
|
|
|
auc, precision, recall, f1, predictions = compute_additional_metrics(X, Y, model) |
|
metrics = { |
|
'loss': eval_metrics[0], |
|
'accuracy': eval_metrics[1], |
|
'auc': auc, |
|
'precision': precision, |
|
'recall': recall, |
|
'f1_score': f1 |
|
} |
|
|
|
|
|
np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=predictions, labels=Y) |
|
|
|
return metrics |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description='Evaluate a trained model on a secondary test dataset (test2).') |
|
parser.add_argument('--metadata_file', type=str, required=True, help='Path to the metadata CSV file for test2.') |
|
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the extracted tissue features.') |
|
parser.add_argument('--model_path', type=str, required=True, help='Path to the saved model file.') |
|
parser.add_argument('--save_dir', type=str, default='./evaluation_results_test2/', help='Directory to save evaluation results.') |
|
|
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.save_dir): |
|
os.makedirs(args.save_dir) |
|
|
|
|
|
df_test2 = load_and_preprocess_data(args.metadata_file, args.data_dir) |
|
data_test2 = create_bags(df_test2, args.data_dir) |
|
|
|
|
|
test2_X = prepare_data_with_padding(data_test2['test2']['X'], max_length=2000) |
|
test2_Y = np.array(data_test2['test2']['Y']).flatten() |
|
|
|
|
|
model = tf.keras.models.load_model(args.model_path) |
|
|
|
|
|
test2_metrics = evaluate_dataset(model, test2_X, test2_Y, "test2", args.save_dir) |
|
|
|
|
|
with open(os.path.join(args.save_dir, 'evaluation_metrics_test2.json'), 'w') as f: |
|
json.dump(test2_metrics, f, indent=4) |
|
|
|
print("Evaluation metrics saved to evaluation_metrics_test2.json") |
|
|