|
import pandas as pd |
|
import numpy as np |
|
|
|
from scipy import stats |
|
import plotly.express as px |
|
|
|
from plotly.subplots import make_subplots |
|
import plotly.graph_objects as go |
|
|
|
ROC = 1 |
|
PR = 2 |
|
|
|
def add_p_value_annotation(fig, array_columns, subplot=None, _format=dict(interline=0.03, text_height=1.03, color='black')): |
|
''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison) |
|
|
|
Parameters: |
|
---------- |
|
fig: figure |
|
plotly boxplot figure |
|
array_columns: np.array |
|
array of which columns to compare |
|
e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2 |
|
subplot: None or int |
|
specifies if the figures has subplots and what subplot to add the notation to |
|
_format: dict |
|
format characteristics for the lines |
|
|
|
Returns: |
|
------- |
|
fig: figure |
|
figure with the added notation |
|
''' |
|
|
|
y_range = np.zeros([len(array_columns), 2]) |
|
for i in range(len(array_columns)): |
|
y_range[i] = [1.03+i*_format['interline'], 1.04+i*_format['interline']] |
|
|
|
|
|
fig_dict = fig.to_dict() |
|
|
|
|
|
if subplot: |
|
if subplot == 1: |
|
subplot_str = '' |
|
else: |
|
subplot_str =str(subplot) |
|
indices = [] |
|
for index, data in enumerate(fig_dict['data']): |
|
|
|
if data['xaxis'] == 'x' + subplot_str: |
|
indices = np.append(indices, index) |
|
indices = [int(i) for i in indices] |
|
print((indices)) |
|
else: |
|
subplot_str = '' |
|
|
|
|
|
for index, column_pair in enumerate(array_columns): |
|
if subplot: |
|
data_pair = [indices[column_pair[0]], indices[column_pair[1]]] |
|
else: |
|
data_pair = column_pair |
|
|
|
|
|
|
|
|
|
|
|
|
|
pvalue = stats.ttest_ind( |
|
fig_dict['data'][data_pair[0]]['y'], |
|
fig_dict['data'][data_pair[1]]['y'], |
|
equal_var=False, |
|
)[1] |
|
if pvalue >= 0.05: |
|
symbol = 'ns' |
|
elif pvalue >= 0.01: |
|
symbol = '*' |
|
elif pvalue >= 0.001: |
|
symbol = '**' |
|
else: |
|
symbol = '***' |
|
|
|
fig.add_shape(type="line", |
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain", |
|
x0=column_pair[0], y0=y_range[index][0], |
|
x1=column_pair[0], y1=y_range[index][1], |
|
line=dict(color=_format['color'], width=1.5,) |
|
) |
|
|
|
fig.add_shape(type="line", |
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain", |
|
x0=column_pair[0], y0=y_range[index][1], |
|
x1=column_pair[1], y1=y_range[index][1], |
|
line=dict(color=_format['color'], width=1.5,) |
|
) |
|
|
|
fig.add_shape(type="line", |
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain", |
|
x0=column_pair[1], y0=y_range[index][0], |
|
x1=column_pair[1], y1=y_range[index][1], |
|
line=dict(color=_format['color'], width=1.5,) |
|
) |
|
|
|
|
|
fig.add_annotation(dict(font=dict(color=_format['color'],size=14), |
|
x=(column_pair[0] + column_pair[1])/2, |
|
y=y_range[index][1]*_format['text_height'], |
|
showarrow=False, |
|
text=symbol, |
|
textangle=0, |
|
xref="x"+subplot_str, |
|
yref="y"+subplot_str+" domain" |
|
)) |
|
return fig |
|
|
|
|
|
def box_plot(df): |
|
|
|
fig = px.box(df, x = 'Task_name', y='test_auroc', color="Model") |
|
|
|
fig.update_layout(plot_bgcolor="white") |
|
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False) |
|
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False) |
|
fig.update_layout(title={'text': "<b>ROC-AUC score distribution</b>", |
|
'font':{'size':40}, |
|
'y': 0.96, |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'yanchor': 'top'}, |
|
|
|
xaxis_title={'text': "Datasets", |
|
'font':{'size':30}}, |
|
yaxis_title={'text': "ROC-AUC", |
|
'font':{'size':30}}, |
|
|
|
font=dict(family="Calibri, monospace", |
|
size=17 |
|
)) |
|
|
|
fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1) |
|
|
|
fig.write_image('../figures/box_plot_integration.png', width=1.5*1200, height=0.75*1200, scale=2) |
|
fig.show() |
|
|
|
|
|
|
|
def go_box_plot(df, metric = ROC): |
|
dataset_list = ['BIOSNAP', 'DAVIS', 'BindingDB'] |
|
model_list = ['LR', 'DNN', 'GNN-CPI', 'DeepDTI', 'DeepDTA', 'DeepConv-DTI', 'Moltrans', 'ours'] |
|
clr_list = ['red', 'orange', 'green', 'indianred', 'lightseagreen', 'goldenrod', 'magenta', 'blue'] |
|
|
|
if metric == ROC: |
|
|
|
file_title = "boxplot_auroc.png" |
|
select_metric = "test_auroc" |
|
else: |
|
|
|
file_title = "boxplot_auprc.png" |
|
select_metric = "test_auprc" |
|
|
|
fig = make_subplots(rows=1, cols=3, subplot_titles=[c for c in dataset_list]) |
|
|
|
groups = df.groupby(df.Task_name) |
|
Legand = True |
|
|
|
for dataset_idx, dataset in enumerate(dataset_list): |
|
df_modelgroup = groups.get_group(dataset) |
|
model_groups = df_modelgroup.groupby(df_modelgroup.Model) |
|
if dataset_idx != 0: |
|
Legand = False |
|
for model_idx, model in enumerate(model_list): |
|
df_data = model_groups.get_group(model) |
|
fig.append_trace(go.Box(y=df_data[select_metric], |
|
name=model, |
|
marker_color=clr_list[model_idx], |
|
showlegend = Legand |
|
), |
|
row=1, |
|
col=dataset_idx+1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.write_image(f'../figures/{file_title}', width=1.5*1200, height=0.75*1200, scale=2) |
|
fig.show() |
|
|
|
|
|
if __name__ == '__main__': |
|
df = pd.read_csv("../dataset/wandb_export_boxplotdata.csv") |
|
box_plot(df) |