|
import pandas as pd |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from typing import Tuple |
|
|
|
VOLUME_FACTOR_REGULARIZATION = 0.5 |
|
UNSCALED_WEIGHTED_ACCURACY_INTERVAL = (-0.5, 100.5) |
|
SCALED_WEIGHTED_ACCURACY_INTERVAL = (0, 1) |
|
|
|
|
|
def scale_value( |
|
value: float, |
|
min_max_bounds: Tuple[float, float], |
|
scale_bounds: Tuple[float, float] = (0, 1), |
|
) -> float: |
|
"""Perform min-max scaling on a value.""" |
|
min_, max_ = min_max_bounds |
|
current_range = max_ - min_ |
|
|
|
std = (value - min_) / current_range |
|
|
|
min_bound, max_bound = scale_bounds |
|
target_range = max_bound - min_bound |
|
return std * target_range + min_bound |
|
|
|
|
|
def get_weighted_accuracy(row, global_requests: int): |
|
"""Function to compute the weighted accuracy of a tool""" |
|
return scale_value( |
|
( |
|
row["tool_accuracy"] |
|
+ (row["total_requests"] / global_requests) * VOLUME_FACTOR_REGULARIZATION |
|
), |
|
UNSCALED_WEIGHTED_ACCURACY_INTERVAL, |
|
SCALED_WEIGHTED_ACCURACY_INTERVAL, |
|
) |
|
|
|
|
|
def compute_weighted_accuracy(tools_accuracy: pd.DataFrame): |
|
global_requests = tools_accuracy.total_requests.sum() |
|
tools_accuracy["weighted_accuracy"] = tools_accuracy.apply( |
|
lambda x: get_weighted_accuracy(x, global_requests), axis=1 |
|
) |
|
return tools_accuracy |
|
|
|
|
|
def plot_tools_accuracy_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="tool_accuracy", ascending=False |
|
) |
|
plt.figure(figsize=(25, 10)) |
|
plot = sns.barplot( |
|
tools_accuracy_info, |
|
x="tool_accuracy", |
|
y="tool", |
|
hue="tool", |
|
dodge=False, |
|
palette="viridis", |
|
) |
|
return gr.Plot(value=plot.get_figure()) |
|
|
|
|
|
def plot_tools_weighted_accuracy_graph(tools_accuracy_info: pd.DataFrame): |
|
tools_accuracy_info = tools_accuracy_info.sort_values( |
|
by="weighted_accuracy", ascending=False |
|
) |
|
|
|
sns.set_theme(palette="viridis") |
|
plt.figure(figsize=(25, 10)) |
|
plot = sns.barplot( |
|
tools_accuracy_info, |
|
x="weighted_accuracy", |
|
y="tool", |
|
hue="tool", |
|
dodge=False, |
|
) |
|
|
|
return gr.Plot(value=plot.get_figure()) |
|
|