|
import logging |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from numpy.typing import NDArray |
|
from scipy.special import softmax |
|
from sklearn.metrics import log_loss, roc_auc_score |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def accuracy_score( |
|
cfg: Any, |
|
results: Dict, |
|
val_df: pd.DataFrame, |
|
raw_results: bool = False, |
|
) -> Union[NDArray, Tuple[NDArray, List[str]]]: |
|
predicted_text = np.array([int(text) for text in results["predicted_text"]]) |
|
target_text = np.array([int(text) for text in results["target_text"]]) |
|
return (predicted_text == target_text).astype("float") |
|
|
|
|
|
def auc_score( |
|
cfg: Any, |
|
results: Dict, |
|
val_df: pd.DataFrame, |
|
raw_results: bool = False, |
|
) -> Union[NDArray, Tuple[NDArray, List[str]]]: |
|
logits = np.array(results["logits"]) |
|
target_text = np.array([int(text) for text in results["target_text"]]) |
|
if cfg.dataset.num_classes > 1: |
|
target_text = np.eye(cfg.dataset.num_classes)[target_text] |
|
return roc_auc_score(target_text, logits, multi_class="ovr") |
|
|
|
|
|
def logloss_score( |
|
cfg: Any, |
|
results: Dict, |
|
val_df: pd.DataFrame, |
|
raw_results: bool = False, |
|
) -> Union[NDArray, Tuple[NDArray, List[str]]]: |
|
logits = np.array(results["logits"]) |
|
target_text = np.array([int(text) for text in results["target_text"]]) |
|
if cfg.dataset.num_classes > 1: |
|
target_text = np.eye(cfg.dataset.num_classes)[target_text] |
|
logits = softmax(logits, axis=1) |
|
return log_loss(target_text, logits, eps=1e-7) |
|
|
|
|
|
class Metrics: |
|
""" |
|
Metrics factory. Returns: |
|
- metric value |
|
- should it be maximized or minimized |
|
- Reduce function |
|
|
|
Maximized or minimized is needed for early stopping (saving best checkpoint) |
|
Reduce function to generate a single metric value, usually "mean" or "none" |
|
""" |
|
|
|
_metrics = { |
|
"AUC": (auc_score, "max", "mean"), |
|
"Accuracy": (accuracy_score, "max", "mean"), |
|
"LogLoss": (logloss_score, "min", "mean"), |
|
} |
|
|
|
@classmethod |
|
def names(cls) -> List[str]: |
|
return sorted(cls._metrics.keys()) |
|
|
|
@classmethod |
|
def get(cls, name: str) -> Any: |
|
"""Access to Metrics. |
|
|
|
Args: |
|
name: metrics name |
|
Returns: |
|
A class to build the Metrics |
|
""" |
|
return cls._metrics.get(name, cls._metrics["LogLoss"]) |
|
|