Corey Morris commited on
Commit
337b761
1 Parent(s): ba99486

fixed reversed plot. extracted making chart into a method

Browse files
Files changed (1) hide show
  1. app.py +31 -23
app.py CHANGED
@@ -75,27 +75,35 @@ selected_models = st.multiselect(
75
  filtered_data = data_provider.get_data(selected_models)
76
  st.dataframe(filtered_data)
77
 
78
- # Create a plot with new data
79
- df = pd.DataFrame({
80
- 'Model': list(filtered_data['Model Name']),
81
- # use debug to troubheshoot error
82
- 'arc:challenge|25': list(filtered_data['arc:challenge|25']),
83
- 'moral_scenarios|5': list(filtered_data['moral_scenarios|5']),
84
- })
85
-
86
- # Calculate color column
87
- df['color'] = 'purple'
88
- df.loc[df['moral_scenarios|5'] < df['arc:challenge|25'], 'color'] = 'red'
89
- df.loc[df['moral_scenarios|5'] > df['arc:challenge|25'], 'color'] = 'blue'
90
-
91
- # Create the scatter plot
92
- fig = px.scatter(df, x='arc:challenge|25', y='moral_scenarios|5', color='color', hover_data=['Model'])
93
- fig.update_layout(showlegend=False, # hide legend
94
- xaxis = dict(autorange="reversed"), # reverse X-axis
95
- yaxis = dict(autorange="reversed")) # reverse Y-axis
96
-
97
- # Show the plot in Streamlit
 
 
 
 
 
 
 
 
 
 
 
98
  st.plotly_chart(fig)
99
-
100
-
101
-
 
75
  filtered_data = data_provider.get_data(selected_models)
76
  st.dataframe(filtered_data)
77
 
78
+ def create_plot(df, model_column, arc_column, moral_column, models=None):
79
+ # Filter the dataframe if specific models are provided
80
+ if models is not None:
81
+ df = df[df[model_column].isin(models)]
82
+
83
+ # Create a plot with new data
84
+ plot_data = pd.DataFrame({
85
+ 'Model': list(df[model_column]),
86
+ arc_column: list(df[arc_column]),
87
+ moral_column: list(df[moral_column]),
88
+ })
89
+
90
+ # Calculate color column
91
+ plot_data['color'] = 'purple'
92
+ plot_data.loc[plot_data[moral_column] < plot_data[arc_column], 'color'] = 'red'
93
+ plot_data.loc[plot_data[moral_column] > plot_data[arc_column], 'color'] = 'blue'
94
+
95
+ # Create the scatter plot
96
+ fig = px.scatter(plot_data, x=arc_column, y=moral_column, color='color', hover_data=['Model'])
97
+ fig.update_layout(showlegend=False, # hide legend
98
+ xaxis_title='ARC Accuracy',
99
+ yaxis_title='Moral Scenarios Accuracy',
100
+ xaxis = dict(),
101
+ yaxis = dict())
102
+
103
+ return fig
104
+
105
+ # models_to_plot = ['Model1', 'Model2', 'Model3']
106
+ # fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
107
+
108
+ fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5')
109
  st.plotly_chart(fig)