from typing import Dict | |
import pandas as pd | |
from sklearn.metrics import mean_squared_error | |
from src.utils import validate_y | |
class MSEMetric: | |
def __init__(self): | |
super().__init__() | |
def evaluate_class_rmse(y_pred: pd.DataFrame, y_true: pd.DataFrame) -> Dict[str, float]: | |
validate_y(y_pred) | |
validate_y(y_true) | |
result = {} | |
for column in y_pred.drop(columns=['text_id']).columns: | |
result[column] = mean_squared_error(y_pred[column], y_true[column], squared=False) | |
return result | |
def evaluate(self, y_pred: pd.DataFrame, y_true: pd.DataFrame) -> float: | |
result = self.evaluate_class_rmse(y_pred, y_true) | |
return sum(result.values()) / len(result) | |