Corey Morris
commited on
Commit
•
9695a47
1
Parent(s):
b9b6115
Added radar chart. Compares a model to the 5 models that have the closest performance on MMLU_average
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import plotly.express as px
|
|
4 |
from result_data_processor import ResultDataProcessor
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
st.set_page_config(layout="wide")
|
9 |
|
@@ -47,6 +48,46 @@ def plot_top_n(df, target_column, n=10):
|
|
47 |
# Show the plot
|
48 |
st.pyplot(fig)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
data_provider = ResultDataProcessor()
|
51 |
|
52 |
# st.title('Model Evaluation Results including MMLU by task')
|
@@ -131,6 +172,7 @@ st.download_button(
|
|
131 |
mime="text/csv",
|
132 |
)
|
133 |
|
|
|
134 |
def create_plot(df, x_values, y_values, models=None, title=None):
|
135 |
if models is not None:
|
136 |
df = df[df.index.isin(models)]
|
@@ -215,6 +257,21 @@ if selected_x_column != selected_y_column: # Avoid creating a plot with the s
|
|
215 |
else:
|
216 |
st.write("Please select different columns for the x and y axes.")
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
# end of custom scatter plots
|
219 |
st.markdown("## Notable findings and plots")
|
220 |
|
|
|
4 |
from result_data_processor import ResultDataProcessor
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
+
import plotly.graph_objects as go
|
8 |
|
9 |
st.set_page_config(layout="wide")
|
10 |
|
|
|
48 |
# Show the plot
|
49 |
st.pyplot(fig)
|
50 |
|
51 |
+
# Function to create an unfilled radar chart
|
52 |
+
def create_radar_chart_unfilled(df, model_names, metrics):
|
53 |
+
fig = go.Figure()
|
54 |
+
min_value = df.loc[model_names, metrics].min().min()
|
55 |
+
max_value = df.loc[model_names, metrics].max().max()
|
56 |
+
for model_name in model_names:
|
57 |
+
values_model = df.loc[model_name, metrics]
|
58 |
+
fig.add_trace(go.Scatterpolar(
|
59 |
+
r=values_model,
|
60 |
+
theta=metrics,
|
61 |
+
name=model_name
|
62 |
+
))
|
63 |
+
|
64 |
+
fig.update_layout(
|
65 |
+
polar=dict(
|
66 |
+
radialaxis=dict(
|
67 |
+
visible=True,
|
68 |
+
range=[min_value, max_value]
|
69 |
+
)),
|
70 |
+
showlegend=True
|
71 |
+
)
|
72 |
+
return fig
|
73 |
+
|
74 |
+
|
75 |
+
# Function to create a line chart
|
76 |
+
def create_line_chart(df, model_names, metrics):
|
77 |
+
line_data = []
|
78 |
+
for model_name in model_names:
|
79 |
+
values_model = df.loc[model_name, metrics]
|
80 |
+
for metric, value in zip(metrics, values_model):
|
81 |
+
line_data.append({'Model': model_name, 'Metric': metric, 'Value': value})
|
82 |
+
|
83 |
+
line_df = pd.DataFrame(line_data)
|
84 |
+
|
85 |
+
fig = px.line(line_df, x='Metric', y='Value', color='Model', title='Comparison of Models', line_dash_sequence=['solid'])
|
86 |
+
fig.update_layout(showlegend=True)
|
87 |
+
return fig
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
data_provider = ResultDataProcessor()
|
92 |
|
93 |
# st.title('Model Evaluation Results including MMLU by task')
|
|
|
172 |
mime="text/csv",
|
173 |
)
|
174 |
|
175 |
+
|
176 |
def create_plot(df, x_values, y_values, models=None, title=None):
|
177 |
if models is not None:
|
178 |
df = df[df.index.isin(models)]
|
|
|
257 |
else:
|
258 |
st.write("Please select different columns for the x and y axes.")
|
259 |
|
260 |
+
|
261 |
+
# Section to select a model and display radar and line charts
|
262 |
+
st.header("Compare Models")
|
263 |
+
selected_model_name = st.selectbox("Select a Model:", filtered_data.index.tolist())
|
264 |
+
metrics_to_compare = ['MMLU_abstract_algebra', 'MMLU_astronomy', 'MMLU_business_ethics', 'MMLU_average', 'MMLU_moral_scenarios']
|
265 |
+
closest_models = filtered_data['MMLU_average'].sub(filtered_data.loc[selected_model_name, 'MMLU_average']).abs().nsmallest(5).index.tolist()
|
266 |
+
|
267 |
+
fig_radar = create_radar_chart_unfilled(filtered_data, closest_models, metrics_to_compare)
|
268 |
+
fig_line = create_line_chart(filtered_data, closest_models, metrics_to_compare)
|
269 |
+
|
270 |
+
st.plotly_chart(fig_radar)
|
271 |
+
st.plotly_chart(fig_line)
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
# end of custom scatter plots
|
276 |
st.markdown("## Notable findings and plots")
|
277 |
|