Spaces:
Running
Running
""" | |
Baseline: based on most-common answer | |
""" | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
from .metrics import mapk, rank_biased_overlap | |
from .plots import plot_ranks | |
import logging | |
from typing import List, Callable, Optional | |
from rouge_score import rouge_scorer as rs | |
from collections import Counter | |
import random | |
logger = logging.getLogger(__name__) | |
tol = 0.001 | |
class MCARank: | |
""" | |
Baseline method: based on most common answer | |
""" | |
def __init__( | |
self, | |
MODELS: List, | |
evaluator: Callable, | |
true_ranking: Optional[List] = None, | |
show_progress: Optional[bool] = False, | |
): | |
self.MODELS = MODELS | |
self.N = len(MODELS) | |
self.evaluate = evaluator | |
self.true_ranking = true_ranking | |
self.show_progress = show_progress | |
def fit(self, df: pd.DataFrame, measure: Optional[str]='equality', p: float = 0): | |
""" | |
df: Dataframe where each row is a benchmark instance, | |
and there is a column with the output for each Model | |
measure: decides how the most common answer is decided. | |
p - is the noise level to include (only used for noisy-equality) | |
""" | |
assert set(self.MODELS) == set(df.columns), "Benchmark data models inconsistent with models to be ranked." | |
if measure == 'equality': | |
# Select the most common answer per question | |
mca = df.mode(axis=1).iloc[:, 0] | |
# Count all the times each model answered the most common one | |
wins = df.eq(mca, axis=0).astype(int) | |
self.ranking = wins.sum().sort_values(ascending=False).index.to_list() | |
elif measure == 'noisy_equality': | |
# Most common answer | |
mca = df.mode(axis=1).iloc[:, 0] | |
perturb = lambda x: not x if (random.random() <= p) else x | |
def __noisy_equality(x, mca): | |
wins = (x == mca).apply(perturb) | |
return wins | |
wins = df.apply(__noisy_equality, axis='rows', args=(mca, )) | |
self.ranking = wins.sum().sort_values(ascending=False).index.to_list() | |
elif measure == 'rouge': | |
MODELS = df.columns.to_list() | |
SIZE = 256 | |
def __mca(x): | |
""" Most Commmon Answer, as the top k bigrams across all outputs """ | |
cs = [rs._create_ngrams(x[m], n=2) for m in MODELS] | |
c = sum(cs, Counter()) | |
return Counter(dict(c.most_common(SIZE))) | |
def __score_mca(x): | |
""" Rouge score computed relative to most-common-answer """ | |
res = {} | |
for m in MODELS: | |
p_n = rs._create_ngrams(x[m], n=2) | |
res[m] = rs._score_ngrams(x.mca, p_n).fmeasure | |
return pd.Series(res) | |
df['mca'] = df.apply(__mca, axis=1) | |
# Winning model based on best ROUGE score for each question | |
win_rates = df.apply(__score_mca, axis=1).idxmax(axis=1).value_counts() | |
win_rate_rank = win_rates.index.tolist() | |
# include models with nowins at the bottom | |
no_wins = list(set(MODELS) - set(win_rate_rank)) | |
self.ranking = win_rate_rank + no_wins | |
else: | |
raise ValueError(f"Measure {measure} not understood.") | |
logger.info(f"Estimated ranks (best to worst): {self.ranking}") | |
logger.info(f"True ranking: {self.true_ranking}") | |
logger.info(f"RBO measure: {self.measure()}") | |
return self.ranking # Best to worst | |
def measure(self, metric='rbo', k=5, p=0.95) -> float: | |
""" | |
Report metric related to self-rank | |
""" | |
if metric not in ['rbo', 'mapk']: | |
raise ValueError(f"Metric {metric} not supported (use 'rbo'/'mapk').") | |
if hasattr(self, 'ranking'): | |
if self.true_ranking is not None: | |
if metric == 'mapk': | |
if k > len(self.true_ranking): | |
logger.warning(f"MAPk metric is for k={len(self.true_ranking)}, and not k={k}.") | |
actual = [self.true_ranking[:k]] | |
pred = [self.ranking[:k]] | |
return mapk(actual, pred, k=k) | |
elif metric == 'rbo': | |
return rank_biased_overlap(self.true_ranking, self.ranking, p=p) | |
else: | |
raise ValueError(f"Metric {metric} not understood.") | |
else: | |
raise ValueError("True ranking not available for metric calculation.") | |
else: | |
raise ValueError("Ranking not estimated. Run 'fit' first.") | |
def plot(self, caselabel="output"): | |
if hasattr(self, 'ranking') & (self.true_ranking is not None): | |
plot_ranks(self.true_ranking, self.ranking, "actual", "estimated", caselabel) | |