wenkai commited on
Commit
4a1f168
1 Parent(s): d376991

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -35,3 +35,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
37
  assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
37
  assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
38
+ data/go1.4-basic.obo filter=lfs diff=lfs merge=lfs -text
39
+ data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
40
+ data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
41
+ data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
data/emb_esm2_3b/P18281.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91714943ae1d08f860e86cfcd098f3973dc14ca63d88556223223fc9687ac7ec
3
+ size 901864
data/evaluate_data/evaluate_cases.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ import random
4
+ import Levenshtein
5
+ import numpy as np
6
+ import difflib
7
+ # from torchmetrics.text import BLEUScore
8
+ import time
9
+ from multiprocessing import Pool, Queue, Process
10
+ import matplotlib.pyplot as plt
11
+ from data.evaluate_data.utils import Ontology
12
+ # bleu = BLEUScore(n_gram=1)
13
+
14
+ def fuzzy_match(texts):
15
+ text_dict = {}
16
+ for context in texts:
17
+ if context not in choices:
18
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
19
+ text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
20
+ return text_dict
21
+
22
+
23
+ def get_sim(text, label):
24
+ all_s = []
25
+ for x in label:
26
+ s = 0
27
+ for y in text:
28
+ temp = Levenshtein.ratio(x, y)
29
+ if temp > s:
30
+ s = temp
31
+ all_s.append(s)
32
+ all_s = [round(i, 3) for i in all_s]
33
+
34
+ # bs = [bleu(x, [label]) for x in text]
35
+ return all_s
36
+
37
+
38
+ def txt_map(x, txt_dict):
39
+ if type(x) == str:
40
+ x = eval(x)
41
+ x_ = []
42
+ for i in x:
43
+ if i == '':
44
+ continue
45
+ if i in txt_dict:
46
+ x_.append(txt_dict[i])
47
+ else:
48
+ x_.append(i)
49
+ return x_
50
+
51
+
52
+ def go_map(t):
53
+ if t in GO_dict:
54
+ return GO_dict[t]
55
+ else:
56
+ print(t)
57
+
58
+
59
+ def get_term(df):
60
+ from collections import Counter
61
+ cnt = Counter()
62
+ for i, row in enumerate(df.itertuples()):
63
+ for term in row.prop_annotations:
64
+ cnt[term] += 1
65
+ terms = list(cnt.keys())
66
+ # remove top
67
+ for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
68
+ if top_term in terms:
69
+ terms.remove(top_term)
70
+ terms_df = pd.DataFrame({'gos': terms})
71
+ terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
72
+
73
+
74
+ if __name__ == "__main__":
75
+ go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
76
+ go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
77
+ go_des.columns = ['GO', 'function']
78
+ go_des = go_des[go_des['function'].notnull()]
79
+ go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
80
+ go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
81
+ GO_dict = dict(zip(go_des['function'], go_des['GO']))
82
+
83
+ data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
84
+ data.columns = ['protein', 'pred', 'label']
85
+ data['label'] = data['label'].apply(lambda x: x.lower())
86
+ data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
87
+
88
+ data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
89
+ data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
90
+
91
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test.csv', sep='|')
92
+ test = test.drop_duplicates()
93
+ test['function'] = test['function'].apply(lambda x: x.lower().strip())
94
+ test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
95
+ test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
96
+
97
+ test_dict = dict()
98
+ for x, y in zip(test['function'], test['GO_label']):
99
+ temp = dict(zip(x, y))
100
+ test_dict.update(temp)
101
+ GO_dict.update(test_dict)
102
+
103
+ choices = list(test_dict.keys())
104
+
105
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
106
+ '''
107
+ print("找到与预测文本最相似的GO标签......")
108
+ t0 = time.time()
109
+ txt_dict = {}
110
+
111
+ all_txt = []
112
+ for txt in data['pred_list']:
113
+ if type(txt) == str:
114
+ all_txt.extend(eval(txt))
115
+ else:
116
+ all_txt.extend(txt)
117
+ all_txt = list(set(all_txt))
118
+
119
+ n = len(all_txt)
120
+ thread = 10
121
+ size = int(n/thread)
122
+ inds = list(range(0, n, size))
123
+ inds.append(n)
124
+ all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
125
+
126
+ with Pool(processes=thread) as pool:
127
+ result = pool.map(fuzzy_match, all_txt_sep)
128
+ pool.close()
129
+ pool.join()
130
+ for d in result:
131
+ txt_dict.update(d)
132
+
133
+ # for txt in all_txt[:10]:
134
+ # fuzzy_match(txt)
135
+
136
+ data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
137
+ data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
138
+ print("fuzzy matching time: {}".format(time.time() - t0))
139
+
140
+ print("calculating f1 score ......")
141
+ data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
142
+ data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
143
+ '''
144
+
145
+ # 准备case测试数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
146
+ prepare_ancestors = True
147
+ if prepare_ancestors:
148
+ print("准备加入祖先后的数据......")
149
+ def prop(df):
150
+ prop_annotations = []
151
+ for i, row in df.iterrows():
152
+ # Propagate annotations
153
+ annot_set = set()
154
+ annots = row['GO_label']
155
+ for go_id in annots:
156
+ annot_set |= go.get_anchestors(go_id)
157
+ annots = list(annot_set)
158
+ prop_annotations.append(annots)
159
+ df['prop_annotations'] = prop_annotations
160
+ return df
161
+
162
+ def pred_text_to_go(df):
163
+ df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
164
+
165
+ df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
166
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
167
+ t0 = time.time()
168
+ txt_dict = {}
169
+
170
+ all_txt = []
171
+ for txt in df['pred_list']:
172
+ if type(txt) == str:
173
+ all_txt.extend(eval(txt))
174
+ else:
175
+ all_txt.extend(txt)
176
+
177
+ all_txt = list(set(all_txt))
178
+ if '' in all_txt:
179
+ all_txt.remove('')
180
+
181
+ n = len(all_txt)
182
+ thread = 10
183
+ size = int(n / thread)
184
+ inds = list(range(0, n, size))
185
+ inds.append(n)
186
+ all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
187
+
188
+ with Pool(processes=thread) as pool:
189
+ result = pool.map(fuzzy_match, all_txt_sep)
190
+ pool.close()
191
+ pool.join()
192
+ for d in result:
193
+ txt_dict.update(d)
194
+
195
+ # for txt in all_txt[:10]:
196
+ # fuzzy_match(txt)
197
+
198
+ df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
199
+ df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
200
+ print("fuzzy matching time: {}".format(time.time() - t0))
201
+
202
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
203
+ return df
204
+
205
+
206
+ test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
207
+ test_pred.columns = ['protein', 'pred', 'GO_label']
208
+ test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
209
+ test_pred = prop(test_pred)
210
+ test_pred = pred_text_to_go(test_pred)
211
+
212
+ for cat in ['mf', 'bp', 'cc']:
213
+ test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_case.pkl'.format(cat))
data/evaluate_data/evaluate_pretrain.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ import random
4
+ import Levenshtein
5
+ import numpy as np
6
+ import difflib
7
+ # from torchmetrics.text import BLEUScore
8
+ import time
9
+ from multiprocessing import Pool, Queue, Process
10
+ import matplotlib.pyplot as plt
11
+ from data.evaluate_data.utils import Ontology
12
+ # bleu = BLEUScore(n_gram=1)
13
+
14
+ def fuzzy_match(texts):
15
+ text_dict = {}
16
+ for context in texts:
17
+ if context not in choices:
18
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
19
+ sim_list = difflib.get_close_matches(context, choices, n=1, cutoff=0.93)
20
+ if len(sim_list) > 0:
21
+ text_dict[context] = sim_list[0]
22
+ else:
23
+ text_dict[context] = ''
24
+ return text_dict
25
+
26
+
27
+ def get_sim(text, label):
28
+ all_s = []
29
+ for x in label:
30
+ s = 0
31
+ for y in text:
32
+ temp = Levenshtein.ratio(x, y)
33
+ if temp > s:
34
+ s = temp
35
+ all_s.append(s)
36
+ all_s = [round(i, 3) for i in all_s]
37
+
38
+ # bs = [bleu(x, [label]) for x in text]
39
+ return all_s
40
+
41
+
42
+ def txt_map(x, txt_dict):
43
+ if type(x) == str:
44
+ x = eval(x)
45
+ x_ = []
46
+ for i in x:
47
+ if i == '':
48
+ continue
49
+ if i in txt_dict:
50
+ x_.append(txt_dict[i])
51
+ else:
52
+ x_.append(i)
53
+ return x_
54
+
55
+
56
+ def go_map(t):
57
+ if t in GO_dict:
58
+ return GO_dict[t]
59
+ else:
60
+ pass
61
+ #print(t)
62
+
63
+
64
+ def get_term(df):
65
+ from collections import Counter
66
+ cnt = Counter()
67
+ for i, row in enumerate(df.itertuples()):
68
+ for term in row.prop_annotations:
69
+ cnt[term] += 1
70
+ terms = list(cnt.keys())
71
+ # remove top
72
+ for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
73
+ if top_term in terms:
74
+ terms.remove(top_term)
75
+ terms_df = pd.DataFrame({'gos': terms})
76
+ terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
77
+
78
+
79
+ if __name__ == "__main__":
80
+ go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
81
+ go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
82
+ go_des.columns = ['GO', 'function']
83
+ go_des = go_des[go_des['function'].notnull()]
84
+ go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
85
+ go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
86
+ GO_dict = dict(zip(go_des['function'], go_des['GO']))
87
+
88
+ data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_go_train.txt', sep='|', header=None, on_bad_lines='skip')
89
+ data.columns = ['name', 'pred', 'label']
90
+ #data['label'] = data['label'].apply(lambda x: x.lower())
91
+ data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
92
+
93
+ #data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
94
+ data['pred_list'] = data['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
95
+
96
+ #train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
97
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
98
+ test = test.drop_duplicates()
99
+ test['function'] = test['function'].apply(lambda x: x.lower().strip())
100
+ test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
101
+ test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
102
+
103
+ data = pd.merge(data, test[['name', 'function']], on='name', how='left')
104
+ data['label_list'] = data['function']
105
+
106
+ test_dict = dict()
107
+ for x, y in zip(test['function'], test['GO_label']):
108
+ temp = dict(zip(x, y))
109
+ test_dict.update(temp)
110
+ GO_dict.update(test_dict)
111
+
112
+ choices = list(test_dict.keys())
113
+
114
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
115
+ print("找到与预测文本最相似的GO标签......")
116
+ t0 = time.time()
117
+ txt_dict = {}
118
+
119
+ all_txt = []
120
+ for txt in data['pred_list']:
121
+ if type(txt) == str:
122
+ all_txt.extend(eval(txt))
123
+ else:
124
+ all_txt.extend(txt)
125
+ all_txt = list(set(all_txt))
126
+
127
+ n = len(all_txt)
128
+ thread = 40
129
+ size = int(n/thread)
130
+ inds = list(range(0, n, size))
131
+ inds.append(n)
132
+ all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
133
+
134
+ with Pool(processes=thread) as pool:
135
+ result = pool.map(fuzzy_match, all_txt_sep)
136
+ pool.close()
137
+ pool.join()
138
+ for d in result:
139
+ txt_dict.update(d)
140
+
141
+ # for txt in all_txt[:10]:
142
+ # fuzzy_match(txt)
143
+
144
+ data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
145
+ data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
146
+ print("fuzzy matching time: {}".format(time.time() - t0))
147
+
148
+ print("calculating f1 score ......")
149
+ data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
150
+ data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
151
+
152
+
153
+ labels = []
154
+ pred_labels = []
155
+ for l in data['label_list_go']:
156
+ if type(l) == str:
157
+ l = eval(l)
158
+ labels.extend(l)
159
+
160
+ label_count = {}
161
+ for x in labels:
162
+ if x not in label_count:
163
+ label_count[x] = 1
164
+ else:
165
+ label_count[x] += 1
166
+
167
+ labels = list(set(labels))
168
+ total = len(labels)
169
+ recalls = []
170
+ precisions = []
171
+ tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
172
+ for preds, label in zip(data['pred_list_go'], data['label_list_go']):
173
+ if type(label) == str:
174
+ label = eval(label)
175
+ if type(preds) == str:
176
+ txts = eval(preds)
177
+ ll = len(label)
178
+ for t in label:
179
+ # supgo = go.get_anchestors(t)
180
+ # if supgo.intersection(set(preds)):
181
+ if t in preds:
182
+ tp_dict[t] += 1
183
+ else:
184
+ fn_dict[t] += 1
185
+ for p in preds:
186
+ # supgo = go.get_anchestors(p)
187
+ # if not supgo.intersection(set(label)):
188
+ if p not in label:
189
+ if p in fp_dict:
190
+ fp_dict[p] += 1
191
+ else:
192
+ fp_dict[p] = 1
193
+ pred_labels.extend(preds)
194
+ p_total = len(set(pred_labels))
195
+ recall, pr = 0., 0.
196
+ for x in labels:
197
+ recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
198
+ pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
199
+ r = recall / total
200
+ p = pr / p_total
201
+ f1 = 2 * p * r / (p + r)
202
+
203
+ print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
204
+ print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
205
+
206
+
207
+ # 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
208
+ prepare_ancestors = False
209
+ if prepare_ancestors:
210
+ print("准备加入祖先后的数据......")
211
+ def prop(df):
212
+ prop_annotations = []
213
+ for i, row in df.iterrows():
214
+ # Propagate annotations
215
+ annot_set = set()
216
+ annots = row['GO_label']
217
+ for go_id in annots:
218
+ annot_set |= go.get_anchestors(go_id)
219
+ annots = list(annot_set)
220
+ prop_annotations.append(annots)
221
+ df['prop_annotations'] = prop_annotations
222
+ return df
223
+
224
+ def remove_nan(x):
225
+ if '' in x:
226
+ x.remove('')
227
+ return x
228
+
229
+ def pred_text_to_go(df):
230
+ df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
231
+
232
+ df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
233
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
234
+ t0 = time.time()
235
+ txt_dict = {}
236
+
237
+ all_txt = []
238
+ for txt in df['pred_list']:
239
+ if type(txt) == str:
240
+ all_txt.extend(eval(txt))
241
+ else:
242
+ all_txt.extend(txt)
243
+
244
+ all_txt = list(set(all_txt))
245
+ if '' in all_txt:
246
+ all_txt.remove('')
247
+
248
+ n = len(all_txt)
249
+ thread = 40
250
+ size = int(n / thread)
251
+ inds = list(range(0, n, size))
252
+ inds.append(n)
253
+ all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
254
+
255
+ with Pool(processes=thread) as pool:
256
+ result = pool.map(fuzzy_match, all_txt_sep)
257
+ pool.close()
258
+ pool.join()
259
+ for d in result:
260
+ txt_dict.update(d)
261
+
262
+ # for txt in all_txt[:10]:
263
+ # fuzzy_match(txt)
264
+
265
+ df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
266
+ df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
267
+ df['pred_list'] = df['pred_list'].apply(lambda x: remove_nan(x))
268
+ print("fuzzy matching time: {}".format(time.time() - t0))
269
+
270
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
271
+ return df
272
+
273
+
274
+ test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/pretrain/output_pretrain.txt', sep='|', header=None)
275
+ test_pred.columns = ['protein', 'pred', 'GO_label']
276
+ test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
277
+ test_pred = test_pred(test)
278
+ get_term(test)
279
+ test_pred = pred_text_to_go(test_pred)
280
+
281
+ test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_pretrain.pkl')
282
+
data/evaluate_data/evaluate_with_ancestors.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ import random
4
+ import Levenshtein
5
+ import numpy as np
6
+ import difflib
7
+ # from torchmetrics.text import BLEUScore
8
+ import time
9
+ from multiprocessing import Pool, Queue, Process
10
+ import matplotlib.pyplot as plt
11
+ from data.evaluate_data.utils import Ontology
12
+ # bleu = BLEUScore(n_gram=1)
13
+
14
+ def fuzzy_match(texts):
15
+ text_dict = {}
16
+ for context in texts:
17
+ if context not in choices:
18
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
19
+ text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
20
+ return text_dict
21
+
22
+
23
+ def get_sim(text, label):
24
+ all_s = []
25
+ for x in label:
26
+ s = 0
27
+ for y in text:
28
+ temp = Levenshtein.ratio(x, y)
29
+ if temp > s:
30
+ s = temp
31
+ all_s.append(s)
32
+ all_s = [round(i, 3) for i in all_s]
33
+
34
+ # bs = [bleu(x, [label]) for x in text]
35
+ return all_s
36
+
37
+
38
+ def txt_map(x, txt_dict):
39
+ if type(x) == str:
40
+ x = eval(x)
41
+ x_ = []
42
+ for i in x:
43
+ if i == '':
44
+ continue
45
+ if i in txt_dict:
46
+ x_.append(txt_dict[i])
47
+ else:
48
+ x_.append(i)
49
+ return x_
50
+
51
+
52
+ def go_map(t):
53
+ if t in GO_dict:
54
+ return GO_dict[t]
55
+ else:
56
+ print(t)
57
+
58
+
59
+ def get_term(df):
60
+ from collections import Counter
61
+ cnt = Counter()
62
+ for i, row in enumerate(df.itertuples()):
63
+ for term in row.prop_annotations:
64
+ cnt[term] += 1
65
+ terms = list(cnt.keys())
66
+ # remove top
67
+ for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
68
+ if top_term in terms:
69
+ terms.remove(top_term)
70
+ terms_df = pd.DataFrame({'gos': terms})
71
+ terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
72
+
73
+
74
+ if __name__ == "__main__":
75
+ cat = 'mf'
76
+
77
+ go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
78
+ go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
79
+ go_des.columns = ['GO', 'function']
80
+ go_des = go_des[go_des['function'].notnull()]
81
+ go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
82
+ go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
83
+ GO_dict = dict(zip(go_des['function'], go_des['GO']))
84
+
85
+
86
+ data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
87
+
88
+ data['label'] = data['label'].apply(lambda x: x.lower())
89
+ data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
90
+
91
+ data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
92
+ data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
93
+
94
+ train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
95
+ train = train.drop_duplicates()
96
+ train['function'] = train['function'].apply(lambda x: x.lower().strip())
97
+ train_dict = dict(zip(train['function'], train['GO_label']))
98
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
99
+ test = test.drop_duplicates()
100
+ test['function'] = test['function'].apply(lambda x: x.lower().strip())
101
+ test_dict = dict(zip(test['function'], test['GO_label']))
102
+ GO_dict.update(train_dict)
103
+ GO_dict.update(test_dict)
104
+
105
+ choices = []
106
+ for x in data['label_list'].tolist() + train['function'].tolist():
107
+ choices.extend(x)
108
+ choices = list(set(choices))
109
+
110
+
111
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
112
+ print("找到与预测文本最相似的GO标签......")
113
+ t0 = time.time()
114
+ txt_dict = {}
115
+
116
+ all_txt = []
117
+ for txt in data['pred_list']:
118
+ if type(txt) == str:
119
+ all_txt.extend(eval(txt))
120
+ else:
121
+ all_txt.extend(txt)
122
+ all_txt = list(set(all_txt))
123
+
124
+ n = len(all_txt)
125
+ thread = 40
126
+ size = int(n/thread)
127
+ inds = list(range(0, n, size))
128
+ inds.append(n)
129
+ all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
130
+
131
+ with Pool(processes=thread) as pool:
132
+ result = pool.map(fuzzy_match, all_txt_sep)
133
+ pool.close()
134
+ pool.join()
135
+ for d in result:
136
+ txt_dict.update(d)
137
+
138
+ # for txt in all_txt[:10]:
139
+ # fuzzy_match(txt)
140
+
141
+ data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
142
+ data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
143
+ print("fuzzy matching time: {}".format(time.time() - t0))
144
+
145
+
146
+ # sims = []
147
+ # for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
148
+ # a = get_sim(text, label)
149
+ # sims.append(a)
150
+ #
151
+ # data['sim'] = sims
152
+ # data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
153
+ # print("simlarity: {}".format(data['avg_sim'].mean()))
154
+
155
+
156
+ print("calculating f1 score ......")
157
+ data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
158
+ data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
159
+
160
+
161
+ labels = []
162
+ pred_labels = []
163
+ for l in data['label_list_go']:
164
+ if type(l) == str:
165
+ l = eval(l)
166
+ labels.extend(l)
167
+
168
+ label_count = {}
169
+ for x in labels:
170
+ if x not in label_count:
171
+ label_count[x] = 1
172
+ else:
173
+ label_count[x] += 1
174
+
175
+ labels = list(set(labels))
176
+ total = len(labels)
177
+ recalls = []
178
+ precisions = []
179
+ tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
180
+ for preds, label in zip(data['pred_list_go'], data['label_list_go']):
181
+ if type(label) == str:
182
+ label = eval(label)
183
+ if type(preds) == str:
184
+ txts = eval(preds)
185
+ ll = len(label)
186
+ for t in label:
187
+ supgo = go.get_anchestors(t)
188
+ if supgo.intersection(set(preds)):
189
+ tp_dict[t] += 1
190
+ else:
191
+ fn_dict[t] += 1
192
+ for p in preds:
193
+ supgo = go.get_anchestors(p)
194
+ if not supgo.intersection(set(label)):
195
+ if p in fp_dict:
196
+ fp_dict[p] += 1
197
+ else:
198
+ fp_dict[p] = 1
199
+ pred_labels.extend(preds)
200
+ p_total = len(set(pred_labels))
201
+ recall, pr = 0., 0.
202
+ for x in labels:
203
+ recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
204
+ pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
205
+ r = recall / total
206
+ p = pr / p_total
207
+ f1 = 2 * p * r / (p + r)
208
+
209
+ print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
210
+ print("f1 score: {}".format(f1))
211
+
212
+ '''
213
+ cat_f1 = {}
214
+ for x in labels:
215
+ if tp_dict[x] + fn_dict[x] > 0:
216
+ re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
217
+ pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
218
+ cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
219
+
220
+ plt.xlabel('f score')
221
+ plt.ylabel('count')
222
+ print(np.mean(list(cat_f1.values())))
223
+ plt.hist(list(cat_f1.values()), color='red', bins=30)
224
+ plt.show()
225
+
226
+ xs, ys = [], []
227
+ for x in labels:
228
+ xs.append(label_count[x])
229
+ ys.append(cat_f1[x])
230
+ df_count = pd.DataFrame({'xs': xs, 'ys': ys})
231
+ df_count['xs'].loc[df_count['xs'] > 10] = 11
232
+ df_count['xs'] = df_count['xs'].astype(str)
233
+ df_count1 = df_count.groupby('xs').mean().reset_index()
234
+ df_count2 = df_count.groupby('xs').count().reset_index()
235
+
236
+ plt.xlabel('label count')
237
+ plt.ylabel('f score mean')
238
+ df_count1['xs'] = df_count1['xs'].astype(int)
239
+ plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
240
+ plt.show()
241
+
242
+ plt.xlabel('label count')
243
+ plt.ylabel('protein num')
244
+ df_count2['xs'] = df_count2['xs'].astype(int)
245
+ plt.bar(df_count2['xs'], df_count2['ys'], color='red')
246
+ plt.show()
247
+ '''
248
+
249
+
250
+ # 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
251
+ print("准备加入祖先后的数据......")
252
+ train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
253
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
254
+ train = train.groupby('name').agg({'GO_label': list}).reset_index()
255
+ test = test.groupby('name').agg({'GO_label': list}).reset_index()
256
+
257
+ def prop(df):
258
+ prop_annotations = []
259
+ for i, row in df.iterrows():
260
+ # Propagate annotations
261
+ annot_set = set()
262
+ annots = row['GO_label']
263
+ for go_id in annots:
264
+ annot_set |= go.get_anchestors(go_id)
265
+ annots = list(annot_set)
266
+ prop_annotations.append(annots)
267
+ df['prop_annotations'] = prop_annotations
268
+ return df
269
+
270
+ train = prop(train)
271
+ test = prop(test)
272
+
273
+ train_test = pd.concat([train, test])
274
+ get_term(train_test)
275
+ del train_test
276
+
277
+ def pred_text_to_go(df):
278
+ df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
279
+
280
+ df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
281
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
282
+ t0 = time.time()
283
+ txt_dict = {}
284
+
285
+ all_txt = []
286
+ for txt in df['pred_list']:
287
+ if type(txt) == str:
288
+ all_txt.extend(eval(txt))
289
+ else:
290
+ all_txt.extend(txt)
291
+
292
+ all_txt = list(set(all_txt))
293
+ if '' in all_txt:
294
+ all_txt.remove('')
295
+
296
+ n = len(all_txt)
297
+ thread = 40
298
+ size = int(n / thread)
299
+ inds = list(range(0, n, size))
300
+ inds.append(n)
301
+ all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
302
+
303
+ with Pool(processes=thread) as pool:
304
+ result = pool.map(fuzzy_match, all_txt_sep)
305
+ pool.close()
306
+ pool.join()
307
+ for d in result:
308
+ txt_dict.update(d)
309
+
310
+ # for txt in all_txt[:10]:
311
+ # fuzzy_match(txt)
312
+
313
+ df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
314
+ df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
315
+ print("fuzzy matching time: {}".format(time.time() - t0))
316
+
317
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
318
+ return df
319
+
320
+
321
+ train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_train{}.csv'.format(cat), sep='|')
322
+ test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
323
+
324
+ train_pred = pred_text_to_go(train_pred)
325
+ test_pred = pred_text_to_go(test_pred)
326
+
327
+ train_data = pd.merge(train[['name', 'prop_annotations']],
328
+ train_pred[['name', 'pred_list_go']],
329
+ on='name', how='inner')
330
+ train_data = train_data.drop_duplicates('name')
331
+ train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
332
+
333
+ test_data = pd.merge(test[['name', 'prop_annotations']],
334
+ test_pred[['name', 'pred_list_go']],
335
+ on='name', how='inner')
336
+ test_data = test_data.drop_duplicates('name')
337
+ test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
338
+ test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
339
+
data/evaluate_data/evaluate_with_ancestors_exp.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+ import random
4
+ import Levenshtein
5
+ import numpy as np
6
+ import difflib
7
+ # from torchmetrics.text import BLEUScore
8
+ import time
9
+ from multiprocessing import Pool, Queue, Process
10
+ import matplotlib.pyplot as plt
11
+ from data.evaluate_data.utils import Ontology
12
+ # bleu = BLEUScore(n_gram=1)
13
+
14
+ def fuzzy_match(texts):
15
+ text_dict = {}
16
+ for context in texts:
17
+ if context not in choices:
18
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
19
+ text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
20
+ return text_dict
21
+
22
+
23
+ def get_sim(text, label):
24
+ all_s = []
25
+ for x in label:
26
+ s = 0
27
+ for y in text:
28
+ temp = Levenshtein.ratio(x, y)
29
+ if temp > s:
30
+ s = temp
31
+ all_s.append(s)
32
+ all_s = [round(i, 3) for i in all_s]
33
+
34
+ # bs = [bleu(x, [label]) for x in text]
35
+ return all_s
36
+
37
+
38
+ def txt_map(x, txt_dict):
39
+ if type(x) == str:
40
+ x = eval(x)
41
+ x_ = []
42
+ for i in x:
43
+ if i == '':
44
+ continue
45
+ if i in txt_dict:
46
+ x_.append(txt_dict[i])
47
+ else:
48
+ x_.append(i)
49
+ return x_
50
+
51
+
52
+ def go_map(t):
53
+ if t in GO_dict:
54
+ return GO_dict[t]
55
+ else:
56
+ print(t)
57
+
58
+
59
+ def get_term(df):
60
+ from collections import Counter
61
+ cnt = Counter()
62
+ for i, row in enumerate(df.itertuples()):
63
+ for term in row.prop_annotations:
64
+ cnt[term] += 1
65
+ terms = list(cnt.keys())
66
+ # remove top
67
+ for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
68
+ if top_term in terms:
69
+ terms.remove(top_term)
70
+ terms_df = pd.DataFrame({'gos': terms})
71
+ terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
72
+
73
+
74
+ if __name__ == "__main__":
75
+ cat = 'mf'
76
+
77
+ go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
78
+ go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
79
+ go_des.columns = ['GO', 'function']
80
+ go_des = go_des[go_des['function'].notnull()]
81
+ go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
82
+ go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
83
+ GO_dict = dict(zip(go_des['function'], go_des['GO']))
84
+
85
+
86
+ data = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
87
+
88
+ data['label'] = data['label'].apply(lambda x: x.lower())
89
+ data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
90
+
91
+ data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
92
+ data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
93
+
94
+ train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
95
+ train = train.drop_duplicates()
96
+ train['function'] = train['function'].apply(lambda x: x.lower().strip())
97
+ train_dict = dict(zip(train['function'], train['GO_label']))
98
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
99
+ test = test.drop_duplicates()
100
+ test['function'] = test['function'].apply(lambda x: x.lower().strip())
101
+ test_dict = dict(zip(test['function'], test['GO_label']))
102
+ GO_dict.update(train_dict)
103
+ GO_dict.update(test_dict)
104
+
105
+ choices = []
106
+ for x in data['label_list'].tolist() + train['function'].tolist():
107
+ choices.extend(x)
108
+ choices = list(set(choices))
109
+
110
+
111
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
112
+ print("找到与预测文本最相似的GO标签......")
113
+ t0 = time.time()
114
+ txt_dict = {}
115
+
116
+ all_txt = []
117
+ for txt in data['pred_list']:
118
+ if type(txt) == str:
119
+ all_txt.extend(eval(txt))
120
+ else:
121
+ all_txt.extend(txt)
122
+ all_txt = list(set(all_txt))
123
+
124
+ n = len(all_txt)
125
+ thread = 40
126
+ size = int(n/thread)
127
+ inds = list(range(0, n, size))
128
+ inds.append(n)
129
+ all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
130
+
131
+ with Pool(processes=thread) as pool:
132
+ result = pool.map(fuzzy_match, all_txt_sep)
133
+ pool.close()
134
+ pool.join()
135
+ for d in result:
136
+ txt_dict.update(d)
137
+
138
+ # for txt in all_txt[:10]:
139
+ # fuzzy_match(txt)
140
+
141
+ data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
142
+ data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
143
+ print("fuzzy matching time: {}".format(time.time() - t0))
144
+
145
+
146
+ # sims = []
147
+ # for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
148
+ # a = get_sim(text, label)
149
+ # sims.append(a)
150
+ #
151
+ # data['sim'] = sims
152
+ # data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
153
+ # print("simlarity: {}".format(data['avg_sim'].mean()))
154
+
155
+
156
+ print("calculating f1 score ......")
157
+ data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
158
+ data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
159
+
160
+
161
+ labels = []
162
+ pred_labels = []
163
+ for l in data['label_list_go']:
164
+ if type(l) == str:
165
+ l = eval(l)
166
+ labels.extend(l)
167
+
168
+ label_count = {}
169
+ for x in labels:
170
+ if x not in label_count:
171
+ label_count[x] = 1
172
+ else:
173
+ label_count[x] += 1
174
+
175
+ labels = list(set(labels))
176
+ total = len(labels)
177
+ recalls = []
178
+ precisions = []
179
+ tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
180
+ for preds, label in zip(data['pred_list_go'], data['label_list_go']):
181
+ if type(label) == str:
182
+ label = eval(label)
183
+ if type(preds) == str:
184
+ txts = eval(preds)
185
+ ll = len(label)
186
+ for t in label:
187
+ supgo = go.get_anchestors(t)
188
+ if supgo.intersection(set(preds)):
189
+ tp_dict[t] += 1
190
+ else:
191
+ fn_dict[t] += 1
192
+ for p in preds:
193
+ supgo = go.get_anchestors(p)
194
+ if not supgo.intersection(set(label)):
195
+ if p in fp_dict:
196
+ fp_dict[p] += 1
197
+ else:
198
+ fp_dict[p] = 1
199
+ pred_labels.extend(preds)
200
+ p_total = len(set(pred_labels))
201
+ recall, pr = 0., 0.
202
+ for x in labels:
203
+ recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
204
+ pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
205
+ r = recall / total
206
+ p = pr / p_total
207
+ f1 = 2 * p * r / (p + r)
208
+
209
+ print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
210
+ print("f1 score: {}".format(f1))
211
+
212
+ '''
213
+ cat_f1 = {}
214
+ for x in labels:
215
+ if tp_dict[x] + fn_dict[x] > 0:
216
+ re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
217
+ pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
218
+ cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
219
+
220
+ plt.xlabel('f score')
221
+ plt.ylabel('count')
222
+ print(np.mean(list(cat_f1.values())))
223
+ plt.hist(list(cat_f1.values()), color='red', bins=30)
224
+ plt.show()
225
+
226
+ xs, ys = [], []
227
+ for x in labels:
228
+ xs.append(label_count[x])
229
+ ys.append(cat_f1[x])
230
+ df_count = pd.DataFrame({'xs': xs, 'ys': ys})
231
+ df_count['xs'].loc[df_count['xs'] > 10] = 11
232
+ df_count['xs'] = df_count['xs'].astype(str)
233
+ df_count1 = df_count.groupby('xs').mean().reset_index()
234
+ df_count2 = df_count.groupby('xs').count().reset_index()
235
+
236
+ plt.xlabel('label count')
237
+ plt.ylabel('f score mean')
238
+ df_count1['xs'] = df_count1['xs'].astype(int)
239
+ plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
240
+ plt.show()
241
+
242
+ plt.xlabel('label count')
243
+ plt.ylabel('protein num')
244
+ df_count2['xs'] = df_count2['xs'].astype(int)
245
+ plt.bar(df_count2['xs'], df_count2['ys'], color='red')
246
+ plt.show()
247
+ '''
248
+
249
+
250
+ # 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
251
+ print("准备加入祖先后的数据......")
252
+ train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
253
+ test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
254
+ train = train.groupby('name').agg({'GO_label': list}).reset_index()
255
+ test = test.groupby('name').agg({'GO_label': list}).reset_index()
256
+
257
+ def prop(df):
258
+ prop_annotations = []
259
+ for i, row in df.iterrows():
260
+ # Propagate annotations
261
+ annot_set = set()
262
+ annots = row['GO_label']
263
+ for go_id in annots:
264
+ annot_set |= go.get_anchestors(go_id)
265
+ annots = list(annot_set)
266
+ prop_annotations.append(annots)
267
+ df['prop_annotations'] = prop_annotations
268
+ return df
269
+
270
+ train = prop(train)
271
+ test = prop(test)
272
+
273
+ train_test = pd.concat([train, test])
274
+ get_term(train_test)
275
+ del train_test
276
+
277
+ def pred_text_to_go(df):
278
+ df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
279
+
280
+ df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
281
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
282
+ t0 = time.time()
283
+ txt_dict = {}
284
+
285
+ all_txt = []
286
+ for txt in df['pred_list']:
287
+ if type(txt) == str:
288
+ all_txt.extend(eval(txt))
289
+ else:
290
+ all_txt.extend(txt)
291
+
292
+ all_txt = list(set(all_txt))
293
+ if '' in all_txt:
294
+ all_txt.remove('')
295
+
296
+ n = len(all_txt)
297
+ thread = 40
298
+ size = int(n / thread)
299
+ inds = list(range(0, n, size))
300
+ inds.append(n)
301
+ all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
302
+
303
+ with Pool(processes=thread) as pool:
304
+ result = pool.map(fuzzy_match, all_txt_sep)
305
+ pool.close()
306
+ pool.join()
307
+ for d in result:
308
+ txt_dict.update(d)
309
+
310
+ # for txt in all_txt[:10]:
311
+ # fuzzy_match(txt)
312
+
313
+ df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
314
+ df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
315
+ print("fuzzy matching time: {}".format(time.time() - t0))
316
+
317
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
318
+ return df
319
+
320
+
321
+ train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_train{}.csv'.format(cat), sep='|')
322
+ test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
323
+
324
+ train_pred = pred_text_to_go(train_pred)
325
+ test_pred = pred_text_to_go(test_pred)
326
+
327
+ train_data = pd.merge(train[['name', 'prop_annotations']],
328
+ train_pred[['name', 'pred_list_go']],
329
+ on='name', how='inner')
330
+ train_data = train_data.drop_duplicates('name')
331
+ train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
332
+
333
+ test_data = pd.merge(test[['name', 'prop_annotations']],
334
+ test_pred[['name', 'pred_list_go']],
335
+ on='name', how='inner')
336
+ test_data = test_data.drop_duplicates('name')
337
+ test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
338
+ test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
339
+
data/evaluate_data/pretrain_output_to_deepgozero.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+ import time
4
+ from multiprocessing import Pool
5
+ import difflib
6
+ from utils import Ontology
7
+ import os
8
+
9
+
10
+ def filter(x_list):
11
+ new_go = []
12
+ # x_list = [i.strip() for i in x.split(';')]
13
+ for i in x_list:
14
+ if i in filter_go:
15
+ new_go.append(i)
16
+ return '; '.join(new_go)
17
+
18
+
19
+ def fuzzy_match(texts):
20
+ text_dict = {}
21
+ for context in texts:
22
+ if context in choices:
23
+ text_dict[context] = context
24
+ elif context not in choices:
25
+ # txt_dict[txt] = process.extractOne(txt, choices)[0]
26
+ sim_list = difflib.get_close_matches(context.lower(), choices, n=1, cutoff=0.9)
27
+ if len(sim_list) > 0:
28
+ text_dict[context] = sim_list[0]
29
+ else:
30
+ # text_dict[context] = ''
31
+ pass
32
+ return text_dict
33
+
34
+
35
+ def txt_map(x, txt_dict):
36
+ if type(x) == str:
37
+ x = eval(x)
38
+ x_ = []
39
+ for i in x:
40
+ if i == '':
41
+ continue
42
+ if i in txt_dict:
43
+ x_.append(txt_dict[i])
44
+ else:
45
+ # x_.append(i)
46
+ pass
47
+ return x_
48
+
49
+
50
+ def go_map_prob(x, GO_dict):
51
+ res = []
52
+ for t in x:
53
+ if t[0] in GO_dict:
54
+ res.append((GO_dict[t[0]], t[1]))
55
+ else:
56
+ pass
57
+ # print("{} not in GO_dict".format(t[0]))
58
+ return res
59
+
60
+
61
+ def txt_map_prob(x, txt_dict):
62
+ if type(x) == str:
63
+ x = eval(x)
64
+ x_ = []
65
+ temp = set()
66
+ for i in x:
67
+ if i[0] == '':
68
+ continue
69
+ elif i[0] in txt_dict and txt_dict[i[0]] not in temp:
70
+ x_.append((txt_dict[i[0]].lower(), i[1]))
71
+ temp.add(txt_dict[i[0]])
72
+ # elif i[0] not in txt_dict:
73
+ # x_.append((i[0].lower(), i[1]))
74
+ # temp.add(i[0])
75
+ else:
76
+ continue
77
+ return x_
78
+
79
+
80
+ def go_map(x, GO_dict):
81
+ res = []
82
+ for t in x:
83
+ if t in GO_dict:
84
+ res.append(GO_dict[t])
85
+ else:
86
+ # pass
87
+ print("{} not in GO_dict".format(t))
88
+ return res
89
+
90
+
91
+ def prop(df):
92
+ prop_annotations = []
93
+ for i, row in df.iterrows():
94
+ # Propagate annotations
95
+ annot_set = set()
96
+ annots = row['GO_label']
97
+ for go_id in annots:
98
+ annot_set |= godb.get_anchestors(go_id)
99
+ annots = list(annot_set)
100
+ prop_annotations.append(annots)
101
+ df['prop_annotations'] = prop_annotations
102
+ return df
103
+
104
+
105
+ def pred_text_to_go(df, with_prob=False):
106
+ # df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
107
+ if with_prob:
108
+ df['pred_list_prob'] = df['pred'].apply(lambda x: [eval(i.strip()) for i in x.split(';')])
109
+ df['pred_list'] = df['pred_list_prob'].apply(lambda x: [i[0] for i in x])
110
+ else:
111
+ df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
112
+ ### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
113
+ t0 = time.time()
114
+ txt_dict = {}
115
+ all_txt = []
116
+ for txt in df['pred_list']:
117
+ if type(txt) == str:
118
+ all_txt.extend(eval(txt))
119
+ else:
120
+ all_txt.extend(txt)
121
+ all_txt = list(set(all_txt))
122
+ if '' in all_txt:
123
+ all_txt.remove('')
124
+ n = len(all_txt)
125
+ thread = 10
126
+ size = int(n / thread)
127
+ inds = list(range(0, n, size))
128
+ inds.append(n)
129
+ all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
130
+ with Pool(processes=thread) as pool:
131
+ result = pool.map(fuzzy_match, all_txt_sep)
132
+ pool.close()
133
+ pool.join()
134
+ for d in result:
135
+ txt_dict.update(d)
136
+ # print(txt_dict)
137
+ # for txt in all_txt[:10]:
138
+ # fuzzy_match(txt)
139
+ if with_prob:
140
+ df['pred_list_prob'] = df['pred_list_prob'].apply(lambda x: txt_map_prob(x, txt_dict))
141
+ print("fuzzy matching time: {}".format(time.time() - t0))
142
+ df['pred_list_go_prob'] = df['pred_list_prob'].apply(lambda x: go_map_prob(x, GO_dict))
143
+ n0 = df.shape[0]
144
+ df['len'] = df['pred_list_go_prob'].apply(lambda x: len(x))
145
+ df = df[df['len'] > 0]
146
+ df = df.drop('len', axis=1)
147
+ df = df.dropna()
148
+ print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
149
+ else:
150
+ df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
151
+ df['pred_list'] = df['pred_list'].apply(lambda x: [i.lower() for i in list(set(x))])
152
+ print("fuzzy matching time: {}".format(time.time() - t0))
153
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: go_map(x, GO_dict))
154
+
155
+ n0 = df.shape[0]
156
+ df['len'] = df['pred_list_go'].apply(lambda x: len(x))
157
+ df = df[df['len'] > 0]
158
+ df = df.drop('len', axis=1)
159
+ df = df.dropna()
160
+ print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
161
+ return df
162
+
163
+
164
+ def cal_f1(df):
165
+ df['label_list_go'] = df['label'].apply(lambda x: [i.strip() for i in x.split(';')])
166
+ df['pred_list_go'] = df['pred_list'].apply(lambda x: [i.strip() for i in x.split(';')])
167
+
168
+ labels = []
169
+ pred_labels = []
170
+ for l in df['label_list_go']:
171
+ labels.extend(l)
172
+
173
+ label_count = {}
174
+ for x in labels:
175
+ if x not in label_count:
176
+ label_count[x] = 1
177
+ else:
178
+ label_count[x] += 1
179
+
180
+ labels = list(set(labels))
181
+ total = len(labels)
182
+ tp_dict, fp_dict, fn_dict = dict(zip(labels, [0] * len(labels))), dict(zip(labels, [0] * len(labels))), dict(
183
+ zip(labels, [0] * len(labels)))
184
+ for preds, label in zip(df['pred_list_go'], df['label_list_go']):
185
+ for t in label:
186
+ # supgo = godb.get_anchestors(t)
187
+ # if supgo.intersection(set(preds)):
188
+ if t in preds:
189
+ tp_dict[t] += 1
190
+ else:
191
+ fn_dict[t] += 1
192
+ for p in preds:
193
+ # supgo = godb.get_anchestors(p)
194
+ # if not supgo.intersection(set(label)):
195
+ if p not in label:
196
+ if p in fp_dict:
197
+ fp_dict[p] += 1
198
+ else:
199
+ fp_dict[p] = 1
200
+ pred_labels.extend(preds)
201
+ p_total = len(set(pred_labels))
202
+ recall, pr = 0., 0.
203
+ for x in labels:
204
+ recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
205
+ pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
206
+ r = recall / total
207
+ p = pr / p_total
208
+ f1 = 2 * p * r / (p + r)
209
+
210
+ print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
211
+ print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
212
+
213
+
214
+ def cat_go(x):
215
+ try:
216
+ cat = godb.get_namespace(x)
217
+ except:
218
+ print("{} not found".format(x))
219
+ return
220
+ if cat == NAMESPACES['mf']:
221
+ return 'mf'
222
+ elif cat == NAMESPACES['bp']:
223
+ return 'bp'
224
+ elif cat == NAMESPACES['cc']:
225
+ return 'cc'
226
+ return
227
+
228
+
229
+ def remove_root(x):
230
+ if 'molecular_function' in x:
231
+ x.remove('molecular_function')
232
+ if 'biological_process' in x:
233
+ x.remove('biological_process')
234
+ if 'cellular_component' in x:
235
+ x.remove('cellular_component')
236
+ return x
237
+
238
+ if __name__ == "__main__":
239
+ NAMESPACES = {
240
+ 'cc': 'cellular_component',
241
+ 'mf': 'molecular_function',
242
+ 'bp': 'biological_process'
243
+ }
244
+ #if not os.path.exists('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl'):
245
+ if 1==1:
246
+ data = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/swissprot_domain_and_train_exp_prompt_new.csv', sep='|')
247
+ print('数据规模:{}'.format(data.shape[0]))
248
+ # data['function'] = data['function'].apply(lambda x: re.sub('[FPC]:', '', x))
249
+ # data.to_csv('swissprot_domain_and_train_exp.csv', sep='|', index=False)
250
+
251
+ godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
252
+ go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions1.4.txt', sep='|', header=None)
253
+ go_des.columns = ['id', 'text']
254
+ go_des = go_des.dropna()
255
+ go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
256
+ go_des['ont'] = go_des['id'].apply(lambda x: cat_go(x))
257
+ go_des = go_des.dropna()
258
+ go_obo_set = set(go_des['id'].tolist())
259
+ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
260
+
261
+ data['GO_label'] = data['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
262
+ data = prop(data)
263
+
264
+ # 加入父节点,得到完整的terms,映射表等等
265
+ go_dict = {}
266
+ for x_list in data['prop_annotations']:
267
+ for goid in x_list:
268
+ if goid in go_dict:
269
+ go_dict[goid] += 1
270
+ else:
271
+ go_dict[goid] = 1
272
+ df_stat = pd.DataFrame({'id': list(go_dict.keys()), 'count': list(go_dict.values())})
273
+ data_gos = set(df_stat['id'].tolist())
274
+ go_des = go_des[go_des['id'].isin(data_gos)]
275
+ filter_go = data_gos.intersection(go_obo_set)
276
+ print(f"包括父节点的GO有{len(data_gos)}个,其中在go1.4.obo中出现的GO有{len(filter_go)}个")
277
+
278
+ go_des.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
279
+ id2text_dict = dict(zip(go_des['id'], go_des['text']))
280
+ GO_dict = dict(zip(go_des['text'], go_des['id']))
281
+
282
+ choices_mf = list(set(go_des[go_des['ont'] == 'mf']['text']))
283
+ choices_bp = list(set(go_des[go_des['ont'] == 'bp']['text']))
284
+ choices_cc = list(set(go_des[go_des['ont'] == 'cc']['text']))
285
+
286
+ choices_mf = {x.lower(): x for x in choices_mf}
287
+ choices_bp = {x.lower(): x for x in choices_bp}
288
+ choices_cc = {x.lower(): x for x in choices_cc}
289
+
290
+ data['GO_label'] = data['GO_label'].apply(lambda x: filter(x))
291
+ data = data[data['GO_label'] != '']
292
+ data['function'] = data['GO_label'].apply(lambda x: [id2text_dict[i.strip()] for i in x.split(';')])
293
+ data['function'] = data['function'].apply(lambda x: '; '.join(x))
294
+
295
+ terms = pd.DataFrame({'gos': list(filter_go)})
296
+ terms.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
297
+ terms.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/terms.pkl')
298
+
299
+ terms_mf = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'mf']['id']))})
300
+ terms_mf.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/terms.pkl')
301
+ terms_mf.to_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
302
+ terms_bp = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'bp']['id']))})
303
+ terms_bp.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/bp/terms.pkl')
304
+ terms_bp.to_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
305
+ terms_cc = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'cc']['id']))})
306
+ terms_cc.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/cc/terms.pkl')
307
+ terms_cc.to_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
308
+ else:
309
+ godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
310
+ terms = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
311
+ filter_go = set(terms['gos'].tolist())
312
+
313
+ terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
314
+ terms_bp = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
315
+ terms_cc = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
316
+
317
+ choices_mf = {x.lower(): x for x in terms_mf['gos'].tolist()}
318
+ choices_bp = {x.lower(): x for x in terms_bp['gos'].tolist()}
319
+ choices_cc = {x.lower(): x for x in terms_cc['gos'].tolist()}
320
+
321
+ go_des = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
322
+ id2text_dict = dict(zip(go_des['id'], go_des['text']))
323
+ GO_dict = dict(zip(go_des['text'], go_des['id']))
324
+
325
+ # 对于预测文件,进行GO筛选,并用相似度算法匹配到filter_go;对于train test val 文件,进行GO筛选、加入祖先、加入interPro特征
326
+ # 加入interpro特征
327
+ df_interpro = pd.read_csv('/cluster/home/wenkai/LAVIS/data/uniprot_sprot_blip2_func_data.txt', sep='|',
328
+ nrows=546389,
329
+ header=None)
330
+ df_interpro.columns = ['name', 'seq', 'go', 'text', 'evi', 'ipr']
331
+ df_interpro = df_interpro[df_interpro['ipr'].notnull()]
332
+ df_interpro['ipr'] = df_interpro['ipr'].apply(lambda x: [i.strip() for i in x.split(';')])
333
+
334
+ iprs = []
335
+ for x in df_interpro['ipr'].tolist():
336
+ if len(x) > 0:
337
+ iprs.extend(x)
338
+ iprs = list(set(iprs))
339
+ print("ipr个数:{}".format(len(iprs)))
340
+ df_ipr = pd.DataFrame({'interpros': iprs})
341
+ df_ipr.to_pickle('/cluster/home/wenkai/LAVIS/data/interpros.pkl')
342
+ df_ipr.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/interpros.pkl')
343
+
344
+
345
+ '''
346
+ # test cases
347
+ df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test_2000.csv', sep='|')
348
+ df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
349
+ #df_real[col] = df_real[col].apply(lambda x: filter(x))
350
+ df_real = df_real[df_real[col] != '']
351
+ print(df_real.shape)
352
+ #df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
353
+ #df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
354
+ df_real = prop(df_real)
355
+ #df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
356
+ #df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
357
+ #df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
358
+ for ont in ['mf', 'bp', 'cc']:
359
+ file_name = 'output_{}_test_2000'.format(ont)
360
+ if ont == 'mf':
361
+ choices = choices_mf
362
+ elif ont == 'bp':
363
+ choices = choices_bp
364
+ elif ont == 'cc':
365
+ choices = choices_cc
366
+ print("对{}预测文本进行标准化...".format(file_name))
367
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file_name), sep='|', header=None, on_bad_lines='skip')
368
+ df_pred.columns = ['name', 'pred', 'label']
369
+ n0 = df_pred.shape[0]
370
+ df_pred = pred_text_to_go(df_pred, with_prob=True)
371
+ print("{}中有{}条数据未能找到相似度高的GO描述".format(file_name, n0-df_pred.shape[0]))
372
+ #df_pred['pred_list'] = df_pred['pred_list'].apply(lambda x: '; '.join(x))
373
+ #cal_f1(df_pred)
374
+ df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file_name), sep='|', index=False)
375
+
376
+ df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
377
+ df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
378
+ ipr_and_pred = []
379
+ for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
380
+ try:
381
+ ipr_and_pred.append(x + y)
382
+ except:
383
+ ipr_and_pred.append(y)
384
+ df_pred['ipr_and_pred'] = ipr_and_pred
385
+ print(df_real.isnull().sum())
386
+ df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
387
+ #df_pred = df_pred.dropna()
388
+ print(df_pred.shape)
389
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
390
+ '/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{}/test_2000_data.pkl'.format(ont))
391
+ '''
392
+
393
+ '''
394
+ df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/nextprot_mf.csv', sep='|')
395
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
396
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
397
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
398
+ df_real = prop(df_real)
399
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
400
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
401
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
402
+
403
+ file = 'output_nextprot'
404
+ choices = choices_mf
405
+ df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file), sep='|', header=None, on_bad_lines='skip')
406
+ df_pred.columns = ['name', 'pred', 'label']
407
+ df_pred = pred_text_to_go(df_pred, with_prob=True)
408
+ df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file), sep='|', index=False)
409
+
410
+ df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
411
+ df_pred['ipr'] = [[] for _ in range(df_pred.shape[0])]
412
+ df_pred['ipr_and_pred'] = df_pred['pred_list_go_prob']
413
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
414
+ '/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/nextprot_data.pkl')
415
+ '''
416
+ # '''
417
+ cat_id = {'mf': '445772', 'bp': '496359', 'cc': '505955'}
418
+ col = 'GO_label'
419
+ for ont in ['mf', 'bp', 'cc']:
420
+ #for ont in ['mf']:
421
+ if ont == 'mf':
422
+ choices = choices_mf
423
+ elif ont == 'bp':
424
+ choices = choices_bp
425
+ elif ont == 'cc':
426
+ choices = choices_cc
427
+ for split in ['train', 'val', 'test']:
428
+ #for split in ['test']:
429
+ df_real = pd.read_csv(f'/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/{split}_exp_{ont}_new.csv',
430
+ sep='|')
431
+ df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
432
+ df_real[col] = df_real[col].apply(lambda x: filter(x))
433
+ df_real = df_real[df_real[col] != '']
434
+ print(df_real.shape)
435
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
436
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
437
+ df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
438
+ df_real = prop(df_real)
439
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
440
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
441
+ df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
442
+
443
+ # 预测text转为go
444
+ df_pred = pd.read_csv(
445
+ f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_exp_{cat_id[ont]}.txt', sep='|',
446
+ header=None, on_bad_lines='skip')
447
+ df_pred.columns = ['name', 'pred', 'label']
448
+ n0 = df_pred.shape[0]
449
+ df_pred = pred_text_to_go(df_pred, with_prob=True)
450
+ print("{}中有{}条数据未能找到相似度高的GO描述".format(ont, n0 - df_pred.shape[0]))
451
+ df_pred[['name', 'pred_list_prob', 'label']].to_csv(
452
+ f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_{cat_id[ont]}_standard.csv', sep='|',
453
+ index=False)
454
+
455
+ df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
456
+ df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
457
+ ipr_and_pred = []
458
+ for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
459
+ try:
460
+ ipr_and_pred.append(x + y)
461
+ except:
462
+ ipr_and_pred.append(y)
463
+ df_pred['ipr_and_pred'] = ipr_and_pred
464
+
465
+ df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
466
+ df_pred = df_pred.dropna()
467
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
468
+ f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/{split}_data_{cat_id[ont]}.pkl')
469
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
470
+ f'/cluster/home/wenkai/deepgo2/data/{ont}/{split}_data_{cat_id[ont]}.pkl')
471
+ if split == 'val':
472
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
473
+ f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/valid_data_{cat_id[ont]}.pkl')
474
+ df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
475
+ f'/cluster/home/wenkai/deepgo2/data/{ont}/valid_data_{cat_id[ont]}.pkl')
476
+ print(f"{ont} {split} deepgozero propagation data completed")
477
+ # '''
data/evaluate_data/process_case.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from utils import Ontology
3
+
4
+
5
+ def prop(df):
6
+ prop_annotations = []
7
+ for i, row in df.iterrows():
8
+ # Propagate annotations
9
+ annot_set = set()
10
+ annots = row['GO_label']
11
+ for go_id in annots:
12
+ annot_set |= godb.get_anchestors(go_id)
13
+ annots = list(annot_set)
14
+ prop_annotations.append(annots)
15
+ df['prop_annotations'] = prop_annotations
16
+ return df
17
+
18
+ godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
19
+
20
+ case_mf = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_mf.csv', sep='|')
21
+
22
+ # bp case, 包括辣椒受体
23
+ case_bp = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_bp.csv', sep='|')
24
+ case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
25
+ case_bp = prop(case_bp)
26
+ case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: '; '.join(x))
27
+ case_bp['prop_annotations'] = case_bp['prop_annotations'].apply(lambda x: '; '.join(x))
28
+ case_bp[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/bp/cases_data.pkl')
29
+
30
+ case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
31
+ case_mf = prop(case_mf)
32
+ case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: '; '.join(x))
33
+ case_mf['prop_annotations'] = case_mf['prop_annotations'].apply(lambda x: '; '.join(x))
34
+
35
+ case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
36
+ case_bp = prop(case_bp)
37
+ case_mf[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/mf/cases_data_445772.pkl')
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
data/evaluate_data/utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque, Counter
2
+ import warnings
3
+ import pandas as pd
4
+ import numpy as np
5
+ from xml.etree import ElementTree as ET
6
+ import math
7
+
8
+ BIOLOGICAL_PROCESS = 'GO:0008150'
9
+ MOLECULAR_FUNCTION = 'GO:0003674'
10
+ CELLULAR_COMPONENT = 'GO:0005575'
11
+ FUNC_DICT = {
12
+ 'cc': CELLULAR_COMPONENT,
13
+ 'mf': MOLECULAR_FUNCTION,
14
+ 'bp': BIOLOGICAL_PROCESS}
15
+
16
+ NAMESPACES = {
17
+ 'cc': 'cellular_component',
18
+ 'mf': 'molecular_function',
19
+ 'bp': 'biological_process'
20
+ }
21
+
22
+ EXP_CODES = set([
23
+ 'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC',
24
+ 'HTP', 'HDA', 'HMP', 'HGI', 'HEP'])
25
+
26
+ # CAFA4 Targets
27
+ CAFA_TARGETS = set([
28
+ '287', '3702', '4577', '6239', '7227', '7955', '9606', '9823', '10090',
29
+ '10116', '44689', '83333', '99287', '226900', '243273', '284812', '559292'])
30
+
31
+
32
+ def is_cafa_target(org):
33
+ return org in CAFA_TARGETS
34
+
35
+
36
+ def is_exp_code(code):
37
+ return code in EXP_CODES
38
+
39
+
40
+ def get_goplus_defs(filename='data/definitions.txt'):
41
+ plus_defs = {}
42
+ with open(filename) as f:
43
+ for line in f:
44
+ line = line.strip()
45
+ go_id, definition = line.split(': ')
46
+ go_id = go_id.replace('_', ':')
47
+ definition = definition.replace('_', ':')
48
+ plus_defs[go_id] = set(definition.split(' and '))
49
+ return plus_defs
50
+
51
+
52
+ class Ontology(object):
53
+
54
+ def __init__(self, filename='data/go.obo', with_rels=False):
55
+ self.ont = self.load(filename, with_rels)
56
+ self.ic = None
57
+ self.ic_norm = 0.0
58
+
59
+ def has_term(self, term_id):
60
+ return term_id in self.ont
61
+
62
+ def get_term(self, term_id):
63
+ if self.has_term(term_id):
64
+ return self.ont[term_id]
65
+ return None
66
+
67
+ def calculate_ic(self, annots):
68
+ cnt = Counter()
69
+ for x in annots:
70
+ cnt.update(x)
71
+ self.ic = {}
72
+ for go_id, n in cnt.items():
73
+ parents = self.get_parents(go_id)
74
+ if len(parents) == 0:
75
+ min_n = n
76
+ else:
77
+ min_n = min([cnt[x] for x in parents])
78
+
79
+ self.ic[go_id] = math.log(min_n / n, 2)
80
+ self.ic_norm = max(self.ic_norm, self.ic[go_id])
81
+
82
+ def get_ic(self, go_id):
83
+ if self.ic is None:
84
+ raise Exception('Not yet calculated')
85
+ if go_id not in self.ic:
86
+ return 0.0
87
+ return self.ic[go_id]
88
+
89
+ def get_norm_ic(self, go_id):
90
+ return self.get_ic(go_id) / self.ic_norm
91
+
92
+ def load(self, filename, with_rels):
93
+ ont = dict()
94
+ obj = None
95
+ with open(filename, 'r') as f:
96
+ for line in f:
97
+ line = line.strip()
98
+ if not line:
99
+ continue
100
+ if line == '[Term]':
101
+ if obj is not None:
102
+ ont[obj['id']] = obj
103
+ obj = dict()
104
+ obj['is_a'] = list()
105
+ obj['part_of'] = list()
106
+ obj['regulates'] = list()
107
+ obj['alt_ids'] = list()
108
+ obj['is_obsolete'] = False
109
+ continue
110
+ elif line == '[Typedef]':
111
+ if obj is not None:
112
+ ont[obj['id']] = obj
113
+ obj = None
114
+ else:
115
+ if obj is None:
116
+ continue
117
+ l = line.split(": ")
118
+ if l[0] == 'id':
119
+ obj['id'] = l[1]
120
+ elif l[0] == 'alt_id':
121
+ obj['alt_ids'].append(l[1])
122
+ elif l[0] == 'namespace':
123
+ obj['namespace'] = l[1]
124
+ elif l[0] == 'is_a':
125
+ obj['is_a'].append(l[1].split(' ! ')[0])
126
+ elif with_rels and l[0] == 'relationship':
127
+ it = l[1].split()
128
+ # add all types of relationships
129
+ obj['is_a'].append(it[1])
130
+ elif l[0] == 'name':
131
+ obj['name'] = l[1]
132
+ elif l[0] == 'is_obsolete' and l[1] == 'true':
133
+ obj['is_obsolete'] = True
134
+ if obj is not None:
135
+ ont[obj['id']] = obj
136
+ for term_id in list(ont.keys()):
137
+ for t_id in ont[term_id]['alt_ids']:
138
+ ont[t_id] = ont[term_id]
139
+ if ont[term_id]['is_obsolete']:
140
+ del ont[term_id]
141
+ for term_id, val in ont.items():
142
+ if 'children' not in val:
143
+ val['children'] = set()
144
+ for p_id in val['is_a']:
145
+ if p_id in ont:
146
+ if 'children' not in ont[p_id]:
147
+ ont[p_id]['children'] = set()
148
+ ont[p_id]['children'].add(term_id)
149
+
150
+ return ont
151
+
152
+ def get_anchestors(self, term_id):
153
+ if term_id not in self.ont:
154
+ return set()
155
+ term_set = set()
156
+ q = deque()
157
+ q.append(term_id)
158
+ while (len(q) > 0):
159
+ t_id = q.popleft()
160
+ if t_id not in term_set:
161
+ term_set.add(t_id)
162
+ for parent_id in self.ont[t_id]['is_a']:
163
+ if parent_id in self.ont:
164
+ q.append(parent_id)
165
+ return term_set
166
+
167
+ def get_prop_terms(self, terms):
168
+ prop_terms = set()
169
+
170
+ for term_id in terms:
171
+ prop_terms |= self.get_anchestors(term_id)
172
+ return prop_terms
173
+
174
+ def get_parents(self, term_id):
175
+ if term_id not in self.ont:
176
+ return set()
177
+ term_set = set()
178
+ for parent_id in self.ont[term_id]['is_a']:
179
+ if parent_id in self.ont:
180
+ term_set.add(parent_id)
181
+ return term_set
182
+
183
+ def get_namespace_terms(self, namespace):
184
+ terms = set()
185
+ for go_id, obj in self.ont.items():
186
+ if obj['namespace'] == namespace:
187
+ terms.add(go_id)
188
+ return terms
189
+
190
+ def get_namespace(self, term_id):
191
+ return self.ont[term_id]['namespace']
192
+
193
+ def get_term_set(self, term_id):
194
+ if term_id not in self.ont:
195
+ return set()
196
+ term_set = set()
197
+ q = deque()
198
+ q.append(term_id)
199
+ while len(q) > 0:
200
+ t_id = q.popleft()
201
+ if t_id not in term_set:
202
+ term_set.add(t_id)
203
+ for ch_id in self.ont[t_id]['children']:
204
+ q.append(ch_id)
205
+ return term_set
206
+
207
+
208
+ def read_fasta(filename):
209
+ seqs = list()
210
+ info = list()
211
+ seq = ''
212
+ inf = ''
213
+ with open(filename, 'r') as f:
214
+ for line in f:
215
+ line = line.strip()
216
+ if line.startswith('>'):
217
+ if seq != '':
218
+ seqs.append(seq)
219
+ info.append(inf)
220
+ seq = ''
221
+ inf = line[1:].split()[0]
222
+ else:
223
+ seq += line
224
+ seqs.append(seq)
225
+ info.append(inf)
226
+ return info, seqs
227
+
228
+
229
+ class DataGenerator(object):
230
+
231
+ def __init__(self, batch_size, is_sparse=False):
232
+ self.batch_size = batch_size
233
+ self.is_sparse = is_sparse
234
+
235
+ def fit(self, inputs, targets=None):
236
+ self.start = 0
237
+ self.inputs = inputs
238
+ self.targets = targets
239
+ if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
240
+ self.size = self.inputs[0].shape[0]
241
+ else:
242
+ self.size = self.inputs.shape[0]
243
+ self.has_targets = targets is not None
244
+
245
+ def __next__(self):
246
+ return self.next()
247
+
248
+ def reset(self):
249
+ self.start = 0
250
+
251
+ def next(self):
252
+ if self.start < self.size:
253
+ batch_index = np.arange(
254
+ self.start, min(self.size, self.start + self.batch_size))
255
+ if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
256
+ res_inputs = []
257
+ for inp in self.inputs:
258
+ if self.is_sparse:
259
+ res_inputs.append(
260
+ inp[batch_index, :].toarray())
261
+ else:
262
+ res_inputs.append(inp[batch_index, :])
263
+ else:
264
+ if self.is_sparse:
265
+ res_inputs = self.inputs[batch_index, :].toarray()
266
+ else:
267
+ res_inputs = self.inputs[batch_index, :]
268
+ self.start += self.batch_size
269
+ if self.has_targets:
270
+ if self.is_sparse:
271
+ labels = self.targets[batch_index, :].toarray()
272
+ else:
273
+ labels = self.targets[batch_index, :]
274
+ return (res_inputs, labels)
275
+ return res_inputs
276
+ else:
277
+ self.reset()
278
+ return self.next()
279
+
280
+
data/fasta/example.fasta ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ >P18281
2
+ MNPELQSAIGQGAALKHAETVDKSAPQIENVTVKKVDRSSFLEEVAKPHELKHAETVDKSGPAIPEDVHVKKVDRGAFLSEIEKAAKQ
data/fasta/prepare_custom_fasta.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # prepare fasta data
2
+ name_list = ['P18281']
3
+ sequence_list = ['MNPELQSAIGQGAALKHAETVDKSAPQIENVTVKKVDRSSFLEEVAKPHELKHAETVDKSGPAIPEDVHVKKVDRGAFLSEIEKAAKQ']
4
+ with open('example.fasta', 'w') as f:
5
+ for i, j in zip(name_list, sequence_list):
6
+ f.write('>{}\n'.format(i))
7
+ f.write('{}\n'.format(j))
data/go1.4-basic.obo ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3da20cc774d666b4338446bc81341eaf536885dc10ccb667480a79f6b964aa3c
3
+ size 31134256
data/go_descriptions1.4.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/test_exp_prompt_bp_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/test_exp_prompt_cc_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/test_exp_prompt_mf_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/train_exp_prompt_bp_new.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12359211ab95f1ce1962b69f033b55e9f502a7527f49414792d1c117ec50b0be
3
+ size 28503657
data/swissprot_exp/train_exp_prompt_cc_new.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01c6144b0e338d3ce8ce98adfd4f9d09f56dc58cd347f4fbaafb6782d694ffd1
3
+ size 23292609
data/swissprot_exp/train_exp_prompt_mf_new.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca3eee941dfc0ee37f59adec6abf8a7276441f04c484a9275274c7003ef4145e
3
+ size 18791760
data/swissprot_exp/val_exp_prompt_bp_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/val_exp_prompt_cc_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/swissprot_exp/val_exp_prompt_mf_new.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/terms/bp_terms.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4952f3551e4fe205640b81f9a1816c15c14cc889bbe55f57d378fb3c6d57f2f7
3
+ size 274892
data/terms/cc_terms.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20992c211336c4f876c920c2995ae85c1422e8742b7094c997aa70ddec7fc8fd
3
+ size 39440
data/terms/mf_terms.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:192861bad821ef3523ab2dcdd1db5eac093364e9b9b4869f75587d656864d29b
3
+ size 107802