File size: 7,433 Bytes
ecdea35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import pandas as pd
import numpy as np

from scipy import stats
import plotly.express as px

from plotly.subplots import make_subplots
import plotly.graph_objects as go

ROC = 1
PR = 2

def add_p_value_annotation(fig, array_columns, subplot=None, _format=dict(interline=0.03, text_height=1.03, color='black')):
    ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
    
    Parameters:
    ----------
    fig: figure
        plotly boxplot figure
    array_columns: np.array
        array of which columns to compare 
        e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
    subplot: None or int
        specifies if the figures has subplots and what subplot to add the notation to
    _format: dict
        format characteristics for the lines

    Returns:
    -------
    fig: figure
        figure with the added notation
    '''
    # Specify in what y_range to plot for each pair of columns
    y_range = np.zeros([len(array_columns), 2])
    for i in range(len(array_columns)):
        y_range[i] = [1.03+i*_format['interline'], 1.04+i*_format['interline']]

    # Get values from figure
    fig_dict = fig.to_dict()

    # Get indices if working with subplots
    if subplot:
        if subplot == 1:
            subplot_str = ''
        else:
            subplot_str =str(subplot)
        indices = [] #Change the box index to the indices of the data for that subplot
        for index, data in enumerate(fig_dict['data']):
            #print(index, data['xaxis'], 'x' + subplot_str)
            if data['xaxis'] == 'x' + subplot_str:
                indices = np.append(indices, index)
        indices = [int(i) for i in indices]
        print((indices))
    else:
        subplot_str = ''

    # Print the p-values
    for index, column_pair in enumerate(array_columns):
        if subplot:
            data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
        else:
            data_pair = column_pair

        # Mare sure it is selecting the data and subplot you want
        #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
        #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])

        # Get the p-value
        pvalue = stats.ttest_ind(
            fig_dict['data'][data_pair[0]]['y'],
            fig_dict['data'][data_pair[1]]['y'],
            equal_var=False,
        )[1]
        if pvalue >= 0.05:
            symbol = 'ns'
        elif pvalue >= 0.01: 
            symbol = '*'
        elif pvalue >= 0.001:
            symbol = '**'
        else:
            symbol = '***'
        # Vertical line
        fig.add_shape(type="line",
            xref="x"+subplot_str, yref="y"+subplot_str+" domain",
            x0=column_pair[0], y0=y_range[index][0], 
            x1=column_pair[0], y1=y_range[index][1],
            line=dict(color=_format['color'], width=1.5,)
        )
        # Horizontal line
        fig.add_shape(type="line",
            xref="x"+subplot_str, yref="y"+subplot_str+" domain",
            x0=column_pair[0], y0=y_range[index][1], 
            x1=column_pair[1], y1=y_range[index][1],
            line=dict(color=_format['color'], width=1.5,)
        )
        # Vertical line
        fig.add_shape(type="line",
            xref="x"+subplot_str, yref="y"+subplot_str+" domain",
            x0=column_pair[1], y0=y_range[index][0], 
            x1=column_pair[1], y1=y_range[index][1],
            line=dict(color=_format['color'], width=1.5,)
        )
        ## add text at the correct x, y coordinates
        ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
        fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
            x=(column_pair[0] + column_pair[1])/2,
            y=y_range[index][1]*_format['text_height'],
            showarrow=False,
            text=symbol,
            textangle=0,
            xref="x"+subplot_str,
            yref="y"+subplot_str+" domain"
        ))
    return fig


def box_plot(df):
    
    fig = px.box(df, x = 'Task_name', y='test_auroc', color="Model")
   
    fig.update_layout(plot_bgcolor="white")
    fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
    fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
    fig.update_layout(title={'text': "<b>ROC-AUC score distribution</b>",
                             'font':{'size':40},
                             'y': 0.96,
                             'x': 0.5,
                             'xanchor': 'center',
                             'yanchor': 'top'},

                      xaxis_title={'text': "Datasets",
                             'font':{'size':30}},
                      yaxis_title={'text': "ROC-AUC",
                             'font':{'size':30}},

                      font=dict(family="Calibri, monospace",
                                size=17
                                ))

    fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)

    fig.write_image('../figures/box_plot_integration.png', width=1.5*1200, height=0.75*1200, scale=2)
    fig.show()
    


def go_box_plot(df, metric = ROC):
    dataset_list = ['BIOSNAP', 'DAVIS', 'BindingDB']
    model_list = ['LR', 'DNN', 'GNN-CPI', 'DeepDTI', 'DeepDTA', 'DeepConv-DTI', 'Moltrans', 'ours']
    clr_list = ['red', 'orange', 'green', 'indianred', 'lightseagreen', 'goldenrod', 'magenta', 'blue']

    if metric == ROC:
        # fig_title = "<b>ROC-AUC score distribution</b>"
        file_title = "boxplot_auroc.png"
        select_metric = "test_auroc"
    else:
        # fig_title = "<b>PR-AUC score distribution</b>"
        file_title = "boxplot_auprc.png"
        select_metric = "test_auprc"

    fig = make_subplots(rows=1, cols=3, subplot_titles=[c for c in dataset_list])

    groups = df.groupby(df.Task_name)
    Legand = True

    for dataset_idx, dataset in enumerate(dataset_list):
            df_modelgroup = groups.get_group(dataset)
            model_groups = df_modelgroup.groupby(df_modelgroup.Model)
            if dataset_idx != 0:
                    Legand = False
            for model_idx, model in enumerate(model_list):
                    df_data = model_groups.get_group(model)
                    fig.append_trace(go.Box(y=df_data[select_metric],
                                name=model,
                                marker_color=clr_list[model_idx],
                                showlegend = Legand
                                ),
                                row=1,
                                col=dataset_idx+1)


    

    # fig.update_layout(title={'text': fig_title,
    #                         'font':{'size':25},
    #                         'y': 0.98,
    #                         'x': 0.46,
    #                         'xanchor': 'center',
    #                         'yanchor': 'top'})

    #    fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
    #    fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=2)
    #    fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=3)

    fig.write_image(f'../figures/{file_title}', width=1.5*1200, height=0.75*1200, scale=2)
    fig.show()


if __name__ == '__main__':
    df = pd.read_csv("../dataset/wandb_export_boxplotdata.csv")
    box_plot(df)