href / src /plt.py
natolambert's picture
upload plot
0b8c16d
raw
history blame
1.94 kB
import matplotlib.pyplot as plt
import pandas as pd
from .utils import undo_hyperlink
def plot_avg_correlation(df1, df2):
"""
Plots the "average" column for each unique model that appears in both dataframes.
Parameters:
- df1: pandas DataFrame containing columns "model" and "average".
- df2: pandas DataFrame containing columns "model" and "average".
"""
# Identify the unique models that appear in both DataFrames
common_models = pd.Series(list(set(df1['model']) & set(df2['model'])))
# Set up the plot
plt.figure(figsize=(13, 6), constrained_layout=True)
# axes from 0 to 1 for x and y
plt.xlim(0.475, 0.8)
plt.ylim(0.475, 0.8)
# larger font (16)
plt.rcParams.update({'font.size': 12, 'axes.labelsize': 14,'axes.titlesize': 14})
# plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
# plt.tight_layout()
# plt.margins(0,0)
for model in common_models:
# Filter data for the current model
df1_model_data = df1[df1['model'] == model]['average'].values
df2_model_data = df2[df2['model'] == model]['average'].values
# Plotting
plt.scatter(df1_model_data, df2_model_data, label=model)
m_name = undo_hyperlink(model)
if m_name == "No text found":
m_name = "Random"
# Add text above each point like
# plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
plt.text(df1_model_data - .005, df2_model_data, m_name, horizontalalignment='right', verticalalignment='center')
# add correlation line to scatter plot
# first, compute correlation
corr = df1['average'].corr(df2['average'])
# add correlation line based on corr
plt.xlabel('HERM Eval. Set Avg.', fontsize=16)
plt.ylabel('Pref. Test Sets Avg.', fontsize=16)
# plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
return plt