Corey Morris commited on
Commit
9695a47
1 Parent(s): b9b6115

Added radar chart. Compares a model to the 5 models that have the closest performance on MMLU_average

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import plotly.express as px
4
  from result_data_processor import ResultDataProcessor
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
7
 
8
  st.set_page_config(layout="wide")
9
 
@@ -47,6 +48,46 @@ def plot_top_n(df, target_column, n=10):
47
  # Show the plot
48
  st.pyplot(fig)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  data_provider = ResultDataProcessor()
51
 
52
  # st.title('Model Evaluation Results including MMLU by task')
@@ -131,6 +172,7 @@ st.download_button(
131
  mime="text/csv",
132
  )
133
 
 
134
  def create_plot(df, x_values, y_values, models=None, title=None):
135
  if models is not None:
136
  df = df[df.index.isin(models)]
@@ -215,6 +257,21 @@ if selected_x_column != selected_y_column: # Avoid creating a plot with the s
215
  else:
216
  st.write("Please select different columns for the x and y axes.")
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  # end of custom scatter plots
219
  st.markdown("## Notable findings and plots")
220
 
 
4
  from result_data_processor import ResultDataProcessor
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import plotly.graph_objects as go
8
 
9
  st.set_page_config(layout="wide")
10
 
 
48
  # Show the plot
49
  st.pyplot(fig)
50
 
51
+ # Function to create an unfilled radar chart
52
+ def create_radar_chart_unfilled(df, model_names, metrics):
53
+ fig = go.Figure()
54
+ min_value = df.loc[model_names, metrics].min().min()
55
+ max_value = df.loc[model_names, metrics].max().max()
56
+ for model_name in model_names:
57
+ values_model = df.loc[model_name, metrics]
58
+ fig.add_trace(go.Scatterpolar(
59
+ r=values_model,
60
+ theta=metrics,
61
+ name=model_name
62
+ ))
63
+
64
+ fig.update_layout(
65
+ polar=dict(
66
+ radialaxis=dict(
67
+ visible=True,
68
+ range=[min_value, max_value]
69
+ )),
70
+ showlegend=True
71
+ )
72
+ return fig
73
+
74
+
75
+ # Function to create a line chart
76
+ def create_line_chart(df, model_names, metrics):
77
+ line_data = []
78
+ for model_name in model_names:
79
+ values_model = df.loc[model_name, metrics]
80
+ for metric, value in zip(metrics, values_model):
81
+ line_data.append({'Model': model_name, 'Metric': metric, 'Value': value})
82
+
83
+ line_df = pd.DataFrame(line_data)
84
+
85
+ fig = px.line(line_df, x='Metric', y='Value', color='Model', title='Comparison of Models', line_dash_sequence=['solid'])
86
+ fig.update_layout(showlegend=True)
87
+ return fig
88
+
89
+
90
+
91
  data_provider = ResultDataProcessor()
92
 
93
  # st.title('Model Evaluation Results including MMLU by task')
 
172
  mime="text/csv",
173
  )
174
 
175
+
176
  def create_plot(df, x_values, y_values, models=None, title=None):
177
  if models is not None:
178
  df = df[df.index.isin(models)]
 
257
  else:
258
  st.write("Please select different columns for the x and y axes.")
259
 
260
+
261
+ # Section to select a model and display radar and line charts
262
+ st.header("Compare Models")
263
+ selected_model_name = st.selectbox("Select a Model:", filtered_data.index.tolist())
264
+ metrics_to_compare = ['MMLU_abstract_algebra', 'MMLU_astronomy', 'MMLU_business_ethics', 'MMLU_average', 'MMLU_moral_scenarios']
265
+ closest_models = filtered_data['MMLU_average'].sub(filtered_data.loc[selected_model_name, 'MMLU_average']).abs().nsmallest(5).index.tolist()
266
+
267
+ fig_radar = create_radar_chart_unfilled(filtered_data, closest_models, metrics_to_compare)
268
+ fig_line = create_line_chart(filtered_data, closest_models, metrics_to_compare)
269
+
270
+ st.plotly_chart(fig_radar)
271
+ st.plotly_chart(fig_line)
272
+
273
+
274
+
275
  # end of custom scatter plots
276
  st.markdown("## Notable findings and plots")
277