File size: 3,420 Bytes
a63e912
 
1a07572
 
a63e912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a07572
 
 
 
 
a63e912
 
 
 
 
 
 
 
1a07572
 
 
 
 
 
a63e912
 
 
e368a57
a63e912
 
 
 
 
 
 
 
 
 
 
 
1a07572
 
 
 
 
 
 
 
 
 
 
 
 
a63e912
1a07572
 
 
 
 
 
 
 
a63e912
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import evaluate
from datasets import Features, Value
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix


_CITATION = """
@article{scikit-learn,
  title={Scikit-learn: Machine Learning in {P}ython},
  author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
         and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
         and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
         Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
  journal={Journal of Machine Learning Research},
  volume={12},
  pages={2825--2830},
  year={2011}
}
"""

_DESCRIPTION = """
This evaluator computes multiple classification metrics to assess the performance of a model. Metrics calculated include:
- Accuracy: The proportion of correct predictions among the total number of cases processed. Computed as (TP + TN) / (TP + TN + FP + FN), where TP, TN, FP, and FN denote true positives, true negatives, false positives, and false negatives respectively.
- Precision, Recall, and F1-Score: Evaluated for each class individually as well as macro (average across classes) and micro (aggregate contributions of all classes) averages.
- Confusion Matrix: A matrix representing the classification accuracy for each class combination.

"""

_KWARGS_DESCRIPTION = """
Args:
    predictions (`list` of `str`): Predicted labels.
    references (`list` of `str`): Ground truth labels.

Returns:
    Returns:
    Dict containing:
        accuracy (float): Proportion of correct predictions. Value ranges between 0 (worst) and 1 (best).
        precision_macro (float), recall_macro (float), f1_macro (float): Macro averages of precision, recall, and F1-score respectively.
        precision_micro (float), recall_micro (float), f1_micro (float): Micro averages of precision, recall, and F1-score respectively.
        confusion_matrix (list of lists): 2D list representing the confusion matrix of the classification results.
"""


class ClassificationEvaluator(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=Features(
                {"predictions": Value("string"), "references": Value("string")}
            ),
        )

    def _compute(self, predictions, references):

        accuracy = accuracy_score(references, predictions, normalize=True, sample_weight=None)        

        # Calculate macro and micro averages for precision, recall, and F1-score
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            references, predictions, average='macro'
        )
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            references, predictions, average='micro'
        )

        # Calculate the confusion matrix
        conf_matrix = confusion_matrix(references, predictions)

        return {
            "accuracy": accuracy,
            "precision_macro": float(precision_macro),
            "recall_macro": float(recall_macro),
            "f1_macro": float(f1_macro),
            "precision_micro": float(precision_micro),
            "recall_micro": float(recall_micro),
            "f1_micro": float(f1_micro),
            "confusion_matrix": conf_matrix.tolist()
        }