jinysun commited on
Commit
ecdea35
1 Parent(s): dabfe71

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ util/data/bindingdb_kd.tab filter=lfs diff=lfs merge=lfs -text
37
+ util/data/davis.tab filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -13,12 +13,16 @@ st.title("🔋DeepDAP")
13
 
14
  url1= r"https://docs.google.com/spreadsheets/d/1AKkZS04VF3osFT36aNHIb4iUbV8D1uNfsldcpHXogj0/gviz/tq?tqx=out:csv&sheet=dap"
15
  df1 = pd.read_csv(url1, dtype=str, encoding='utf-8')
16
-
17
- text_search = st.text_input("🔍Search papers or molecules", value="")
18
- m1 = df1["Donor_Name"].str.contains(text_search)
19
- m2 = df1["reference"].str.contains(text_search)
20
- m3 = df1["Acceptor_Name"].str.contains(text_search)
21
- df_search = df1[m1 | m2|m3]
 
 
 
 
22
  if text_search:
23
  st.write(df_search)
24
  st.download_button( "⬇️Download edited files as .csv", df_search.to_csv(), "df_search.csv", use_container_width=True)
@@ -28,16 +32,23 @@ st.download_button(
28
  "⬇️ Download edited files as .csv", edited_df.to_csv(), "edited_df.csv", use_container_width=True
29
  )
30
 
31
- molecule = st.text_input("👨‍🔬Molecule")
32
- smile_code = st_ketcher(molecule)
33
- st.markdown("🏆New SMILES of edited molecules: {smile_code }")
34
-
35
- acceptor= st.text_input("🎈SMILES of acceptor")
36
-
37
- donor = st.text_input("🎈SMILES of donor")
38
-
 
 
 
 
 
 
 
39
  try:
40
  pce = run.smiles_aas_test( str(acceptor ), str(donor) )
41
- st.markdown("⚡PCE: ``{pce}``")
42
  except:
43
- st.markdown("⚡PCE: None ")
 
13
 
14
  url1= r"https://docs.google.com/spreadsheets/d/1AKkZS04VF3osFT36aNHIb4iUbV8D1uNfsldcpHXogj0/gviz/tq?tqx=out:csv&sheet=dap"
15
  df1 = pd.read_csv(url1, dtype=str, encoding='utf-8')
16
+ col1, col2 = st.columns(2)
17
+ with col1:
18
+ text_search = st.text_input("🔍Search papers or molecules", value="")
19
+ m1 = df1["Donor_Name"].str.contains(text_search)
20
+ m2 = df1["reference"].str.contains(text_search)
21
+ m3 = df1["Acceptor_Name"].str.contains(text_search)
22
+ df_search = df1[m1 | m2|m3]
23
+ with col2:
24
+ st.link_button("📝Database", r"https://docs.google.com/spreadsheets/d/1AKkZS04VF3osFT36aNHIb4iUbV8D1uNfsldcpHXogj0")
25
+ st.caption('🎉If you want to update the database, click the button.')
26
  if text_search:
27
  st.write(df_search)
28
  st.download_button( "⬇️Download edited files as .csv", df_search.to_csv(), "df_search.csv", use_container_width=True)
 
32
  "⬇️ Download edited files as .csv", edited_df.to_csv(), "edited_df.csv", use_container_width=True
33
  )
34
 
35
+ option = st.selectbox(
36
+ "How would you like to be contacted?",
37
+ ("Donor", "Acceptor"), placeholder="Select the type of active layer..."
38
+ )
39
+ if option == 'Acceptor':
40
+
41
+ molecule = st.text_input("👨‍🔬Acceptor Molecule" )
42
+ acceptor= st_ketcher(molecule )
43
+ st.markdown(f"🏆New SMILES of edited acceptor molecules: {acceptor}")
44
+ donor= st.text_input("📋 Donor Molecule")
45
+ if option =='Donor':
46
+ do= st.text_input("👨‍🔬Donor Molecule" )
47
+ donor = st_ketcher(do)
48
+ st.markdown(f"🏆New SMILES of edited donor molecules: {donor}")
49
+ acceptor = st.text_input("📋 Acceptor Molecule")
50
  try:
51
  pce = run.smiles_aas_test( str(acceptor ), str(donor) )
52
+ st.markdown(f"⚡PCE: ``{pce}``")
53
  except:
54
+ st.markdown(f"⚡PCE: None ")
config/config_hparam.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ { "name": "biomarker_log",
2
+
3
+ "d_model_name" : "DeepChem/ChemBERTa-10M-MTR",
4
+ "p_model_name" : "DeepChem/ChemBERTa-77M-MLM",
5
+ "gpu_ids" : "0",
6
+ "model_mode" : "train",
7
+ "load_checkpoint" : "./checkpoint/bindingDB/test.ckpt",
8
+
9
+ "prot_maxlength" : 360,
10
+ "layer_limit" : true,
11
+
12
+ "max_epoch": 16,
13
+ "batch_size": 40,
14
+ "num_workers": 0,
15
+
16
+ "task_name" : "OSC",
17
+ "lr": 1e-4,
18
+ "layer_features" : [512, 128, 64, 1],
19
+ "dropout" : 0.1,
20
+ "loss_fn" : "MSE",
21
+
22
+ "traindata_rate" : 1.0,
23
+ "pretrained": {"chem":true, "prot":true},
24
+ "num_seed" : 111
25
+ }
26
+
config/predict.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ { "name": "biomarker_log",
2
+
3
+ "d_model_name" : "DeepChem/ChemBERTa-10M-MTR",
4
+ "p_model_name" : "DeepChem/ChemBERTa-77M-MLM",
5
+ "gpu_ids" : "0",
6
+ "model_mode" : "test",
7
+ "load_checkpoint" : "./OSC/test.ckpt",
8
+
9
+ "prot_maxlength" : 360,
10
+ "layer_limit" : true,
11
+
12
+ "max_epoch": 16,
13
+ "batch_size": 40,
14
+ "num_workers": 0,
15
+
16
+ "task_name" : "OSC",
17
+ "lr": 1e-4,
18
+ "layer_features" : [512, 128, 64, 1],
19
+ "dropout" : 0.1,
20
+ "loss_fn" : "MSE",
21
+
22
+ "traindata_rate" : 1.0,
23
+ "pretrained": {"chem":true, "prot":true},
24
+ "num_seed" : 111
25
+ }
26
+
util/__pycache__/attention_flow.cpython-38.pyc ADDED
Binary file (6.07 kB). View file
 
util/__pycache__/emetric.cpython-38.pyc ADDED
Binary file (1.87 kB). View file
 
util/__pycache__/regression_metric.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
util/__pycache__/stream.cpython-38.pyc ADDED
Binary file (2.96 kB). View file
 
util/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.6 kB). View file
 
util/attention_flow.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ import seaborn as sns
8
+ import itertools
9
+ import matplotlib as mpl
10
+ # import cugraph as cnx
11
+
12
+ rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0,
13
+ 'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16}
14
+ plt.rcParams.update(**rc)
15
+ mpl.rcParams['axes.linewidth'] = .5 #set the value globally
16
+
17
+
18
+ def plot_attention_heatmap(att, s_position, t_positions, input_tokens):
19
+
20
+ cls_att = np.flip(att[:,s_position, t_positions], axis=0)
21
+ xticklb = list(itertools.compress(input_tokens, [i in t_positions for i in np.arange(len(input_tokens))]))
22
+ yticklb = [str(i) if i%2 ==0 else '' for i in np.arange(att.shape[0],0, -1)]
23
+ ax = sns.heatmap(cls_att, xticklabels=xticklb, yticklabels=yticklb, cmap="YlOrRd")
24
+
25
+ return ax
26
+
27
+ def convert_adjmat_tomats(adjmat, n_layers, l):
28
+ mats = np.zeros((n_layers,l,l))
29
+
30
+ for i in np.arange(n_layers):
31
+ mats[i] = adjmat[(i+1)*l:(i+2)*l,i*l:(i+1)*l]
32
+
33
+ return mats
34
+
35
+ def make_residual_attention(attentions):
36
+ all_attention = [att.detach().cpu().numpy() for att in attentions]
37
+ attentions_mat = np.asarray(all_attention)[:,0]
38
+
39
+ res_att_mat = attentions_mat.sum(axis=1)/attentions_mat.shape[1]
40
+ res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
41
+ res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]
42
+
43
+ return attentions_mat, res_att_mat
44
+
45
+ ## -------------------------------------------------------- ##
46
+ ## -- Make flow network (No Print Node - edge Connection)-- ##
47
+ ## -------------------------------------------------------- ##
48
+
49
+ def make_flow_network(mat, input_tokens):
50
+ n_layers, length, _ = mat.shape
51
+ adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
52
+ labels_to_index = {}
53
+ for k in np.arange(length):
54
+ labels_to_index[str(k)+"_"+input_tokens[k]] = k
55
+
56
+ for i in np.arange(1,n_layers+1):
57
+ for k_f in np.arange(length):
58
+ index_from = (i)*length+k_f
59
+ label = "L"+str(i)+"_"+str(k_f)
60
+ labels_to_index[label] = index_from
61
+ for k_t in np.arange(length):
62
+ index_to = (i-1)*length+k_t
63
+ adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
64
+
65
+ net_graph=nx.from_numpy_matrix(adj_mat, create_using=nx.DiGraph())
66
+ for i in np.arange(adj_mat.shape[0]):
67
+ for j in np.arange(adj_mat.shape[1]):
68
+ nx.set_edge_attributes(net_graph, {(i,j): adj_mat[i,j]}, 'capacity')
69
+
70
+ return net_graph, labels_to_index
71
+
72
+
73
+ def make_input_node(attention_mat, res_labels_to_index):
74
+ input_nodes = []
75
+ for key in res_labels_to_index:
76
+ if res_labels_to_index[key] < attention_mat.shape[-1]:
77
+ input_nodes.append(key)
78
+
79
+ return input_nodes
80
+ ## ------------------------------------------------ ##
81
+ ## -- Draw Attention flow node - Edge Connection -- ##
82
+ ## ------------------------------------------------ ##
83
+
84
+ ##-- networkx graph Initation and Calculation flow --##
85
+ def get_adjmat(mat, input_tokens):
86
+ n_layers, length, _ = mat.shape
87
+ adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
88
+ labels_to_index = {}
89
+ for k in np.arange(length):
90
+ labels_to_index[str(k)+"_"+input_tokens[k]] = k
91
+
92
+ for i in np.arange(1,n_layers+1):
93
+ for k_f in np.arange(length):
94
+ index_from = (i)*length+k_f
95
+ label = "L"+str(i)+"_"+str(k_f)
96
+ labels_to_index[label] = index_from
97
+ for k_t in np.arange(length):
98
+ index_to = (i-1)*length+k_t
99
+ adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
100
+
101
+ return adj_mat, labels_to_index
102
+
103
+ def draw_attention_graph(adjmat, labels_to_index, n_layers, length):
104
+ A = adjmat
105
+ net_graph=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
106
+ for i in np.arange(A.shape[0]):
107
+ for j in np.arange(A.shape[1]):
108
+ nx.set_edge_attributes(net_graph, {(i,j): A[i,j]}, 'capacity')
109
+
110
+ pos = {}
111
+ label_pos = {}
112
+ for i in np.arange(n_layers+1):
113
+ for k_f in np.arange(length):
114
+ pos[i*length+k_f] = ((i+0.4)*2, length - k_f)
115
+ label_pos[i*length+k_f] = (i*2, length - k_f)
116
+
117
+ index_to_labels = {}
118
+ for key in labels_to_index:
119
+ index_to_labels[labels_to_index[key]] = key.split("_")[-1]
120
+ if labels_to_index[key] >= length:
121
+ index_to_labels[labels_to_index[key]] = ''
122
+
123
+ #plt.figure(1,figsize=(20,12))
124
+ nx.draw_networkx_nodes(net_graph,pos,node_color='green', labels=index_to_labels, node_size=50)
125
+ nx.draw_networkx_labels(net_graph,pos=label_pos, labels=index_to_labels, font_size=18)
126
+
127
+ all_weights = []
128
+ #4 a. Iterate through the graph nodes to gather all the weights
129
+ for (node1,node2,data) in net_graph.edges(data=True):
130
+ all_weights.append(data['weight']) #we'll use this when determining edge thickness
131
+
132
+ #4 b. Get unique weights
133
+ unique_weights = list(set(all_weights))
134
+
135
+ #4 c. Plot the edges - one by one!
136
+ for weight in unique_weights:
137
+ #4 d. Form a filtered list with just the weight you want to draw
138
+ weighted_edges = [(node1,node2) for (node1,node2,edge_attr) in net_graph.edges(data=True) if edge_attr['weight']==weight]
139
+ #4 e. I think multiplying by [num_nodes/sum(all_weights)] makes the graphs edges look cleaner
140
+
141
+ w = weight #(weight - min(all_weights))/(max(all_weights) - min(all_weights))
142
+ width = w
143
+ nx.draw_networkx_edges(net_graph,pos,edgelist=weighted_edges,width=width, edge_color='darkblue')
144
+
145
+ return net_graph
146
+
147
+ def compute_flows(G, labels_to_index, input_nodes, length):
148
+ number_of_nodes = len(labels_to_index)
149
+ flow_values=np.zeros((number_of_nodes,number_of_nodes))
150
+ for key in tqdm(labels_to_index, desc="flow algorithms", total=len(labels_to_index)):
151
+ if key not in input_nodes:
152
+ current_layer = int(labels_to_index[key] / length)
153
+ pre_layer = current_layer - 1
154
+ u = labels_to_index[key]
155
+ for inp_node_key in input_nodes:
156
+ v = labels_to_index[inp_node_key]
157
+ flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
158
+ # flow_value = cnx
159
+ flow_values[u][pre_layer*length+v ] = flow_value
160
+ flow_values[u] /= flow_values[u].sum()
161
+
162
+ return flow_values
163
+
164
+ def compute_node_flow(G, labels_to_index, input_nodes, output_nodes,length):
165
+ number_of_nodes = len(labels_to_index)
166
+ flow_values=np.zeros((number_of_nodes,number_of_nodes))
167
+ for key in output_nodes:
168
+ if key not in input_nodes:
169
+ current_layer = int(labels_to_index[key] / length)
170
+ pre_layer = current_layer - 1
171
+ u = labels_to_index[key]
172
+ for inp_node_key in input_nodes:
173
+ v = labels_to_index[inp_node_key]
174
+ flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
175
+ flow_values[u][pre_layer*length+v ] = flow_value
176
+ flow_values[u] /= flow_values[u].sum()
177
+
178
+ return flow_values
179
+
180
+ def compute_joint_attention(att_mat, add_residual=True):
181
+ if add_residual:
182
+ residual_att = np.eye(att_mat.shape[1])[None,...]
183
+ aug_att_mat = att_mat + residual_att
184
+ aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[...,None]
185
+ else:
186
+ aug_att_mat = att_mat
187
+
188
+ joint_attentions = np.zeros(aug_att_mat.shape)
189
+
190
+ layers = joint_attentions.shape[0]
191
+ joint_attentions[0] = aug_att_mat[0]
192
+ for i in np.arange(1,layers):
193
+ joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i-1])
194
+
195
+ return joint_attentions
util/attention_plot.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+
6
+ def make_attention_table(att, tokens, numb, token_idx = 0, layerNumb = -1):
7
+ token_att = att[layerNumb, token_idx, range(1, len(tokens))]
8
+
9
+ token_label=[]
10
+ token_numb=[]
11
+ for idx, token in enumerate(tokens[1:]) :
12
+ token_label.append(f"<b>{token}</b>")
13
+ token_numb.append(f"{idx}")
14
+
15
+ pair = list(zip(token_numb, token_att))
16
+
17
+ df = pd.DataFrame(pair, columns=["Amino acid", "Attention rate"])
18
+ df.to_csv(f"amino_acid_seq_attention_{numb}.csv", index=None)
19
+
20
+ top3_idx = sorted(range(len(token_att)), key=lambda i: token_att[i], reverse=True)[:3]
21
+
22
+ colors = ['cornflowerblue', ] * len(token_numb)
23
+
24
+ for i in top3_idx:
25
+ colors[i] = 'crimson'
26
+
27
+ fig = go.Figure(data=[go.Bar(
28
+ x=df["Amino acid"],
29
+ y=df["Attention rate"],
30
+ # range_y=[min(token_att), max(token_att)],
31
+ marker_color=colors # marker color can be a single color value or an iterable
32
+ )])
33
+
34
+ # fig = px.histogram(df, x="Amino acid", y="Attention rate", range_y=[min(token_att), max(token_att)])
35
+
36
+ fig.update_layout(plot_bgcolor="white")
37
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
38
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
39
+ fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
40
+ 'font':{'size':40},
41
+ 'y': 0.96,
42
+ 'x': 0.5,
43
+ 'xanchor': 'center',
44
+ 'yanchor': 'top'},
45
+
46
+ xaxis=dict(tickmode='array',
47
+ tickvals=token_numb,
48
+ ticktext=token_label
49
+ ),
50
+
51
+ xaxis_title={'text': "Amino acid sequence",
52
+ 'font':{'size':30}},
53
+ yaxis_title={'text': "Attention rate",
54
+ 'font':{'size':30}},
55
+
56
+ font=dict(family="Calibri, monospace",
57
+ size=17
58
+ ))
59
+
60
+ fig.write_image(f'figures/Amino_acid_seq_{numb}.png', width=1.5*1200, height=0.75*1200, scale=2)
61
+ fig.show()
62
+
63
+
64
+ def read_attention():
65
+ df = pd.read_csv("../amino_acid_seq_attention.csv")
66
+ # d_flow_values = np.asarray(d_read_flow_values)
67
+
68
+ fig = px.bar(df, x="Amino acid", y="Attention rate", range_y=[min(df["Attention rate"]), max(df["Attention rate"])])
69
+
70
+ fig.update_layout(plot_bgcolor="white")
71
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
72
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
73
+ fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
74
+ 'font':{'size':40},
75
+ 'y': 0.96,
76
+ 'x': 0.5,
77
+ 'xanchor': 'center',
78
+ 'yanchor': 'top'},
79
+
80
+ xaxis_title={'text': "Amino acid sequence",
81
+ 'font':{'size':30}},
82
+ yaxis_title={'text': "Attention rate",
83
+ 'font':{'size':30}},
84
+
85
+ font=dict(family="Calibri, monospace",
86
+ size=17
87
+ ))
88
+
89
+ fig.write_image('figures/Amino_acid_seq.png', width=1.5*1200, height=0.75*1200, scale=2)
90
+ fig.show()
91
+
92
+ if __name__ == '__main__':
93
+ read_attention()
util/boxplot.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ from scipy import stats
5
+ import plotly.express as px
6
+
7
+ from plotly.subplots import make_subplots
8
+ import plotly.graph_objects as go
9
+
10
+ ROC = 1
11
+ PR = 2
12
+
13
+ def add_p_value_annotation(fig, array_columns, subplot=None, _format=dict(interline=0.03, text_height=1.03, color='black')):
14
+ ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
15
+
16
+ Parameters:
17
+ ----------
18
+ fig: figure
19
+ plotly boxplot figure
20
+ array_columns: np.array
21
+ array of which columns to compare
22
+ e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
23
+ subplot: None or int
24
+ specifies if the figures has subplots and what subplot to add the notation to
25
+ _format: dict
26
+ format characteristics for the lines
27
+
28
+ Returns:
29
+ -------
30
+ fig: figure
31
+ figure with the added notation
32
+ '''
33
+ # Specify in what y_range to plot for each pair of columns
34
+ y_range = np.zeros([len(array_columns), 2])
35
+ for i in range(len(array_columns)):
36
+ y_range[i] = [1.03+i*_format['interline'], 1.04+i*_format['interline']]
37
+
38
+ # Get values from figure
39
+ fig_dict = fig.to_dict()
40
+
41
+ # Get indices if working with subplots
42
+ if subplot:
43
+ if subplot == 1:
44
+ subplot_str = ''
45
+ else:
46
+ subplot_str =str(subplot)
47
+ indices = [] #Change the box index to the indices of the data for that subplot
48
+ for index, data in enumerate(fig_dict['data']):
49
+ #print(index, data['xaxis'], 'x' + subplot_str)
50
+ if data['xaxis'] == 'x' + subplot_str:
51
+ indices = np.append(indices, index)
52
+ indices = [int(i) for i in indices]
53
+ print((indices))
54
+ else:
55
+ subplot_str = ''
56
+
57
+ # Print the p-values
58
+ for index, column_pair in enumerate(array_columns):
59
+ if subplot:
60
+ data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
61
+ else:
62
+ data_pair = column_pair
63
+
64
+ # Mare sure it is selecting the data and subplot you want
65
+ #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
66
+ #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
67
+
68
+ # Get the p-value
69
+ pvalue = stats.ttest_ind(
70
+ fig_dict['data'][data_pair[0]]['y'],
71
+ fig_dict['data'][data_pair[1]]['y'],
72
+ equal_var=False,
73
+ )[1]
74
+ if pvalue >= 0.05:
75
+ symbol = 'ns'
76
+ elif pvalue >= 0.01:
77
+ symbol = '*'
78
+ elif pvalue >= 0.001:
79
+ symbol = '**'
80
+ else:
81
+ symbol = '***'
82
+ # Vertical line
83
+ fig.add_shape(type="line",
84
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
85
+ x0=column_pair[0], y0=y_range[index][0],
86
+ x1=column_pair[0], y1=y_range[index][1],
87
+ line=dict(color=_format['color'], width=1.5,)
88
+ )
89
+ # Horizontal line
90
+ fig.add_shape(type="line",
91
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
92
+ x0=column_pair[0], y0=y_range[index][1],
93
+ x1=column_pair[1], y1=y_range[index][1],
94
+ line=dict(color=_format['color'], width=1.5,)
95
+ )
96
+ # Vertical line
97
+ fig.add_shape(type="line",
98
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
99
+ x0=column_pair[1], y0=y_range[index][0],
100
+ x1=column_pair[1], y1=y_range[index][1],
101
+ line=dict(color=_format['color'], width=1.5,)
102
+ )
103
+ ## add text at the correct x, y coordinates
104
+ ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
105
+ fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
106
+ x=(column_pair[0] + column_pair[1])/2,
107
+ y=y_range[index][1]*_format['text_height'],
108
+ showarrow=False,
109
+ text=symbol,
110
+ textangle=0,
111
+ xref="x"+subplot_str,
112
+ yref="y"+subplot_str+" domain"
113
+ ))
114
+ return fig
115
+
116
+
117
+ def box_plot(df):
118
+
119
+ fig = px.box(df, x = 'Task_name', y='test_auroc', color="Model")
120
+
121
+ fig.update_layout(plot_bgcolor="white")
122
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
123
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
124
+ fig.update_layout(title={'text': "<b>ROC-AUC score distribution</b>",
125
+ 'font':{'size':40},
126
+ 'y': 0.96,
127
+ 'x': 0.5,
128
+ 'xanchor': 'center',
129
+ 'yanchor': 'top'},
130
+
131
+ xaxis_title={'text': "Datasets",
132
+ 'font':{'size':30}},
133
+ yaxis_title={'text': "ROC-AUC",
134
+ 'font':{'size':30}},
135
+
136
+ font=dict(family="Calibri, monospace",
137
+ size=17
138
+ ))
139
+
140
+ fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
141
+
142
+ fig.write_image('../figures/box_plot_integration.png', width=1.5*1200, height=0.75*1200, scale=2)
143
+ fig.show()
144
+
145
+
146
+
147
+ def go_box_plot(df, metric = ROC):
148
+ dataset_list = ['BIOSNAP', 'DAVIS', 'BindingDB']
149
+ model_list = ['LR', 'DNN', 'GNN-CPI', 'DeepDTI', 'DeepDTA', 'DeepConv-DTI', 'Moltrans', 'ours']
150
+ clr_list = ['red', 'orange', 'green', 'indianred', 'lightseagreen', 'goldenrod', 'magenta', 'blue']
151
+
152
+ if metric == ROC:
153
+ # fig_title = "<b>ROC-AUC score distribution</b>"
154
+ file_title = "boxplot_auroc.png"
155
+ select_metric = "test_auroc"
156
+ else:
157
+ # fig_title = "<b>PR-AUC score distribution</b>"
158
+ file_title = "boxplot_auprc.png"
159
+ select_metric = "test_auprc"
160
+
161
+ fig = make_subplots(rows=1, cols=3, subplot_titles=[c for c in dataset_list])
162
+
163
+ groups = df.groupby(df.Task_name)
164
+ Legand = True
165
+
166
+ for dataset_idx, dataset in enumerate(dataset_list):
167
+ df_modelgroup = groups.get_group(dataset)
168
+ model_groups = df_modelgroup.groupby(df_modelgroup.Model)
169
+ if dataset_idx != 0:
170
+ Legand = False
171
+ for model_idx, model in enumerate(model_list):
172
+ df_data = model_groups.get_group(model)
173
+ fig.append_trace(go.Box(y=df_data[select_metric],
174
+ name=model,
175
+ marker_color=clr_list[model_idx],
176
+ showlegend = Legand
177
+ ),
178
+ row=1,
179
+ col=dataset_idx+1)
180
+
181
+
182
+
183
+
184
+ # fig.update_layout(title={'text': fig_title,
185
+ # 'font':{'size':25},
186
+ # 'y': 0.98,
187
+ # 'x': 0.46,
188
+ # 'xanchor': 'center',
189
+ # 'yanchor': 'top'})
190
+
191
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
192
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=2)
193
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=3)
194
+
195
+ fig.write_image(f'../figures/{file_title}', width=1.5*1200, height=0.75*1200, scale=2)
196
+ fig.show()
197
+
198
+
199
+ if __name__ == '__main__':
200
+ df = pd.read_csv("../dataset/wandb_export_boxplotdata.csv")
201
+ box_plot(df)
util/data/bindingdb_kd.tab ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72a38ae07a75d5d4c269d2776b6e62e0edde29ff7cf8a323158c08951f808d1
3
+ size 54432102
util/data/davis.tab ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d4c6809dcb7c5da2b91a32d594d6935b75484940bde4d18055eb5e1059262f4
3
+ size 21376712
util/emetric.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def get_cindex(Y, P):
4
+ summ = 0
5
+ pair = 0
6
+
7
+ for i in range(1, len(Y)):
8
+ for j in range(0, i):
9
+ if i is not j:
10
+ if(Y[i] > Y[j]):
11
+ pair +=1
12
+ summ += 1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])
13
+
14
+
15
+ if pair is not 0:
16
+ return summ/pair
17
+ else:
18
+ return 0
19
+
20
+
21
+ def r_squared_error(y_obs,y_pred):
22
+ y_obs = np.array(y_obs)
23
+ y_pred = np.array(y_pred)
24
+ y_obs_mean = [np.mean(y_obs) for y in y_obs]
25
+ y_pred_mean = [np.mean(y_pred) for y in y_pred]
26
+
27
+ mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
28
+ mult = mult * mult
29
+
30
+ y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
31
+ y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )
32
+
33
+ return mult / float(y_obs_sq * y_pred_sq)
34
+
35
+
36
+ def get_k(y_obs,y_pred):
37
+ y_obs = np.array(y_obs)
38
+ y_pred = np.array(y_pred)
39
+
40
+ return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))
41
+
42
+
43
+ def squared_error_zero(y_obs,y_pred):
44
+ k = get_k(y_obs,y_pred)
45
+
46
+ y_obs = np.array(y_obs)
47
+ y_pred = np.array(y_pred)
48
+ y_obs_mean = [np.mean(y_obs) for y in y_obs]
49
+ upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
50
+ down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
51
+
52
+ return 1 - (upp / float(down))
53
+
54
+
55
+ def get_rm2(ys_orig,ys_line):
56
+ r2 = r_squared_error(ys_orig, ys_line)
57
+ r02 = squared_error_zero(ys_orig, ys_line)
58
+
59
+ return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))
util/load_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tdc.multi_pred import DTI
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ if __name__ == '__main__':
7
+ bindingDB_data = DTI(name = 'BindingDB_Kd')
8
+ davis_data = DTI(name = 'DAVIS')
9
+
10
+ bindingDB_data.harmonize_affinities(mode = 'max_affinity')
11
+
12
+ bindingDB_data.convert_to_log(form = 'binding')
13
+ davis_data.convert_to_log(form = 'binding')
14
+
15
+ split_bindingDB = bindingDB_data.get_split()
16
+ split_davis = davis_data.get_split()
17
+
18
+ dataset_list = ["train", "valid", "test"]
19
+ for dataset_type in dataset_list:
20
+ df_bindingDB = pd.DataFrame(split_bindingDB[dataset_type])
21
+ df_davis = pd.DataFrame(split_davis[dataset_type])
22
+
23
+ df_bindingDB.to_csv(f"../dataset_kd/bindingDB_{dataset_type}.csv", index=False)
24
+ df_davis.to_csv(f"../dataset_kd/davis_{dataset_type}.csv", index=False)
25
+
26
+
27
+ Y_bindingDB = np.array(df_bindingDB.Y)
28
+ Y_davis = np.array(df_davis.Y)
29
+
30
+ Y_davis_log = [np.log10(Y_davis)]
31
+
32
+
util/make_external_validation.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ if __name__ == '__main__':
6
+ smiles = pd.read_csv("../dataset/external_smiles.csv")
7
+ ass = pd.read_csv("../dataset/external_aas.csv")
8
+
9
+ smiles_data = list(np.array(smiles['smiles']))
10
+ smiles_label = list(np.array(smiles['label'].tolist()))
11
+ smiles_label = [x.split() for x in smiles_label]
12
+
13
+ ass_data = list(np.array(ass['aas']))
14
+ cyp_type = list(np.array(ass['CYP_type']))
15
+
16
+ external_dataset = []
17
+ for smiles_idx in range(0, len(smiles_data)):
18
+ for ass_idx in range(0, len(ass_data)):
19
+
20
+ external_data = [smiles_data[smiles_idx], ass_data[ass_idx], cyp_type[ass_idx]]
21
+ external_dataset.append(external_data)
22
+
23
+ df = pd.DataFrame(external_dataset, columns=['smiles', 'aas', 'CYP_type'])
24
+ df.to_csv('../dataset/external_dataset.csv', index=False)
25
+
26
+
27
+ print(smiles['smiles'][0])
28
+ print(ass['CYP_type'][0])
util/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, copy
2
+ from easydict import EasyDict
3
+
4
+ import torch.nn as nn
5
+
6
+ class DictX(dict):
7
+ def __getattr__(self, key):
8
+ try:
9
+ return self[key]
10
+ except KeyError as k:
11
+ raise AttributeError(k)
12
+
13
+ def __setattr__(self, key, value):
14
+ self[key] = value
15
+
16
+ def __delattr__(self, key):
17
+ try:
18
+ del self[key]
19
+ except KeyError as k:
20
+ raise AttributeError(k)
21
+
22
+ def __repr__(self):
23
+ return '<DictX ' + dict.__repr__(self) + '>'
24
+
25
+
26
+ def load_hparams(file_path):
27
+ hparams = EasyDict()
28
+ with open(file_path, 'r') as f:
29
+ hparams = json.load(f)
30
+ return hparams
31
+
32
+
33
+ def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
34
+ oldModuleList = model.encoder.layer
35
+ newModuleList = nn.ModuleList()
36
+
37
+ # Now iterate over all layers, only keepign only the relevant layers.
38
+ for i in range(num_layers_to_keep):
39
+ newModuleList.append(oldModuleList[i])
40
+
41
+ # create a copy of the model, modify it with the new list, and return
42
+ copyOfModel = copy.deepcopy(model)
43
+ copyOfModel.encoder.layer = newModuleList
44
+
45
+ return copyOfModel