H2OTest / llm_studio /src /metrics /text_causal_classification_modeling_metrics.py
elineve's picture
Upload 301 files
07423df
raw
history blame
2.39 kB
import logging
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from scipy.special import softmax
from sklearn.metrics import log_loss, roc_auc_score
logger = logging.getLogger(__name__)
def accuracy_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
predicted_text = np.array([int(text) for text in results["predicted_text"]])
target_text = np.array([int(text) for text in results["target_text"]])
return (predicted_text == target_text).astype("float")
def auc_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
logits = np.array(results["logits"])
target_text = np.array([int(text) for text in results["target_text"]])
if cfg.dataset.num_classes > 1:
target_text = np.eye(cfg.dataset.num_classes)[target_text]
return roc_auc_score(target_text, logits, multi_class="ovr")
def logloss_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
logits = np.array(results["logits"])
target_text = np.array([int(text) for text in results["target_text"]])
if cfg.dataset.num_classes > 1:
target_text = np.eye(cfg.dataset.num_classes)[target_text]
logits = softmax(logits, axis=1)
return log_loss(target_text, logits, eps=1e-7)
class Metrics:
"""
Metrics factory. Returns:
- metric value
- should it be maximized or minimized
- Reduce function
Maximized or minimized is needed for early stopping (saving best checkpoint)
Reduce function to generate a single metric value, usually "mean" or "none"
"""
_metrics = {
"AUC": (auc_score, "max", "mean"),
"Accuracy": (accuracy_score, "max", "mean"),
"LogLoss": (logloss_score, "min", "mean"),
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._metrics.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Metrics.
Args:
name: metrics name
Returns:
A class to build the Metrics
"""
return cls._metrics.get(name, cls._metrics["LogLoss"])