File size: 5,539 Bytes
c12bd84 dfa14a8 c12bd84 e3642ff c671de9 c1a84da e3642ff c671de9 c12bd84 e3642ff c12bd84 8488477 c671de9 5129f48 c671de9 5129f48 c671de9 5129f48 c12bd84 c671de9 8488477 c12bd84 8488477 c12bd84 e854cb9 43b4e29 ca8d4b9 43b4e29 8488477 e3642ff 8488477 e3642ff 8488477 e3642ff a34a60b c671de9 e3642ff 43b4e29 8488477 a34a60b 8488477 337b761 8488477 337b761 8488477 337b761 8488477 337b761 c1a84da 337b761 c671de9 ac931c6 ca8d4b9 8488477 ac931c6 8488477 ac931c6 c671de9 8488477 c671de9 ca8e784 8488477 ca8e784 c671de9 8488477 c671de9 8488477 c671de9 8488477 c671de9 ca8d4b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import streamlit as st
import pandas as pd
import os
import fnmatch
import json
import plotly.express as px
class MultiURLData:
def __init__(self):
self.data = self.process_data()
def process_data(self):
dataframes = []
def find_files(directory, pattern):
for root, dirs, files in os.walk(directory):
for basename in files:
if fnmatch.fnmatch(basename, pattern):
filename = os.path.join(root, basename)
yield filename
for filename in find_files('results', 'results*.json'):
model_name = filename.split('/')[2]
with open(filename) as f:
data = json.load(f)
df = pd.DataFrame(data['results']).T
# data cleanup
df = df.rename(columns={'acc': model_name})
# Replace 'hendrycksTest-' with a more descriptive column name
df.index = df.index.str.replace('hendrycksTest-', 'MMLU_', regex=True)
df.index = df.index.str.replace('harness\|', '', regex=True)
# remove |5 from the index
df.index = df.index.str.replace('\|5', '', regex=True)
dataframes.append(df[[model_name]])
data = pd.concat(dataframes, axis=1)
data = data.transpose()
data['Model Name'] = data.index
cols = data.columns.tolist()
cols = cols[-1:] + cols[:-1]
data = data[cols]
# remove the Model Name column
data = data.drop(['Model Name'], axis=1)
# create a new column that averages the results from each of the columns with a name that start with MMLU
data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
# move the MMLU_average column to the third column in the dataframe
cols = data.columns.tolist()
cols = cols[:2] + cols[-1:] + cols[2:-1]
data = data[cols]
return data
# filter data based on the index
def get_data(self, selected_models):
filtered_data = self.data[self.data.index.isin(selected_models)]
return filtered_data
data_provider = MultiURLData()
st.title('Model Evaluation Results including MMLU by task')
filters = st.checkbox('Select Models and Evaluations')
# Create defaults for selected columns and models
selected_columns = data_provider.data.columns.tolist()
selected_models = data_provider.data.index.tolist()
if filters:
# Create checkboxes for each column
selected_columns = st.multiselect(
'Select Columns',
data_provider.data.columns.tolist(),
default=selected_columns
)
selected_models = st.multiselect(
'Select Models',
data_provider.data.index.tolist(),
default=selected_models
)
# Get the filtered data
st.header('Sortable table')
filtered_data = data_provider.get_data(selected_models)
# sort the table by the MMLU_average column
filtered_data = filtered_data.sort_values(by=['MMLU_average'], ascending=False)
st.dataframe(filtered_data[selected_columns])
# CSV download
csv = filtered_data.to_csv(index=True)
st.download_button(
label="Download data as CSV",
data=csv,
file_name="model_evaluation_results.csv",
mime="text/csv",
)
def create_plot(df, arc_column, moral_column, models=None):
if models is not None:
df = df[df.index.isin(models)]
plot_data = pd.DataFrame({
'Model': df.index,
arc_column: df[arc_column],
moral_column: df[moral_column],
})
plot_data['color'] = 'purple'
fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'], trendline="ols")
fig.update_layout(showlegend=False,
xaxis_title=arc_column,
yaxis_title=moral_column,
xaxis = dict(),
yaxis = dict())
return fig
st.header('Overall benchmark comparison')
st.header('Custom scatter plots')
selected_x_column = st.selectbox('Select x-axis', filtered_data.columns.tolist(), index=0)
selected_y_column = st.selectbox('Select y-axis', filtered_data.columns.tolist(), index=1)
if selected_x_column != selected_y_column: # Avoid creating a plot with the same column on both axes
fig = create_plot(filtered_data, selected_x_column, selected_y_column)
st.plotly_chart(fig)
else:
st.write("Please select different columns for the x and y axes.")
fig = create_plot(filtered_data, 'arc:challenge|25', 'hellaswag|10')
st.plotly_chart(fig)
fig = create_plot(filtered_data, 'arc:challenge|25', 'MMLU_average')
st.plotly_chart(fig)
fig = create_plot(filtered_data, 'hellaswag|10', 'MMLU_average')
st.plotly_chart(fig)
st.header('Top 50 models on MMLU_average')
top_50 = filtered_data.nlargest(50, 'MMLU_average')
fig = create_plot(top_50, 'arc:challenge|25', 'MMLU_average')
st.plotly_chart(fig)
st.header('Moral Scenarios')
fig = create_plot(filtered_data, 'arc:challenge|25', 'MMLU_moral_scenarios')
st.plotly_chart(fig)
fig = create_plot(filtered_data, 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
st.plotly_chart(fig)
fig = create_plot(filtered_data, 'MMLU_average', 'MMLU_moral_scenarios')
st.plotly_chart(fig)
fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
st.plotly_chart(fig)
fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
st.plotly_chart(fig)
|