File size: 5,667 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import numpy as np
import tensorflow as tf
import json
from tensorflow.keras.preprocessing.image import ImageDataGenerator as IDG
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
import argparse
import pandas as pd

# Function to compute additional metrics like AUC, Precision, Recall, and F1 Score
def compute_additional_metrics(generator, model):
    y_true = generator.classes
    y_pred_prob = model.predict(generator)
    y_pred = np.argmax(y_pred_prob, axis=1)
    auc = roc_auc_score(y_true, y_pred_prob[:, 1])
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')
    accuracy = accuracy_score(y_true, y_pred)
    return auc, precision, recall, f1, accuracy, y_pred_prob

# Function to save evaluation metrics
def save_evaluation_metrics(generator, model, dataset_name, save_dir):
    auc, precision, recall, f1, accuracy, y_pred_prob = compute_additional_metrics(generator, model)
    metrics = {
        'auc': auc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'accuracy': accuracy
    }
    # Save predictions
    np.savez_compressed(os.path.join(save_dir, f'{dataset_name}_predictions.npz'), predictions=y_pred_prob, labels=generator.classes)
    return metrics

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train and evaluate ResNet50 on benchmark datasets.')
    parser.add_argument('--dataset_dir', type=str, required=True, help='Directory containing train, validate, test, and test2 directories.')
    parser.add_argument('--save_dir', type=str, default='./results/', help='Directory to save the model and evaluation results.')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.')

    args = parser.parse_args()

    train_dir = os.path.join(args.dataset_dir, 'train')
    validate_dir = os.path.join(args.dataset_dir, 'validate')
    test_dir = os.path.join(args.dataset_dir, 'test')
    test2_dir = os.path.join(args.dataset_dir, 'test2')

    os.makedirs(args.save_dir, exist_ok=True)

    # Set up ResNet50 model
    with tf.device('GPU:0'):
        resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
        last_layer = resnet.get_layer('conv5_block3_out')
        last_output = last_layer.output
        x = tf.keras.layers.GlobalAveragePooling2D()(last_output)
        x = tf.keras.layers.Dense(2, activation='softmax')(x)  # Assuming binary classification
        model = tf.keras.Model(inputs=resnet.input, outputs=x)
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', 'Recall', 'Precision'])

        # Image data generators
        train_datagen = IDG(rescale=1/255.0, horizontal_flip=True)
        validate_datagen = IDG(rescale=1/255.0)
        test_datagen = IDG(rescale=1/255.0)

        batch_size = 64

        train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
                                                            class_mode='categorical', batch_size=batch_size)
        validate_generator = validate_datagen.flow_from_directory(validate_dir, target_size=(224, 224),
                                                                  class_mode='categorical', batch_size=batch_size)
        test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224),
                                                          class_mode='categorical', batch_size=batch_size)
        test2_generator = test_datagen.flow_from_directory(test2_dir, target_size=(224, 224),
                                                           class_mode='categorical', batch_size=batch_size)

        # Training the model
        hist = model.fit(train_generator, epochs=args.epochs, validation_data=validate_generator, verbose=1, shuffle=True)

        # Save the trained model
        model.save(os.path.join(args.save_dir, 'risk_classifier_resnet_model.hdf5'))

        # Save training history separately
        training_log = {
            'loss': hist.history['loss'],
            'val_loss': hist.history['val_loss'],
            'accuracy': hist.history['accuracy'],
            'val_accuracy': hist.history['val_accuracy'],
            'recall': hist.history['recall'],
            'val_recall': hist.history['val_recall'],
            'precision': hist.history['precision'],
            'val_precision': hist.history['val_precision']
        }
        with open(os.path.join(args.save_dir, 'resnet_training_log.json'), 'w') as f:
            json.dump(training_log, f)

        # Evaluate the model on each dataset and save metrics
        train_metrics = save_evaluation_metrics(train_generator, model, "train", args.save_dir)
        validate_metrics = save_evaluation_metrics(validate_generator, model, "validate", args.save_dir)
        test_metrics = save_evaluation_metrics(test_generator, model, "test", args.save_dir)
        test2_metrics = save_evaluation_metrics(test2_generator, model, "test2", args.save_dir)

        # Save the evaluation metrics in a JSON file
        evaluation_metrics = {
            'train_metrics': train_metrics,
            'validate_metrics': validate_metrics,
            'test_metrics': test_metrics,
            'test2_metrics': test2_metrics
        }

        with open(os.path.join(args.save_dir, 'resnet_evaluation_metrics.json'), 'w') as f:
            json.dump(evaluation_metrics, f)

        print("Training and evaluation metrics saved.")