Upload 24 files
Browse files- .gitattributes +4 -0
- data/emb_esm2_3b/P18281.pt +3 -0
- data/evaluate_data/evaluate_cases.py +213 -0
- data/evaluate_data/evaluate_pretrain.py +282 -0
- data/evaluate_data/evaluate_with_ancestors.py +339 -0
- data/evaluate_data/evaluate_with_ancestors_exp.py +339 -0
- data/evaluate_data/pretrain_output_to_deepgozero.py +477 -0
- data/evaluate_data/process_case.py +50 -0
- data/evaluate_data/utils.py +280 -0
- data/fasta/example.fasta +2 -0
- data/fasta/prepare_custom_fasta.py +7 -0
- data/go1.4-basic.obo +3 -0
- data/go_descriptions1.4.txt +0 -0
- data/swissprot_exp/test_exp_prompt_bp_new.csv +0 -0
- data/swissprot_exp/test_exp_prompt_cc_new.csv +0 -0
- data/swissprot_exp/test_exp_prompt_mf_new.csv +0 -0
- data/swissprot_exp/train_exp_prompt_bp_new.csv +3 -0
- data/swissprot_exp/train_exp_prompt_cc_new.csv +3 -0
- data/swissprot_exp/train_exp_prompt_mf_new.csv +3 -0
- data/swissprot_exp/val_exp_prompt_bp_new.csv +0 -0
- data/swissprot_exp/val_exp_prompt_cc_new.csv +0 -0
- data/swissprot_exp/val_exp_prompt_mf_new.csv +0 -0
- data/terms/bp_terms.pkl +3 -0
- data/terms/cc_terms.pkl +3 -0
- data/terms/mf_terms.pkl +3 -0
.gitattributes
CHANGED
@@ -35,3 +35,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
|
37 |
assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
|
37 |
assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/go1.4-basic.obo filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
|
data/emb_esm2_3b/P18281.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91714943ae1d08f860e86cfcd098f3973dc14ca63d88556223223fc9687ac7ec
|
3 |
+
size 901864
|
data/evaluate_data/evaluate_cases.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import Levenshtein
|
5 |
+
import numpy as np
|
6 |
+
import difflib
|
7 |
+
# from torchmetrics.text import BLEUScore
|
8 |
+
import time
|
9 |
+
from multiprocessing import Pool, Queue, Process
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from data.evaluate_data.utils import Ontology
|
12 |
+
# bleu = BLEUScore(n_gram=1)
|
13 |
+
|
14 |
+
def fuzzy_match(texts):
|
15 |
+
text_dict = {}
|
16 |
+
for context in texts:
|
17 |
+
if context not in choices:
|
18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
20 |
+
return text_dict
|
21 |
+
|
22 |
+
|
23 |
+
def get_sim(text, label):
|
24 |
+
all_s = []
|
25 |
+
for x in label:
|
26 |
+
s = 0
|
27 |
+
for y in text:
|
28 |
+
temp = Levenshtein.ratio(x, y)
|
29 |
+
if temp > s:
|
30 |
+
s = temp
|
31 |
+
all_s.append(s)
|
32 |
+
all_s = [round(i, 3) for i in all_s]
|
33 |
+
|
34 |
+
# bs = [bleu(x, [label]) for x in text]
|
35 |
+
return all_s
|
36 |
+
|
37 |
+
|
38 |
+
def txt_map(x, txt_dict):
|
39 |
+
if type(x) == str:
|
40 |
+
x = eval(x)
|
41 |
+
x_ = []
|
42 |
+
for i in x:
|
43 |
+
if i == '':
|
44 |
+
continue
|
45 |
+
if i in txt_dict:
|
46 |
+
x_.append(txt_dict[i])
|
47 |
+
else:
|
48 |
+
x_.append(i)
|
49 |
+
return x_
|
50 |
+
|
51 |
+
|
52 |
+
def go_map(t):
|
53 |
+
if t in GO_dict:
|
54 |
+
return GO_dict[t]
|
55 |
+
else:
|
56 |
+
print(t)
|
57 |
+
|
58 |
+
|
59 |
+
def get_term(df):
|
60 |
+
from collections import Counter
|
61 |
+
cnt = Counter()
|
62 |
+
for i, row in enumerate(df.itertuples()):
|
63 |
+
for term in row.prop_annotations:
|
64 |
+
cnt[term] += 1
|
65 |
+
terms = list(cnt.keys())
|
66 |
+
# remove top
|
67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
68 |
+
if top_term in terms:
|
69 |
+
terms.remove(top_term)
|
70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
76 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
77 |
+
go_des.columns = ['GO', 'function']
|
78 |
+
go_des = go_des[go_des['function'].notnull()]
|
79 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
80 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
81 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
82 |
+
|
83 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
|
84 |
+
data.columns = ['protein', 'pred', 'label']
|
85 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
86 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
87 |
+
|
88 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
89 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
90 |
+
|
91 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test.csv', sep='|')
|
92 |
+
test = test.drop_duplicates()
|
93 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
94 |
+
test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
|
95 |
+
test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
96 |
+
|
97 |
+
test_dict = dict()
|
98 |
+
for x, y in zip(test['function'], test['GO_label']):
|
99 |
+
temp = dict(zip(x, y))
|
100 |
+
test_dict.update(temp)
|
101 |
+
GO_dict.update(test_dict)
|
102 |
+
|
103 |
+
choices = list(test_dict.keys())
|
104 |
+
|
105 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
106 |
+
'''
|
107 |
+
print("找到与预测文本最相似的GO标签......")
|
108 |
+
t0 = time.time()
|
109 |
+
txt_dict = {}
|
110 |
+
|
111 |
+
all_txt = []
|
112 |
+
for txt in data['pred_list']:
|
113 |
+
if type(txt) == str:
|
114 |
+
all_txt.extend(eval(txt))
|
115 |
+
else:
|
116 |
+
all_txt.extend(txt)
|
117 |
+
all_txt = list(set(all_txt))
|
118 |
+
|
119 |
+
n = len(all_txt)
|
120 |
+
thread = 10
|
121 |
+
size = int(n/thread)
|
122 |
+
inds = list(range(0, n, size))
|
123 |
+
inds.append(n)
|
124 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
125 |
+
|
126 |
+
with Pool(processes=thread) as pool:
|
127 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
128 |
+
pool.close()
|
129 |
+
pool.join()
|
130 |
+
for d in result:
|
131 |
+
txt_dict.update(d)
|
132 |
+
|
133 |
+
# for txt in all_txt[:10]:
|
134 |
+
# fuzzy_match(txt)
|
135 |
+
|
136 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
137 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
138 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
139 |
+
|
140 |
+
print("calculating f1 score ......")
|
141 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
142 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
143 |
+
'''
|
144 |
+
|
145 |
+
# 准备case测试数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
146 |
+
prepare_ancestors = True
|
147 |
+
if prepare_ancestors:
|
148 |
+
print("准备加入祖先后的数据......")
|
149 |
+
def prop(df):
|
150 |
+
prop_annotations = []
|
151 |
+
for i, row in df.iterrows():
|
152 |
+
# Propagate annotations
|
153 |
+
annot_set = set()
|
154 |
+
annots = row['GO_label']
|
155 |
+
for go_id in annots:
|
156 |
+
annot_set |= go.get_anchestors(go_id)
|
157 |
+
annots = list(annot_set)
|
158 |
+
prop_annotations.append(annots)
|
159 |
+
df['prop_annotations'] = prop_annotations
|
160 |
+
return df
|
161 |
+
|
162 |
+
def pred_text_to_go(df):
|
163 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
164 |
+
|
165 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
166 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
167 |
+
t0 = time.time()
|
168 |
+
txt_dict = {}
|
169 |
+
|
170 |
+
all_txt = []
|
171 |
+
for txt in df['pred_list']:
|
172 |
+
if type(txt) == str:
|
173 |
+
all_txt.extend(eval(txt))
|
174 |
+
else:
|
175 |
+
all_txt.extend(txt)
|
176 |
+
|
177 |
+
all_txt = list(set(all_txt))
|
178 |
+
if '' in all_txt:
|
179 |
+
all_txt.remove('')
|
180 |
+
|
181 |
+
n = len(all_txt)
|
182 |
+
thread = 10
|
183 |
+
size = int(n / thread)
|
184 |
+
inds = list(range(0, n, size))
|
185 |
+
inds.append(n)
|
186 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
187 |
+
|
188 |
+
with Pool(processes=thread) as pool:
|
189 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
190 |
+
pool.close()
|
191 |
+
pool.join()
|
192 |
+
for d in result:
|
193 |
+
txt_dict.update(d)
|
194 |
+
|
195 |
+
# for txt in all_txt[:10]:
|
196 |
+
# fuzzy_match(txt)
|
197 |
+
|
198 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
199 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
200 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
201 |
+
|
202 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
203 |
+
return df
|
204 |
+
|
205 |
+
|
206 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
|
207 |
+
test_pred.columns = ['protein', 'pred', 'GO_label']
|
208 |
+
test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
209 |
+
test_pred = prop(test_pred)
|
210 |
+
test_pred = pred_text_to_go(test_pred)
|
211 |
+
|
212 |
+
for cat in ['mf', 'bp', 'cc']:
|
213 |
+
test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_case.pkl'.format(cat))
|
data/evaluate_data/evaluate_pretrain.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import Levenshtein
|
5 |
+
import numpy as np
|
6 |
+
import difflib
|
7 |
+
# from torchmetrics.text import BLEUScore
|
8 |
+
import time
|
9 |
+
from multiprocessing import Pool, Queue, Process
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from data.evaluate_data.utils import Ontology
|
12 |
+
# bleu = BLEUScore(n_gram=1)
|
13 |
+
|
14 |
+
def fuzzy_match(texts):
|
15 |
+
text_dict = {}
|
16 |
+
for context in texts:
|
17 |
+
if context not in choices:
|
18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
19 |
+
sim_list = difflib.get_close_matches(context, choices, n=1, cutoff=0.93)
|
20 |
+
if len(sim_list) > 0:
|
21 |
+
text_dict[context] = sim_list[0]
|
22 |
+
else:
|
23 |
+
text_dict[context] = ''
|
24 |
+
return text_dict
|
25 |
+
|
26 |
+
|
27 |
+
def get_sim(text, label):
|
28 |
+
all_s = []
|
29 |
+
for x in label:
|
30 |
+
s = 0
|
31 |
+
for y in text:
|
32 |
+
temp = Levenshtein.ratio(x, y)
|
33 |
+
if temp > s:
|
34 |
+
s = temp
|
35 |
+
all_s.append(s)
|
36 |
+
all_s = [round(i, 3) for i in all_s]
|
37 |
+
|
38 |
+
# bs = [bleu(x, [label]) for x in text]
|
39 |
+
return all_s
|
40 |
+
|
41 |
+
|
42 |
+
def txt_map(x, txt_dict):
|
43 |
+
if type(x) == str:
|
44 |
+
x = eval(x)
|
45 |
+
x_ = []
|
46 |
+
for i in x:
|
47 |
+
if i == '':
|
48 |
+
continue
|
49 |
+
if i in txt_dict:
|
50 |
+
x_.append(txt_dict[i])
|
51 |
+
else:
|
52 |
+
x_.append(i)
|
53 |
+
return x_
|
54 |
+
|
55 |
+
|
56 |
+
def go_map(t):
|
57 |
+
if t in GO_dict:
|
58 |
+
return GO_dict[t]
|
59 |
+
else:
|
60 |
+
pass
|
61 |
+
#print(t)
|
62 |
+
|
63 |
+
|
64 |
+
def get_term(df):
|
65 |
+
from collections import Counter
|
66 |
+
cnt = Counter()
|
67 |
+
for i, row in enumerate(df.itertuples()):
|
68 |
+
for term in row.prop_annotations:
|
69 |
+
cnt[term] += 1
|
70 |
+
terms = list(cnt.keys())
|
71 |
+
# remove top
|
72 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
73 |
+
if top_term in terms:
|
74 |
+
terms.remove(top_term)
|
75 |
+
terms_df = pd.DataFrame({'gos': terms})
|
76 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
81 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
82 |
+
go_des.columns = ['GO', 'function']
|
83 |
+
go_des = go_des[go_des['function'].notnull()]
|
84 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
85 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
86 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
87 |
+
|
88 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_go_train.txt', sep='|', header=None, on_bad_lines='skip')
|
89 |
+
data.columns = ['name', 'pred', 'label']
|
90 |
+
#data['label'] = data['label'].apply(lambda x: x.lower())
|
91 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
92 |
+
|
93 |
+
#data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
94 |
+
data['pred_list'] = data['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
95 |
+
|
96 |
+
#train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
|
97 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
|
98 |
+
test = test.drop_duplicates()
|
99 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
100 |
+
test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
|
101 |
+
test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
102 |
+
|
103 |
+
data = pd.merge(data, test[['name', 'function']], on='name', how='left')
|
104 |
+
data['label_list'] = data['function']
|
105 |
+
|
106 |
+
test_dict = dict()
|
107 |
+
for x, y in zip(test['function'], test['GO_label']):
|
108 |
+
temp = dict(zip(x, y))
|
109 |
+
test_dict.update(temp)
|
110 |
+
GO_dict.update(test_dict)
|
111 |
+
|
112 |
+
choices = list(test_dict.keys())
|
113 |
+
|
114 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
115 |
+
print("找到与预测文本最相似的GO标签......")
|
116 |
+
t0 = time.time()
|
117 |
+
txt_dict = {}
|
118 |
+
|
119 |
+
all_txt = []
|
120 |
+
for txt in data['pred_list']:
|
121 |
+
if type(txt) == str:
|
122 |
+
all_txt.extend(eval(txt))
|
123 |
+
else:
|
124 |
+
all_txt.extend(txt)
|
125 |
+
all_txt = list(set(all_txt))
|
126 |
+
|
127 |
+
n = len(all_txt)
|
128 |
+
thread = 40
|
129 |
+
size = int(n/thread)
|
130 |
+
inds = list(range(0, n, size))
|
131 |
+
inds.append(n)
|
132 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
133 |
+
|
134 |
+
with Pool(processes=thread) as pool:
|
135 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
136 |
+
pool.close()
|
137 |
+
pool.join()
|
138 |
+
for d in result:
|
139 |
+
txt_dict.update(d)
|
140 |
+
|
141 |
+
# for txt in all_txt[:10]:
|
142 |
+
# fuzzy_match(txt)
|
143 |
+
|
144 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
145 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
146 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
147 |
+
|
148 |
+
print("calculating f1 score ......")
|
149 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
150 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
151 |
+
|
152 |
+
|
153 |
+
labels = []
|
154 |
+
pred_labels = []
|
155 |
+
for l in data['label_list_go']:
|
156 |
+
if type(l) == str:
|
157 |
+
l = eval(l)
|
158 |
+
labels.extend(l)
|
159 |
+
|
160 |
+
label_count = {}
|
161 |
+
for x in labels:
|
162 |
+
if x not in label_count:
|
163 |
+
label_count[x] = 1
|
164 |
+
else:
|
165 |
+
label_count[x] += 1
|
166 |
+
|
167 |
+
labels = list(set(labels))
|
168 |
+
total = len(labels)
|
169 |
+
recalls = []
|
170 |
+
precisions = []
|
171 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
172 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
173 |
+
if type(label) == str:
|
174 |
+
label = eval(label)
|
175 |
+
if type(preds) == str:
|
176 |
+
txts = eval(preds)
|
177 |
+
ll = len(label)
|
178 |
+
for t in label:
|
179 |
+
# supgo = go.get_anchestors(t)
|
180 |
+
# if supgo.intersection(set(preds)):
|
181 |
+
if t in preds:
|
182 |
+
tp_dict[t] += 1
|
183 |
+
else:
|
184 |
+
fn_dict[t] += 1
|
185 |
+
for p in preds:
|
186 |
+
# supgo = go.get_anchestors(p)
|
187 |
+
# if not supgo.intersection(set(label)):
|
188 |
+
if p not in label:
|
189 |
+
if p in fp_dict:
|
190 |
+
fp_dict[p] += 1
|
191 |
+
else:
|
192 |
+
fp_dict[p] = 1
|
193 |
+
pred_labels.extend(preds)
|
194 |
+
p_total = len(set(pred_labels))
|
195 |
+
recall, pr = 0., 0.
|
196 |
+
for x in labels:
|
197 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
198 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
199 |
+
r = recall / total
|
200 |
+
p = pr / p_total
|
201 |
+
f1 = 2 * p * r / (p + r)
|
202 |
+
|
203 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
204 |
+
print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
|
205 |
+
|
206 |
+
|
207 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
208 |
+
prepare_ancestors = False
|
209 |
+
if prepare_ancestors:
|
210 |
+
print("准备加入祖先后的数据......")
|
211 |
+
def prop(df):
|
212 |
+
prop_annotations = []
|
213 |
+
for i, row in df.iterrows():
|
214 |
+
# Propagate annotations
|
215 |
+
annot_set = set()
|
216 |
+
annots = row['GO_label']
|
217 |
+
for go_id in annots:
|
218 |
+
annot_set |= go.get_anchestors(go_id)
|
219 |
+
annots = list(annot_set)
|
220 |
+
prop_annotations.append(annots)
|
221 |
+
df['prop_annotations'] = prop_annotations
|
222 |
+
return df
|
223 |
+
|
224 |
+
def remove_nan(x):
|
225 |
+
if '' in x:
|
226 |
+
x.remove('')
|
227 |
+
return x
|
228 |
+
|
229 |
+
def pred_text_to_go(df):
|
230 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
231 |
+
|
232 |
+
df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
233 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
234 |
+
t0 = time.time()
|
235 |
+
txt_dict = {}
|
236 |
+
|
237 |
+
all_txt = []
|
238 |
+
for txt in df['pred_list']:
|
239 |
+
if type(txt) == str:
|
240 |
+
all_txt.extend(eval(txt))
|
241 |
+
else:
|
242 |
+
all_txt.extend(txt)
|
243 |
+
|
244 |
+
all_txt = list(set(all_txt))
|
245 |
+
if '' in all_txt:
|
246 |
+
all_txt.remove('')
|
247 |
+
|
248 |
+
n = len(all_txt)
|
249 |
+
thread = 40
|
250 |
+
size = int(n / thread)
|
251 |
+
inds = list(range(0, n, size))
|
252 |
+
inds.append(n)
|
253 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
254 |
+
|
255 |
+
with Pool(processes=thread) as pool:
|
256 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
257 |
+
pool.close()
|
258 |
+
pool.join()
|
259 |
+
for d in result:
|
260 |
+
txt_dict.update(d)
|
261 |
+
|
262 |
+
# for txt in all_txt[:10]:
|
263 |
+
# fuzzy_match(txt)
|
264 |
+
|
265 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
266 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
267 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: remove_nan(x))
|
268 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
269 |
+
|
270 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
271 |
+
return df
|
272 |
+
|
273 |
+
|
274 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/pretrain/output_pretrain.txt', sep='|', header=None)
|
275 |
+
test_pred.columns = ['protein', 'pred', 'GO_label']
|
276 |
+
test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
277 |
+
test_pred = test_pred(test)
|
278 |
+
get_term(test)
|
279 |
+
test_pred = pred_text_to_go(test_pred)
|
280 |
+
|
281 |
+
test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_pretrain.pkl')
|
282 |
+
|
data/evaluate_data/evaluate_with_ancestors.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import Levenshtein
|
5 |
+
import numpy as np
|
6 |
+
import difflib
|
7 |
+
# from torchmetrics.text import BLEUScore
|
8 |
+
import time
|
9 |
+
from multiprocessing import Pool, Queue, Process
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from data.evaluate_data.utils import Ontology
|
12 |
+
# bleu = BLEUScore(n_gram=1)
|
13 |
+
|
14 |
+
def fuzzy_match(texts):
|
15 |
+
text_dict = {}
|
16 |
+
for context in texts:
|
17 |
+
if context not in choices:
|
18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
20 |
+
return text_dict
|
21 |
+
|
22 |
+
|
23 |
+
def get_sim(text, label):
|
24 |
+
all_s = []
|
25 |
+
for x in label:
|
26 |
+
s = 0
|
27 |
+
for y in text:
|
28 |
+
temp = Levenshtein.ratio(x, y)
|
29 |
+
if temp > s:
|
30 |
+
s = temp
|
31 |
+
all_s.append(s)
|
32 |
+
all_s = [round(i, 3) for i in all_s]
|
33 |
+
|
34 |
+
# bs = [bleu(x, [label]) for x in text]
|
35 |
+
return all_s
|
36 |
+
|
37 |
+
|
38 |
+
def txt_map(x, txt_dict):
|
39 |
+
if type(x) == str:
|
40 |
+
x = eval(x)
|
41 |
+
x_ = []
|
42 |
+
for i in x:
|
43 |
+
if i == '':
|
44 |
+
continue
|
45 |
+
if i in txt_dict:
|
46 |
+
x_.append(txt_dict[i])
|
47 |
+
else:
|
48 |
+
x_.append(i)
|
49 |
+
return x_
|
50 |
+
|
51 |
+
|
52 |
+
def go_map(t):
|
53 |
+
if t in GO_dict:
|
54 |
+
return GO_dict[t]
|
55 |
+
else:
|
56 |
+
print(t)
|
57 |
+
|
58 |
+
|
59 |
+
def get_term(df):
|
60 |
+
from collections import Counter
|
61 |
+
cnt = Counter()
|
62 |
+
for i, row in enumerate(df.itertuples()):
|
63 |
+
for term in row.prop_annotations:
|
64 |
+
cnt[term] += 1
|
65 |
+
terms = list(cnt.keys())
|
66 |
+
# remove top
|
67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
68 |
+
if top_term in terms:
|
69 |
+
terms.remove(top_term)
|
70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
cat = 'mf'
|
76 |
+
|
77 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
78 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
79 |
+
go_des.columns = ['GO', 'function']
|
80 |
+
go_des = go_des[go_des['function'].notnull()]
|
81 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
82 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
83 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
84 |
+
|
85 |
+
|
86 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
|
87 |
+
|
88 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
89 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
90 |
+
|
91 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
92 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
93 |
+
|
94 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
|
95 |
+
train = train.drop_duplicates()
|
96 |
+
train['function'] = train['function'].apply(lambda x: x.lower().strip())
|
97 |
+
train_dict = dict(zip(train['function'], train['GO_label']))
|
98 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
|
99 |
+
test = test.drop_duplicates()
|
100 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
101 |
+
test_dict = dict(zip(test['function'], test['GO_label']))
|
102 |
+
GO_dict.update(train_dict)
|
103 |
+
GO_dict.update(test_dict)
|
104 |
+
|
105 |
+
choices = []
|
106 |
+
for x in data['label_list'].tolist() + train['function'].tolist():
|
107 |
+
choices.extend(x)
|
108 |
+
choices = list(set(choices))
|
109 |
+
|
110 |
+
|
111 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
112 |
+
print("找到与预测文本最相似的GO标签......")
|
113 |
+
t0 = time.time()
|
114 |
+
txt_dict = {}
|
115 |
+
|
116 |
+
all_txt = []
|
117 |
+
for txt in data['pred_list']:
|
118 |
+
if type(txt) == str:
|
119 |
+
all_txt.extend(eval(txt))
|
120 |
+
else:
|
121 |
+
all_txt.extend(txt)
|
122 |
+
all_txt = list(set(all_txt))
|
123 |
+
|
124 |
+
n = len(all_txt)
|
125 |
+
thread = 40
|
126 |
+
size = int(n/thread)
|
127 |
+
inds = list(range(0, n, size))
|
128 |
+
inds.append(n)
|
129 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
130 |
+
|
131 |
+
with Pool(processes=thread) as pool:
|
132 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
133 |
+
pool.close()
|
134 |
+
pool.join()
|
135 |
+
for d in result:
|
136 |
+
txt_dict.update(d)
|
137 |
+
|
138 |
+
# for txt in all_txt[:10]:
|
139 |
+
# fuzzy_match(txt)
|
140 |
+
|
141 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
142 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
143 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
144 |
+
|
145 |
+
|
146 |
+
# sims = []
|
147 |
+
# for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
|
148 |
+
# a = get_sim(text, label)
|
149 |
+
# sims.append(a)
|
150 |
+
#
|
151 |
+
# data['sim'] = sims
|
152 |
+
# data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
153 |
+
# print("simlarity: {}".format(data['avg_sim'].mean()))
|
154 |
+
|
155 |
+
|
156 |
+
print("calculating f1 score ......")
|
157 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
158 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
159 |
+
|
160 |
+
|
161 |
+
labels = []
|
162 |
+
pred_labels = []
|
163 |
+
for l in data['label_list_go']:
|
164 |
+
if type(l) == str:
|
165 |
+
l = eval(l)
|
166 |
+
labels.extend(l)
|
167 |
+
|
168 |
+
label_count = {}
|
169 |
+
for x in labels:
|
170 |
+
if x not in label_count:
|
171 |
+
label_count[x] = 1
|
172 |
+
else:
|
173 |
+
label_count[x] += 1
|
174 |
+
|
175 |
+
labels = list(set(labels))
|
176 |
+
total = len(labels)
|
177 |
+
recalls = []
|
178 |
+
precisions = []
|
179 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
180 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
181 |
+
if type(label) == str:
|
182 |
+
label = eval(label)
|
183 |
+
if type(preds) == str:
|
184 |
+
txts = eval(preds)
|
185 |
+
ll = len(label)
|
186 |
+
for t in label:
|
187 |
+
supgo = go.get_anchestors(t)
|
188 |
+
if supgo.intersection(set(preds)):
|
189 |
+
tp_dict[t] += 1
|
190 |
+
else:
|
191 |
+
fn_dict[t] += 1
|
192 |
+
for p in preds:
|
193 |
+
supgo = go.get_anchestors(p)
|
194 |
+
if not supgo.intersection(set(label)):
|
195 |
+
if p in fp_dict:
|
196 |
+
fp_dict[p] += 1
|
197 |
+
else:
|
198 |
+
fp_dict[p] = 1
|
199 |
+
pred_labels.extend(preds)
|
200 |
+
p_total = len(set(pred_labels))
|
201 |
+
recall, pr = 0., 0.
|
202 |
+
for x in labels:
|
203 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
204 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
205 |
+
r = recall / total
|
206 |
+
p = pr / p_total
|
207 |
+
f1 = 2 * p * r / (p + r)
|
208 |
+
|
209 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
210 |
+
print("f1 score: {}".format(f1))
|
211 |
+
|
212 |
+
'''
|
213 |
+
cat_f1 = {}
|
214 |
+
for x in labels:
|
215 |
+
if tp_dict[x] + fn_dict[x] > 0:
|
216 |
+
re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
217 |
+
pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
218 |
+
cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
|
219 |
+
|
220 |
+
plt.xlabel('f score')
|
221 |
+
plt.ylabel('count')
|
222 |
+
print(np.mean(list(cat_f1.values())))
|
223 |
+
plt.hist(list(cat_f1.values()), color='red', bins=30)
|
224 |
+
plt.show()
|
225 |
+
|
226 |
+
xs, ys = [], []
|
227 |
+
for x in labels:
|
228 |
+
xs.append(label_count[x])
|
229 |
+
ys.append(cat_f1[x])
|
230 |
+
df_count = pd.DataFrame({'xs': xs, 'ys': ys})
|
231 |
+
df_count['xs'].loc[df_count['xs'] > 10] = 11
|
232 |
+
df_count['xs'] = df_count['xs'].astype(str)
|
233 |
+
df_count1 = df_count.groupby('xs').mean().reset_index()
|
234 |
+
df_count2 = df_count.groupby('xs').count().reset_index()
|
235 |
+
|
236 |
+
plt.xlabel('label count')
|
237 |
+
plt.ylabel('f score mean')
|
238 |
+
df_count1['xs'] = df_count1['xs'].astype(int)
|
239 |
+
plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
|
240 |
+
plt.show()
|
241 |
+
|
242 |
+
plt.xlabel('label count')
|
243 |
+
plt.ylabel('protein num')
|
244 |
+
df_count2['xs'] = df_count2['xs'].astype(int)
|
245 |
+
plt.bar(df_count2['xs'], df_count2['ys'], color='red')
|
246 |
+
plt.show()
|
247 |
+
'''
|
248 |
+
|
249 |
+
|
250 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
251 |
+
print("准备加入祖先后的数据......")
|
252 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
|
253 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
|
254 |
+
train = train.groupby('name').agg({'GO_label': list}).reset_index()
|
255 |
+
test = test.groupby('name').agg({'GO_label': list}).reset_index()
|
256 |
+
|
257 |
+
def prop(df):
|
258 |
+
prop_annotations = []
|
259 |
+
for i, row in df.iterrows():
|
260 |
+
# Propagate annotations
|
261 |
+
annot_set = set()
|
262 |
+
annots = row['GO_label']
|
263 |
+
for go_id in annots:
|
264 |
+
annot_set |= go.get_anchestors(go_id)
|
265 |
+
annots = list(annot_set)
|
266 |
+
prop_annotations.append(annots)
|
267 |
+
df['prop_annotations'] = prop_annotations
|
268 |
+
return df
|
269 |
+
|
270 |
+
train = prop(train)
|
271 |
+
test = prop(test)
|
272 |
+
|
273 |
+
train_test = pd.concat([train, test])
|
274 |
+
get_term(train_test)
|
275 |
+
del train_test
|
276 |
+
|
277 |
+
def pred_text_to_go(df):
|
278 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
279 |
+
|
280 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
281 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
282 |
+
t0 = time.time()
|
283 |
+
txt_dict = {}
|
284 |
+
|
285 |
+
all_txt = []
|
286 |
+
for txt in df['pred_list']:
|
287 |
+
if type(txt) == str:
|
288 |
+
all_txt.extend(eval(txt))
|
289 |
+
else:
|
290 |
+
all_txt.extend(txt)
|
291 |
+
|
292 |
+
all_txt = list(set(all_txt))
|
293 |
+
if '' in all_txt:
|
294 |
+
all_txt.remove('')
|
295 |
+
|
296 |
+
n = len(all_txt)
|
297 |
+
thread = 40
|
298 |
+
size = int(n / thread)
|
299 |
+
inds = list(range(0, n, size))
|
300 |
+
inds.append(n)
|
301 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
302 |
+
|
303 |
+
with Pool(processes=thread) as pool:
|
304 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
305 |
+
pool.close()
|
306 |
+
pool.join()
|
307 |
+
for d in result:
|
308 |
+
txt_dict.update(d)
|
309 |
+
|
310 |
+
# for txt in all_txt[:10]:
|
311 |
+
# fuzzy_match(txt)
|
312 |
+
|
313 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
314 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
315 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
316 |
+
|
317 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
318 |
+
return df
|
319 |
+
|
320 |
+
|
321 |
+
train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_train{}.csv'.format(cat), sep='|')
|
322 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
|
323 |
+
|
324 |
+
train_pred = pred_text_to_go(train_pred)
|
325 |
+
test_pred = pred_text_to_go(test_pred)
|
326 |
+
|
327 |
+
train_data = pd.merge(train[['name', 'prop_annotations']],
|
328 |
+
train_pred[['name', 'pred_list_go']],
|
329 |
+
on='name', how='inner')
|
330 |
+
train_data = train_data.drop_duplicates('name')
|
331 |
+
train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
|
332 |
+
|
333 |
+
test_data = pd.merge(test[['name', 'prop_annotations']],
|
334 |
+
test_pred[['name', 'pred_list_go']],
|
335 |
+
on='name', how='inner')
|
336 |
+
test_data = test_data.drop_duplicates('name')
|
337 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
|
338 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
|
339 |
+
|
data/evaluate_data/evaluate_with_ancestors_exp.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import Levenshtein
|
5 |
+
import numpy as np
|
6 |
+
import difflib
|
7 |
+
# from torchmetrics.text import BLEUScore
|
8 |
+
import time
|
9 |
+
from multiprocessing import Pool, Queue, Process
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from data.evaluate_data.utils import Ontology
|
12 |
+
# bleu = BLEUScore(n_gram=1)
|
13 |
+
|
14 |
+
def fuzzy_match(texts):
|
15 |
+
text_dict = {}
|
16 |
+
for context in texts:
|
17 |
+
if context not in choices:
|
18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
20 |
+
return text_dict
|
21 |
+
|
22 |
+
|
23 |
+
def get_sim(text, label):
|
24 |
+
all_s = []
|
25 |
+
for x in label:
|
26 |
+
s = 0
|
27 |
+
for y in text:
|
28 |
+
temp = Levenshtein.ratio(x, y)
|
29 |
+
if temp > s:
|
30 |
+
s = temp
|
31 |
+
all_s.append(s)
|
32 |
+
all_s = [round(i, 3) for i in all_s]
|
33 |
+
|
34 |
+
# bs = [bleu(x, [label]) for x in text]
|
35 |
+
return all_s
|
36 |
+
|
37 |
+
|
38 |
+
def txt_map(x, txt_dict):
|
39 |
+
if type(x) == str:
|
40 |
+
x = eval(x)
|
41 |
+
x_ = []
|
42 |
+
for i in x:
|
43 |
+
if i == '':
|
44 |
+
continue
|
45 |
+
if i in txt_dict:
|
46 |
+
x_.append(txt_dict[i])
|
47 |
+
else:
|
48 |
+
x_.append(i)
|
49 |
+
return x_
|
50 |
+
|
51 |
+
|
52 |
+
def go_map(t):
|
53 |
+
if t in GO_dict:
|
54 |
+
return GO_dict[t]
|
55 |
+
else:
|
56 |
+
print(t)
|
57 |
+
|
58 |
+
|
59 |
+
def get_term(df):
|
60 |
+
from collections import Counter
|
61 |
+
cnt = Counter()
|
62 |
+
for i, row in enumerate(df.itertuples()):
|
63 |
+
for term in row.prop_annotations:
|
64 |
+
cnt[term] += 1
|
65 |
+
terms = list(cnt.keys())
|
66 |
+
# remove top
|
67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
68 |
+
if top_term in terms:
|
69 |
+
terms.remove(top_term)
|
70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
cat = 'mf'
|
76 |
+
|
77 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
78 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
79 |
+
go_des.columns = ['GO', 'function']
|
80 |
+
go_des = go_des[go_des['function'].notnull()]
|
81 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
82 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
83 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
84 |
+
|
85 |
+
|
86 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
|
87 |
+
|
88 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
89 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
90 |
+
|
91 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
92 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
93 |
+
|
94 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
|
95 |
+
train = train.drop_duplicates()
|
96 |
+
train['function'] = train['function'].apply(lambda x: x.lower().strip())
|
97 |
+
train_dict = dict(zip(train['function'], train['GO_label']))
|
98 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
|
99 |
+
test = test.drop_duplicates()
|
100 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
101 |
+
test_dict = dict(zip(test['function'], test['GO_label']))
|
102 |
+
GO_dict.update(train_dict)
|
103 |
+
GO_dict.update(test_dict)
|
104 |
+
|
105 |
+
choices = []
|
106 |
+
for x in data['label_list'].tolist() + train['function'].tolist():
|
107 |
+
choices.extend(x)
|
108 |
+
choices = list(set(choices))
|
109 |
+
|
110 |
+
|
111 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
112 |
+
print("找到与预测文本最相似的GO标签......")
|
113 |
+
t0 = time.time()
|
114 |
+
txt_dict = {}
|
115 |
+
|
116 |
+
all_txt = []
|
117 |
+
for txt in data['pred_list']:
|
118 |
+
if type(txt) == str:
|
119 |
+
all_txt.extend(eval(txt))
|
120 |
+
else:
|
121 |
+
all_txt.extend(txt)
|
122 |
+
all_txt = list(set(all_txt))
|
123 |
+
|
124 |
+
n = len(all_txt)
|
125 |
+
thread = 40
|
126 |
+
size = int(n/thread)
|
127 |
+
inds = list(range(0, n, size))
|
128 |
+
inds.append(n)
|
129 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
130 |
+
|
131 |
+
with Pool(processes=thread) as pool:
|
132 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
133 |
+
pool.close()
|
134 |
+
pool.join()
|
135 |
+
for d in result:
|
136 |
+
txt_dict.update(d)
|
137 |
+
|
138 |
+
# for txt in all_txt[:10]:
|
139 |
+
# fuzzy_match(txt)
|
140 |
+
|
141 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
142 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
143 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
144 |
+
|
145 |
+
|
146 |
+
# sims = []
|
147 |
+
# for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
|
148 |
+
# a = get_sim(text, label)
|
149 |
+
# sims.append(a)
|
150 |
+
#
|
151 |
+
# data['sim'] = sims
|
152 |
+
# data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
153 |
+
# print("simlarity: {}".format(data['avg_sim'].mean()))
|
154 |
+
|
155 |
+
|
156 |
+
print("calculating f1 score ......")
|
157 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
158 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
159 |
+
|
160 |
+
|
161 |
+
labels = []
|
162 |
+
pred_labels = []
|
163 |
+
for l in data['label_list_go']:
|
164 |
+
if type(l) == str:
|
165 |
+
l = eval(l)
|
166 |
+
labels.extend(l)
|
167 |
+
|
168 |
+
label_count = {}
|
169 |
+
for x in labels:
|
170 |
+
if x not in label_count:
|
171 |
+
label_count[x] = 1
|
172 |
+
else:
|
173 |
+
label_count[x] += 1
|
174 |
+
|
175 |
+
labels = list(set(labels))
|
176 |
+
total = len(labels)
|
177 |
+
recalls = []
|
178 |
+
precisions = []
|
179 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
180 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
181 |
+
if type(label) == str:
|
182 |
+
label = eval(label)
|
183 |
+
if type(preds) == str:
|
184 |
+
txts = eval(preds)
|
185 |
+
ll = len(label)
|
186 |
+
for t in label:
|
187 |
+
supgo = go.get_anchestors(t)
|
188 |
+
if supgo.intersection(set(preds)):
|
189 |
+
tp_dict[t] += 1
|
190 |
+
else:
|
191 |
+
fn_dict[t] += 1
|
192 |
+
for p in preds:
|
193 |
+
supgo = go.get_anchestors(p)
|
194 |
+
if not supgo.intersection(set(label)):
|
195 |
+
if p in fp_dict:
|
196 |
+
fp_dict[p] += 1
|
197 |
+
else:
|
198 |
+
fp_dict[p] = 1
|
199 |
+
pred_labels.extend(preds)
|
200 |
+
p_total = len(set(pred_labels))
|
201 |
+
recall, pr = 0., 0.
|
202 |
+
for x in labels:
|
203 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
204 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
205 |
+
r = recall / total
|
206 |
+
p = pr / p_total
|
207 |
+
f1 = 2 * p * r / (p + r)
|
208 |
+
|
209 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
210 |
+
print("f1 score: {}".format(f1))
|
211 |
+
|
212 |
+
'''
|
213 |
+
cat_f1 = {}
|
214 |
+
for x in labels:
|
215 |
+
if tp_dict[x] + fn_dict[x] > 0:
|
216 |
+
re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
217 |
+
pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
218 |
+
cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
|
219 |
+
|
220 |
+
plt.xlabel('f score')
|
221 |
+
plt.ylabel('count')
|
222 |
+
print(np.mean(list(cat_f1.values())))
|
223 |
+
plt.hist(list(cat_f1.values()), color='red', bins=30)
|
224 |
+
plt.show()
|
225 |
+
|
226 |
+
xs, ys = [], []
|
227 |
+
for x in labels:
|
228 |
+
xs.append(label_count[x])
|
229 |
+
ys.append(cat_f1[x])
|
230 |
+
df_count = pd.DataFrame({'xs': xs, 'ys': ys})
|
231 |
+
df_count['xs'].loc[df_count['xs'] > 10] = 11
|
232 |
+
df_count['xs'] = df_count['xs'].astype(str)
|
233 |
+
df_count1 = df_count.groupby('xs').mean().reset_index()
|
234 |
+
df_count2 = df_count.groupby('xs').count().reset_index()
|
235 |
+
|
236 |
+
plt.xlabel('label count')
|
237 |
+
plt.ylabel('f score mean')
|
238 |
+
df_count1['xs'] = df_count1['xs'].astype(int)
|
239 |
+
plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
|
240 |
+
plt.show()
|
241 |
+
|
242 |
+
plt.xlabel('label count')
|
243 |
+
plt.ylabel('protein num')
|
244 |
+
df_count2['xs'] = df_count2['xs'].astype(int)
|
245 |
+
plt.bar(df_count2['xs'], df_count2['ys'], color='red')
|
246 |
+
plt.show()
|
247 |
+
'''
|
248 |
+
|
249 |
+
|
250 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
251 |
+
print("准备加入祖先后的数据......")
|
252 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
|
253 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
|
254 |
+
train = train.groupby('name').agg({'GO_label': list}).reset_index()
|
255 |
+
test = test.groupby('name').agg({'GO_label': list}).reset_index()
|
256 |
+
|
257 |
+
def prop(df):
|
258 |
+
prop_annotations = []
|
259 |
+
for i, row in df.iterrows():
|
260 |
+
# Propagate annotations
|
261 |
+
annot_set = set()
|
262 |
+
annots = row['GO_label']
|
263 |
+
for go_id in annots:
|
264 |
+
annot_set |= go.get_anchestors(go_id)
|
265 |
+
annots = list(annot_set)
|
266 |
+
prop_annotations.append(annots)
|
267 |
+
df['prop_annotations'] = prop_annotations
|
268 |
+
return df
|
269 |
+
|
270 |
+
train = prop(train)
|
271 |
+
test = prop(test)
|
272 |
+
|
273 |
+
train_test = pd.concat([train, test])
|
274 |
+
get_term(train_test)
|
275 |
+
del train_test
|
276 |
+
|
277 |
+
def pred_text_to_go(df):
|
278 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
279 |
+
|
280 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
281 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
282 |
+
t0 = time.time()
|
283 |
+
txt_dict = {}
|
284 |
+
|
285 |
+
all_txt = []
|
286 |
+
for txt in df['pred_list']:
|
287 |
+
if type(txt) == str:
|
288 |
+
all_txt.extend(eval(txt))
|
289 |
+
else:
|
290 |
+
all_txt.extend(txt)
|
291 |
+
|
292 |
+
all_txt = list(set(all_txt))
|
293 |
+
if '' in all_txt:
|
294 |
+
all_txt.remove('')
|
295 |
+
|
296 |
+
n = len(all_txt)
|
297 |
+
thread = 40
|
298 |
+
size = int(n / thread)
|
299 |
+
inds = list(range(0, n, size))
|
300 |
+
inds.append(n)
|
301 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
302 |
+
|
303 |
+
with Pool(processes=thread) as pool:
|
304 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
305 |
+
pool.close()
|
306 |
+
pool.join()
|
307 |
+
for d in result:
|
308 |
+
txt_dict.update(d)
|
309 |
+
|
310 |
+
# for txt in all_txt[:10]:
|
311 |
+
# fuzzy_match(txt)
|
312 |
+
|
313 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
314 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
315 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
316 |
+
|
317 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
318 |
+
return df
|
319 |
+
|
320 |
+
|
321 |
+
train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_train{}.csv'.format(cat), sep='|')
|
322 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
|
323 |
+
|
324 |
+
train_pred = pred_text_to_go(train_pred)
|
325 |
+
test_pred = pred_text_to_go(test_pred)
|
326 |
+
|
327 |
+
train_data = pd.merge(train[['name', 'prop_annotations']],
|
328 |
+
train_pred[['name', 'pred_list_go']],
|
329 |
+
on='name', how='inner')
|
330 |
+
train_data = train_data.drop_duplicates('name')
|
331 |
+
train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
|
332 |
+
|
333 |
+
test_data = pd.merge(test[['name', 'prop_annotations']],
|
334 |
+
test_pred[['name', 'pred_list_go']],
|
335 |
+
on='name', how='inner')
|
336 |
+
test_data = test_data.drop_duplicates('name')
|
337 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
|
338 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
|
339 |
+
|
data/evaluate_data/pretrain_output_to_deepgozero.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import pandas as pd
|
3 |
+
import time
|
4 |
+
from multiprocessing import Pool
|
5 |
+
import difflib
|
6 |
+
from utils import Ontology
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
def filter(x_list):
|
11 |
+
new_go = []
|
12 |
+
# x_list = [i.strip() for i in x.split(';')]
|
13 |
+
for i in x_list:
|
14 |
+
if i in filter_go:
|
15 |
+
new_go.append(i)
|
16 |
+
return '; '.join(new_go)
|
17 |
+
|
18 |
+
|
19 |
+
def fuzzy_match(texts):
|
20 |
+
text_dict = {}
|
21 |
+
for context in texts:
|
22 |
+
if context in choices:
|
23 |
+
text_dict[context] = context
|
24 |
+
elif context not in choices:
|
25 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
26 |
+
sim_list = difflib.get_close_matches(context.lower(), choices, n=1, cutoff=0.9)
|
27 |
+
if len(sim_list) > 0:
|
28 |
+
text_dict[context] = sim_list[0]
|
29 |
+
else:
|
30 |
+
# text_dict[context] = ''
|
31 |
+
pass
|
32 |
+
return text_dict
|
33 |
+
|
34 |
+
|
35 |
+
def txt_map(x, txt_dict):
|
36 |
+
if type(x) == str:
|
37 |
+
x = eval(x)
|
38 |
+
x_ = []
|
39 |
+
for i in x:
|
40 |
+
if i == '':
|
41 |
+
continue
|
42 |
+
if i in txt_dict:
|
43 |
+
x_.append(txt_dict[i])
|
44 |
+
else:
|
45 |
+
# x_.append(i)
|
46 |
+
pass
|
47 |
+
return x_
|
48 |
+
|
49 |
+
|
50 |
+
def go_map_prob(x, GO_dict):
|
51 |
+
res = []
|
52 |
+
for t in x:
|
53 |
+
if t[0] in GO_dict:
|
54 |
+
res.append((GO_dict[t[0]], t[1]))
|
55 |
+
else:
|
56 |
+
pass
|
57 |
+
# print("{} not in GO_dict".format(t[0]))
|
58 |
+
return res
|
59 |
+
|
60 |
+
|
61 |
+
def txt_map_prob(x, txt_dict):
|
62 |
+
if type(x) == str:
|
63 |
+
x = eval(x)
|
64 |
+
x_ = []
|
65 |
+
temp = set()
|
66 |
+
for i in x:
|
67 |
+
if i[0] == '':
|
68 |
+
continue
|
69 |
+
elif i[0] in txt_dict and txt_dict[i[0]] not in temp:
|
70 |
+
x_.append((txt_dict[i[0]].lower(), i[1]))
|
71 |
+
temp.add(txt_dict[i[0]])
|
72 |
+
# elif i[0] not in txt_dict:
|
73 |
+
# x_.append((i[0].lower(), i[1]))
|
74 |
+
# temp.add(i[0])
|
75 |
+
else:
|
76 |
+
continue
|
77 |
+
return x_
|
78 |
+
|
79 |
+
|
80 |
+
def go_map(x, GO_dict):
|
81 |
+
res = []
|
82 |
+
for t in x:
|
83 |
+
if t in GO_dict:
|
84 |
+
res.append(GO_dict[t])
|
85 |
+
else:
|
86 |
+
# pass
|
87 |
+
print("{} not in GO_dict".format(t))
|
88 |
+
return res
|
89 |
+
|
90 |
+
|
91 |
+
def prop(df):
|
92 |
+
prop_annotations = []
|
93 |
+
for i, row in df.iterrows():
|
94 |
+
# Propagate annotations
|
95 |
+
annot_set = set()
|
96 |
+
annots = row['GO_label']
|
97 |
+
for go_id in annots:
|
98 |
+
annot_set |= godb.get_anchestors(go_id)
|
99 |
+
annots = list(annot_set)
|
100 |
+
prop_annotations.append(annots)
|
101 |
+
df['prop_annotations'] = prop_annotations
|
102 |
+
return df
|
103 |
+
|
104 |
+
|
105 |
+
def pred_text_to_go(df, with_prob=False):
|
106 |
+
# df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
107 |
+
if with_prob:
|
108 |
+
df['pred_list_prob'] = df['pred'].apply(lambda x: [eval(i.strip()) for i in x.split(';')])
|
109 |
+
df['pred_list'] = df['pred_list_prob'].apply(lambda x: [i[0] for i in x])
|
110 |
+
else:
|
111 |
+
df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
112 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
113 |
+
t0 = time.time()
|
114 |
+
txt_dict = {}
|
115 |
+
all_txt = []
|
116 |
+
for txt in df['pred_list']:
|
117 |
+
if type(txt) == str:
|
118 |
+
all_txt.extend(eval(txt))
|
119 |
+
else:
|
120 |
+
all_txt.extend(txt)
|
121 |
+
all_txt = list(set(all_txt))
|
122 |
+
if '' in all_txt:
|
123 |
+
all_txt.remove('')
|
124 |
+
n = len(all_txt)
|
125 |
+
thread = 10
|
126 |
+
size = int(n / thread)
|
127 |
+
inds = list(range(0, n, size))
|
128 |
+
inds.append(n)
|
129 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
130 |
+
with Pool(processes=thread) as pool:
|
131 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
132 |
+
pool.close()
|
133 |
+
pool.join()
|
134 |
+
for d in result:
|
135 |
+
txt_dict.update(d)
|
136 |
+
# print(txt_dict)
|
137 |
+
# for txt in all_txt[:10]:
|
138 |
+
# fuzzy_match(txt)
|
139 |
+
if with_prob:
|
140 |
+
df['pred_list_prob'] = df['pred_list_prob'].apply(lambda x: txt_map_prob(x, txt_dict))
|
141 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
142 |
+
df['pred_list_go_prob'] = df['pred_list_prob'].apply(lambda x: go_map_prob(x, GO_dict))
|
143 |
+
n0 = df.shape[0]
|
144 |
+
df['len'] = df['pred_list_go_prob'].apply(lambda x: len(x))
|
145 |
+
df = df[df['len'] > 0]
|
146 |
+
df = df.drop('len', axis=1)
|
147 |
+
df = df.dropna()
|
148 |
+
print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
|
149 |
+
else:
|
150 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
151 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: [i.lower() for i in list(set(x))])
|
152 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
153 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: go_map(x, GO_dict))
|
154 |
+
|
155 |
+
n0 = df.shape[0]
|
156 |
+
df['len'] = df['pred_list_go'].apply(lambda x: len(x))
|
157 |
+
df = df[df['len'] > 0]
|
158 |
+
df = df.drop('len', axis=1)
|
159 |
+
df = df.dropna()
|
160 |
+
print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
|
161 |
+
return df
|
162 |
+
|
163 |
+
|
164 |
+
def cal_f1(df):
|
165 |
+
df['label_list_go'] = df['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
166 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [i.strip() for i in x.split(';')])
|
167 |
+
|
168 |
+
labels = []
|
169 |
+
pred_labels = []
|
170 |
+
for l in df['label_list_go']:
|
171 |
+
labels.extend(l)
|
172 |
+
|
173 |
+
label_count = {}
|
174 |
+
for x in labels:
|
175 |
+
if x not in label_count:
|
176 |
+
label_count[x] = 1
|
177 |
+
else:
|
178 |
+
label_count[x] += 1
|
179 |
+
|
180 |
+
labels = list(set(labels))
|
181 |
+
total = len(labels)
|
182 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0] * len(labels))), dict(zip(labels, [0] * len(labels))), dict(
|
183 |
+
zip(labels, [0] * len(labels)))
|
184 |
+
for preds, label in zip(df['pred_list_go'], df['label_list_go']):
|
185 |
+
for t in label:
|
186 |
+
# supgo = godb.get_anchestors(t)
|
187 |
+
# if supgo.intersection(set(preds)):
|
188 |
+
if t in preds:
|
189 |
+
tp_dict[t] += 1
|
190 |
+
else:
|
191 |
+
fn_dict[t] += 1
|
192 |
+
for p in preds:
|
193 |
+
# supgo = godb.get_anchestors(p)
|
194 |
+
# if not supgo.intersection(set(label)):
|
195 |
+
if p not in label:
|
196 |
+
if p in fp_dict:
|
197 |
+
fp_dict[p] += 1
|
198 |
+
else:
|
199 |
+
fp_dict[p] = 1
|
200 |
+
pred_labels.extend(preds)
|
201 |
+
p_total = len(set(pred_labels))
|
202 |
+
recall, pr = 0., 0.
|
203 |
+
for x in labels:
|
204 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
205 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
206 |
+
r = recall / total
|
207 |
+
p = pr / p_total
|
208 |
+
f1 = 2 * p * r / (p + r)
|
209 |
+
|
210 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
211 |
+
print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
|
212 |
+
|
213 |
+
|
214 |
+
def cat_go(x):
|
215 |
+
try:
|
216 |
+
cat = godb.get_namespace(x)
|
217 |
+
except:
|
218 |
+
print("{} not found".format(x))
|
219 |
+
return
|
220 |
+
if cat == NAMESPACES['mf']:
|
221 |
+
return 'mf'
|
222 |
+
elif cat == NAMESPACES['bp']:
|
223 |
+
return 'bp'
|
224 |
+
elif cat == NAMESPACES['cc']:
|
225 |
+
return 'cc'
|
226 |
+
return
|
227 |
+
|
228 |
+
|
229 |
+
def remove_root(x):
|
230 |
+
if 'molecular_function' in x:
|
231 |
+
x.remove('molecular_function')
|
232 |
+
if 'biological_process' in x:
|
233 |
+
x.remove('biological_process')
|
234 |
+
if 'cellular_component' in x:
|
235 |
+
x.remove('cellular_component')
|
236 |
+
return x
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
NAMESPACES = {
|
240 |
+
'cc': 'cellular_component',
|
241 |
+
'mf': 'molecular_function',
|
242 |
+
'bp': 'biological_process'
|
243 |
+
}
|
244 |
+
#if not os.path.exists('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl'):
|
245 |
+
if 1==1:
|
246 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/swissprot_domain_and_train_exp_prompt_new.csv', sep='|')
|
247 |
+
print('数据规模:{}'.format(data.shape[0]))
|
248 |
+
# data['function'] = data['function'].apply(lambda x: re.sub('[FPC]:', '', x))
|
249 |
+
# data.to_csv('swissprot_domain_and_train_exp.csv', sep='|', index=False)
|
250 |
+
|
251 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
252 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions1.4.txt', sep='|', header=None)
|
253 |
+
go_des.columns = ['id', 'text']
|
254 |
+
go_des = go_des.dropna()
|
255 |
+
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
|
256 |
+
go_des['ont'] = go_des['id'].apply(lambda x: cat_go(x))
|
257 |
+
go_des = go_des.dropna()
|
258 |
+
go_obo_set = set(go_des['id'].tolist())
|
259 |
+
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
260 |
+
|
261 |
+
data['GO_label'] = data['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
262 |
+
data = prop(data)
|
263 |
+
|
264 |
+
# 加入父节点,得到完整的terms,映射表等等
|
265 |
+
go_dict = {}
|
266 |
+
for x_list in data['prop_annotations']:
|
267 |
+
for goid in x_list:
|
268 |
+
if goid in go_dict:
|
269 |
+
go_dict[goid] += 1
|
270 |
+
else:
|
271 |
+
go_dict[goid] = 1
|
272 |
+
df_stat = pd.DataFrame({'id': list(go_dict.keys()), 'count': list(go_dict.values())})
|
273 |
+
data_gos = set(df_stat['id'].tolist())
|
274 |
+
go_des = go_des[go_des['id'].isin(data_gos)]
|
275 |
+
filter_go = data_gos.intersection(go_obo_set)
|
276 |
+
print(f"包括父节点的GO有{len(data_gos)}个,其中在go1.4.obo中出现的GO有{len(filter_go)}个")
|
277 |
+
|
278 |
+
go_des.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
|
279 |
+
id2text_dict = dict(zip(go_des['id'], go_des['text']))
|
280 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
281 |
+
|
282 |
+
choices_mf = list(set(go_des[go_des['ont'] == 'mf']['text']))
|
283 |
+
choices_bp = list(set(go_des[go_des['ont'] == 'bp']['text']))
|
284 |
+
choices_cc = list(set(go_des[go_des['ont'] == 'cc']['text']))
|
285 |
+
|
286 |
+
choices_mf = {x.lower(): x for x in choices_mf}
|
287 |
+
choices_bp = {x.lower(): x for x in choices_bp}
|
288 |
+
choices_cc = {x.lower(): x for x in choices_cc}
|
289 |
+
|
290 |
+
data['GO_label'] = data['GO_label'].apply(lambda x: filter(x))
|
291 |
+
data = data[data['GO_label'] != '']
|
292 |
+
data['function'] = data['GO_label'].apply(lambda x: [id2text_dict[i.strip()] for i in x.split(';')])
|
293 |
+
data['function'] = data['function'].apply(lambda x: '; '.join(x))
|
294 |
+
|
295 |
+
terms = pd.DataFrame({'gos': list(filter_go)})
|
296 |
+
terms.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
|
297 |
+
terms.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/terms.pkl')
|
298 |
+
|
299 |
+
terms_mf = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'mf']['id']))})
|
300 |
+
terms_mf.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/terms.pkl')
|
301 |
+
terms_mf.to_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
302 |
+
terms_bp = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'bp']['id']))})
|
303 |
+
terms_bp.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/bp/terms.pkl')
|
304 |
+
terms_bp.to_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
|
305 |
+
terms_cc = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'cc']['id']))})
|
306 |
+
terms_cc.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/cc/terms.pkl')
|
307 |
+
terms_cc.to_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
|
308 |
+
else:
|
309 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
310 |
+
terms = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
|
311 |
+
filter_go = set(terms['gos'].tolist())
|
312 |
+
|
313 |
+
terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
314 |
+
terms_bp = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
|
315 |
+
terms_cc = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
|
316 |
+
|
317 |
+
choices_mf = {x.lower(): x for x in terms_mf['gos'].tolist()}
|
318 |
+
choices_bp = {x.lower(): x for x in terms_bp['gos'].tolist()}
|
319 |
+
choices_cc = {x.lower(): x for x in terms_cc['gos'].tolist()}
|
320 |
+
|
321 |
+
go_des = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
|
322 |
+
id2text_dict = dict(zip(go_des['id'], go_des['text']))
|
323 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
324 |
+
|
325 |
+
# 对于预测文件,进行GO筛选,并用相似度算法匹配到filter_go;对于train test val 文件,进行GO筛选、加入祖先、加入interPro特征
|
326 |
+
# 加入interpro特征
|
327 |
+
df_interpro = pd.read_csv('/cluster/home/wenkai/LAVIS/data/uniprot_sprot_blip2_func_data.txt', sep='|',
|
328 |
+
nrows=546389,
|
329 |
+
header=None)
|
330 |
+
df_interpro.columns = ['name', 'seq', 'go', 'text', 'evi', 'ipr']
|
331 |
+
df_interpro = df_interpro[df_interpro['ipr'].notnull()]
|
332 |
+
df_interpro['ipr'] = df_interpro['ipr'].apply(lambda x: [i.strip() for i in x.split(';')])
|
333 |
+
|
334 |
+
iprs = []
|
335 |
+
for x in df_interpro['ipr'].tolist():
|
336 |
+
if len(x) > 0:
|
337 |
+
iprs.extend(x)
|
338 |
+
iprs = list(set(iprs))
|
339 |
+
print("ipr个数:{}".format(len(iprs)))
|
340 |
+
df_ipr = pd.DataFrame({'interpros': iprs})
|
341 |
+
df_ipr.to_pickle('/cluster/home/wenkai/LAVIS/data/interpros.pkl')
|
342 |
+
df_ipr.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/interpros.pkl')
|
343 |
+
|
344 |
+
|
345 |
+
'''
|
346 |
+
# test cases
|
347 |
+
df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test_2000.csv', sep='|')
|
348 |
+
df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
|
349 |
+
#df_real[col] = df_real[col].apply(lambda x: filter(x))
|
350 |
+
df_real = df_real[df_real[col] != '']
|
351 |
+
print(df_real.shape)
|
352 |
+
#df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
353 |
+
#df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
354 |
+
df_real = prop(df_real)
|
355 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
356 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
357 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
358 |
+
for ont in ['mf', 'bp', 'cc']:
|
359 |
+
file_name = 'output_{}_test_2000'.format(ont)
|
360 |
+
if ont == 'mf':
|
361 |
+
choices = choices_mf
|
362 |
+
elif ont == 'bp':
|
363 |
+
choices = choices_bp
|
364 |
+
elif ont == 'cc':
|
365 |
+
choices = choices_cc
|
366 |
+
print("对{}预测文本进行标准化...".format(file_name))
|
367 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file_name), sep='|', header=None, on_bad_lines='skip')
|
368 |
+
df_pred.columns = ['name', 'pred', 'label']
|
369 |
+
n0 = df_pred.shape[0]
|
370 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
371 |
+
print("{}中有{}条数据未能找到相似度高的GO描述".format(file_name, n0-df_pred.shape[0]))
|
372 |
+
#df_pred['pred_list'] = df_pred['pred_list'].apply(lambda x: '; '.join(x))
|
373 |
+
#cal_f1(df_pred)
|
374 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file_name), sep='|', index=False)
|
375 |
+
|
376 |
+
df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
|
377 |
+
df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
|
378 |
+
ipr_and_pred = []
|
379 |
+
for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
|
380 |
+
try:
|
381 |
+
ipr_and_pred.append(x + y)
|
382 |
+
except:
|
383 |
+
ipr_and_pred.append(y)
|
384 |
+
df_pred['ipr_and_pred'] = ipr_and_pred
|
385 |
+
print(df_real.isnull().sum())
|
386 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
387 |
+
#df_pred = df_pred.dropna()
|
388 |
+
print(df_pred.shape)
|
389 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
390 |
+
'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{}/test_2000_data.pkl'.format(ont))
|
391 |
+
'''
|
392 |
+
|
393 |
+
'''
|
394 |
+
df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/nextprot_mf.csv', sep='|')
|
395 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
396 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
397 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
398 |
+
df_real = prop(df_real)
|
399 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
400 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
401 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
402 |
+
|
403 |
+
file = 'output_nextprot'
|
404 |
+
choices = choices_mf
|
405 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file), sep='|', header=None, on_bad_lines='skip')
|
406 |
+
df_pred.columns = ['name', 'pred', 'label']
|
407 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
408 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file), sep='|', index=False)
|
409 |
+
|
410 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
411 |
+
df_pred['ipr'] = [[] for _ in range(df_pred.shape[0])]
|
412 |
+
df_pred['ipr_and_pred'] = df_pred['pred_list_go_prob']
|
413 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
414 |
+
'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/nextprot_data.pkl')
|
415 |
+
'''
|
416 |
+
# '''
|
417 |
+
cat_id = {'mf': '445772', 'bp': '496359', 'cc': '505955'}
|
418 |
+
col = 'GO_label'
|
419 |
+
for ont in ['mf', 'bp', 'cc']:
|
420 |
+
#for ont in ['mf']:
|
421 |
+
if ont == 'mf':
|
422 |
+
choices = choices_mf
|
423 |
+
elif ont == 'bp':
|
424 |
+
choices = choices_bp
|
425 |
+
elif ont == 'cc':
|
426 |
+
choices = choices_cc
|
427 |
+
for split in ['train', 'val', 'test']:
|
428 |
+
#for split in ['test']:
|
429 |
+
df_real = pd.read_csv(f'/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/{split}_exp_{ont}_new.csv',
|
430 |
+
sep='|')
|
431 |
+
df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
|
432 |
+
df_real[col] = df_real[col].apply(lambda x: filter(x))
|
433 |
+
df_real = df_real[df_real[col] != '']
|
434 |
+
print(df_real.shape)
|
435 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
436 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
437 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
438 |
+
df_real = prop(df_real)
|
439 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
440 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
441 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
442 |
+
|
443 |
+
# 预测text转为go
|
444 |
+
df_pred = pd.read_csv(
|
445 |
+
f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_exp_{cat_id[ont]}.txt', sep='|',
|
446 |
+
header=None, on_bad_lines='skip')
|
447 |
+
df_pred.columns = ['name', 'pred', 'label']
|
448 |
+
n0 = df_pred.shape[0]
|
449 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
450 |
+
print("{}中有{}条数据未能找到相似度高的GO描述".format(ont, n0 - df_pred.shape[0]))
|
451 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv(
|
452 |
+
f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_{cat_id[ont]}_standard.csv', sep='|',
|
453 |
+
index=False)
|
454 |
+
|
455 |
+
df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
|
456 |
+
df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
|
457 |
+
ipr_and_pred = []
|
458 |
+
for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
|
459 |
+
try:
|
460 |
+
ipr_and_pred.append(x + y)
|
461 |
+
except:
|
462 |
+
ipr_and_pred.append(y)
|
463 |
+
df_pred['ipr_and_pred'] = ipr_and_pred
|
464 |
+
|
465 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
466 |
+
df_pred = df_pred.dropna()
|
467 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
468 |
+
f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/{split}_data_{cat_id[ont]}.pkl')
|
469 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
470 |
+
f'/cluster/home/wenkai/deepgo2/data/{ont}/{split}_data_{cat_id[ont]}.pkl')
|
471 |
+
if split == 'val':
|
472 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
473 |
+
f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/valid_data_{cat_id[ont]}.pkl')
|
474 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
475 |
+
f'/cluster/home/wenkai/deepgo2/data/{ont}/valid_data_{cat_id[ont]}.pkl')
|
476 |
+
print(f"{ont} {split} deepgozero propagation data completed")
|
477 |
+
# '''
|
data/evaluate_data/process_case.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from utils import Ontology
|
3 |
+
|
4 |
+
|
5 |
+
def prop(df):
|
6 |
+
prop_annotations = []
|
7 |
+
for i, row in df.iterrows():
|
8 |
+
# Propagate annotations
|
9 |
+
annot_set = set()
|
10 |
+
annots = row['GO_label']
|
11 |
+
for go_id in annots:
|
12 |
+
annot_set |= godb.get_anchestors(go_id)
|
13 |
+
annots = list(annot_set)
|
14 |
+
prop_annotations.append(annots)
|
15 |
+
df['prop_annotations'] = prop_annotations
|
16 |
+
return df
|
17 |
+
|
18 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
19 |
+
|
20 |
+
case_mf = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_mf.csv', sep='|')
|
21 |
+
|
22 |
+
# bp case, 包括辣椒受体
|
23 |
+
case_bp = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_bp.csv', sep='|')
|
24 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
25 |
+
case_bp = prop(case_bp)
|
26 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: '; '.join(x))
|
27 |
+
case_bp['prop_annotations'] = case_bp['prop_annotations'].apply(lambda x: '; '.join(x))
|
28 |
+
case_bp[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/bp/cases_data.pkl')
|
29 |
+
|
30 |
+
case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
31 |
+
case_mf = prop(case_mf)
|
32 |
+
case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: '; '.join(x))
|
33 |
+
case_mf['prop_annotations'] = case_mf['prop_annotations'].apply(lambda x: '; '.join(x))
|
34 |
+
|
35 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
36 |
+
case_bp = prop(case_bp)
|
37 |
+
case_mf[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/mf/cases_data_445772.pkl')
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
data/evaluate_data/utils.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque, Counter
|
2 |
+
import warnings
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from xml.etree import ElementTree as ET
|
6 |
+
import math
|
7 |
+
|
8 |
+
BIOLOGICAL_PROCESS = 'GO:0008150'
|
9 |
+
MOLECULAR_FUNCTION = 'GO:0003674'
|
10 |
+
CELLULAR_COMPONENT = 'GO:0005575'
|
11 |
+
FUNC_DICT = {
|
12 |
+
'cc': CELLULAR_COMPONENT,
|
13 |
+
'mf': MOLECULAR_FUNCTION,
|
14 |
+
'bp': BIOLOGICAL_PROCESS}
|
15 |
+
|
16 |
+
NAMESPACES = {
|
17 |
+
'cc': 'cellular_component',
|
18 |
+
'mf': 'molecular_function',
|
19 |
+
'bp': 'biological_process'
|
20 |
+
}
|
21 |
+
|
22 |
+
EXP_CODES = set([
|
23 |
+
'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC',
|
24 |
+
'HTP', 'HDA', 'HMP', 'HGI', 'HEP'])
|
25 |
+
|
26 |
+
# CAFA4 Targets
|
27 |
+
CAFA_TARGETS = set([
|
28 |
+
'287', '3702', '4577', '6239', '7227', '7955', '9606', '9823', '10090',
|
29 |
+
'10116', '44689', '83333', '99287', '226900', '243273', '284812', '559292'])
|
30 |
+
|
31 |
+
|
32 |
+
def is_cafa_target(org):
|
33 |
+
return org in CAFA_TARGETS
|
34 |
+
|
35 |
+
|
36 |
+
def is_exp_code(code):
|
37 |
+
return code in EXP_CODES
|
38 |
+
|
39 |
+
|
40 |
+
def get_goplus_defs(filename='data/definitions.txt'):
|
41 |
+
plus_defs = {}
|
42 |
+
with open(filename) as f:
|
43 |
+
for line in f:
|
44 |
+
line = line.strip()
|
45 |
+
go_id, definition = line.split(': ')
|
46 |
+
go_id = go_id.replace('_', ':')
|
47 |
+
definition = definition.replace('_', ':')
|
48 |
+
plus_defs[go_id] = set(definition.split(' and '))
|
49 |
+
return plus_defs
|
50 |
+
|
51 |
+
|
52 |
+
class Ontology(object):
|
53 |
+
|
54 |
+
def __init__(self, filename='data/go.obo', with_rels=False):
|
55 |
+
self.ont = self.load(filename, with_rels)
|
56 |
+
self.ic = None
|
57 |
+
self.ic_norm = 0.0
|
58 |
+
|
59 |
+
def has_term(self, term_id):
|
60 |
+
return term_id in self.ont
|
61 |
+
|
62 |
+
def get_term(self, term_id):
|
63 |
+
if self.has_term(term_id):
|
64 |
+
return self.ont[term_id]
|
65 |
+
return None
|
66 |
+
|
67 |
+
def calculate_ic(self, annots):
|
68 |
+
cnt = Counter()
|
69 |
+
for x in annots:
|
70 |
+
cnt.update(x)
|
71 |
+
self.ic = {}
|
72 |
+
for go_id, n in cnt.items():
|
73 |
+
parents = self.get_parents(go_id)
|
74 |
+
if len(parents) == 0:
|
75 |
+
min_n = n
|
76 |
+
else:
|
77 |
+
min_n = min([cnt[x] for x in parents])
|
78 |
+
|
79 |
+
self.ic[go_id] = math.log(min_n / n, 2)
|
80 |
+
self.ic_norm = max(self.ic_norm, self.ic[go_id])
|
81 |
+
|
82 |
+
def get_ic(self, go_id):
|
83 |
+
if self.ic is None:
|
84 |
+
raise Exception('Not yet calculated')
|
85 |
+
if go_id not in self.ic:
|
86 |
+
return 0.0
|
87 |
+
return self.ic[go_id]
|
88 |
+
|
89 |
+
def get_norm_ic(self, go_id):
|
90 |
+
return self.get_ic(go_id) / self.ic_norm
|
91 |
+
|
92 |
+
def load(self, filename, with_rels):
|
93 |
+
ont = dict()
|
94 |
+
obj = None
|
95 |
+
with open(filename, 'r') as f:
|
96 |
+
for line in f:
|
97 |
+
line = line.strip()
|
98 |
+
if not line:
|
99 |
+
continue
|
100 |
+
if line == '[Term]':
|
101 |
+
if obj is not None:
|
102 |
+
ont[obj['id']] = obj
|
103 |
+
obj = dict()
|
104 |
+
obj['is_a'] = list()
|
105 |
+
obj['part_of'] = list()
|
106 |
+
obj['regulates'] = list()
|
107 |
+
obj['alt_ids'] = list()
|
108 |
+
obj['is_obsolete'] = False
|
109 |
+
continue
|
110 |
+
elif line == '[Typedef]':
|
111 |
+
if obj is not None:
|
112 |
+
ont[obj['id']] = obj
|
113 |
+
obj = None
|
114 |
+
else:
|
115 |
+
if obj is None:
|
116 |
+
continue
|
117 |
+
l = line.split(": ")
|
118 |
+
if l[0] == 'id':
|
119 |
+
obj['id'] = l[1]
|
120 |
+
elif l[0] == 'alt_id':
|
121 |
+
obj['alt_ids'].append(l[1])
|
122 |
+
elif l[0] == 'namespace':
|
123 |
+
obj['namespace'] = l[1]
|
124 |
+
elif l[0] == 'is_a':
|
125 |
+
obj['is_a'].append(l[1].split(' ! ')[0])
|
126 |
+
elif with_rels and l[0] == 'relationship':
|
127 |
+
it = l[1].split()
|
128 |
+
# add all types of relationships
|
129 |
+
obj['is_a'].append(it[1])
|
130 |
+
elif l[0] == 'name':
|
131 |
+
obj['name'] = l[1]
|
132 |
+
elif l[0] == 'is_obsolete' and l[1] == 'true':
|
133 |
+
obj['is_obsolete'] = True
|
134 |
+
if obj is not None:
|
135 |
+
ont[obj['id']] = obj
|
136 |
+
for term_id in list(ont.keys()):
|
137 |
+
for t_id in ont[term_id]['alt_ids']:
|
138 |
+
ont[t_id] = ont[term_id]
|
139 |
+
if ont[term_id]['is_obsolete']:
|
140 |
+
del ont[term_id]
|
141 |
+
for term_id, val in ont.items():
|
142 |
+
if 'children' not in val:
|
143 |
+
val['children'] = set()
|
144 |
+
for p_id in val['is_a']:
|
145 |
+
if p_id in ont:
|
146 |
+
if 'children' not in ont[p_id]:
|
147 |
+
ont[p_id]['children'] = set()
|
148 |
+
ont[p_id]['children'].add(term_id)
|
149 |
+
|
150 |
+
return ont
|
151 |
+
|
152 |
+
def get_anchestors(self, term_id):
|
153 |
+
if term_id not in self.ont:
|
154 |
+
return set()
|
155 |
+
term_set = set()
|
156 |
+
q = deque()
|
157 |
+
q.append(term_id)
|
158 |
+
while (len(q) > 0):
|
159 |
+
t_id = q.popleft()
|
160 |
+
if t_id not in term_set:
|
161 |
+
term_set.add(t_id)
|
162 |
+
for parent_id in self.ont[t_id]['is_a']:
|
163 |
+
if parent_id in self.ont:
|
164 |
+
q.append(parent_id)
|
165 |
+
return term_set
|
166 |
+
|
167 |
+
def get_prop_terms(self, terms):
|
168 |
+
prop_terms = set()
|
169 |
+
|
170 |
+
for term_id in terms:
|
171 |
+
prop_terms |= self.get_anchestors(term_id)
|
172 |
+
return prop_terms
|
173 |
+
|
174 |
+
def get_parents(self, term_id):
|
175 |
+
if term_id not in self.ont:
|
176 |
+
return set()
|
177 |
+
term_set = set()
|
178 |
+
for parent_id in self.ont[term_id]['is_a']:
|
179 |
+
if parent_id in self.ont:
|
180 |
+
term_set.add(parent_id)
|
181 |
+
return term_set
|
182 |
+
|
183 |
+
def get_namespace_terms(self, namespace):
|
184 |
+
terms = set()
|
185 |
+
for go_id, obj in self.ont.items():
|
186 |
+
if obj['namespace'] == namespace:
|
187 |
+
terms.add(go_id)
|
188 |
+
return terms
|
189 |
+
|
190 |
+
def get_namespace(self, term_id):
|
191 |
+
return self.ont[term_id]['namespace']
|
192 |
+
|
193 |
+
def get_term_set(self, term_id):
|
194 |
+
if term_id not in self.ont:
|
195 |
+
return set()
|
196 |
+
term_set = set()
|
197 |
+
q = deque()
|
198 |
+
q.append(term_id)
|
199 |
+
while len(q) > 0:
|
200 |
+
t_id = q.popleft()
|
201 |
+
if t_id not in term_set:
|
202 |
+
term_set.add(t_id)
|
203 |
+
for ch_id in self.ont[t_id]['children']:
|
204 |
+
q.append(ch_id)
|
205 |
+
return term_set
|
206 |
+
|
207 |
+
|
208 |
+
def read_fasta(filename):
|
209 |
+
seqs = list()
|
210 |
+
info = list()
|
211 |
+
seq = ''
|
212 |
+
inf = ''
|
213 |
+
with open(filename, 'r') as f:
|
214 |
+
for line in f:
|
215 |
+
line = line.strip()
|
216 |
+
if line.startswith('>'):
|
217 |
+
if seq != '':
|
218 |
+
seqs.append(seq)
|
219 |
+
info.append(inf)
|
220 |
+
seq = ''
|
221 |
+
inf = line[1:].split()[0]
|
222 |
+
else:
|
223 |
+
seq += line
|
224 |
+
seqs.append(seq)
|
225 |
+
info.append(inf)
|
226 |
+
return info, seqs
|
227 |
+
|
228 |
+
|
229 |
+
class DataGenerator(object):
|
230 |
+
|
231 |
+
def __init__(self, batch_size, is_sparse=False):
|
232 |
+
self.batch_size = batch_size
|
233 |
+
self.is_sparse = is_sparse
|
234 |
+
|
235 |
+
def fit(self, inputs, targets=None):
|
236 |
+
self.start = 0
|
237 |
+
self.inputs = inputs
|
238 |
+
self.targets = targets
|
239 |
+
if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
|
240 |
+
self.size = self.inputs[0].shape[0]
|
241 |
+
else:
|
242 |
+
self.size = self.inputs.shape[0]
|
243 |
+
self.has_targets = targets is not None
|
244 |
+
|
245 |
+
def __next__(self):
|
246 |
+
return self.next()
|
247 |
+
|
248 |
+
def reset(self):
|
249 |
+
self.start = 0
|
250 |
+
|
251 |
+
def next(self):
|
252 |
+
if self.start < self.size:
|
253 |
+
batch_index = np.arange(
|
254 |
+
self.start, min(self.size, self.start + self.batch_size))
|
255 |
+
if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
|
256 |
+
res_inputs = []
|
257 |
+
for inp in self.inputs:
|
258 |
+
if self.is_sparse:
|
259 |
+
res_inputs.append(
|
260 |
+
inp[batch_index, :].toarray())
|
261 |
+
else:
|
262 |
+
res_inputs.append(inp[batch_index, :])
|
263 |
+
else:
|
264 |
+
if self.is_sparse:
|
265 |
+
res_inputs = self.inputs[batch_index, :].toarray()
|
266 |
+
else:
|
267 |
+
res_inputs = self.inputs[batch_index, :]
|
268 |
+
self.start += self.batch_size
|
269 |
+
if self.has_targets:
|
270 |
+
if self.is_sparse:
|
271 |
+
labels = self.targets[batch_index, :].toarray()
|
272 |
+
else:
|
273 |
+
labels = self.targets[batch_index, :]
|
274 |
+
return (res_inputs, labels)
|
275 |
+
return res_inputs
|
276 |
+
else:
|
277 |
+
self.reset()
|
278 |
+
return self.next()
|
279 |
+
|
280 |
+
|
data/fasta/example.fasta
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
>P18281
|
2 |
+
MNPELQSAIGQGAALKHAETVDKSAPQIENVTVKKVDRSSFLEEVAKPHELKHAETVDKSGPAIPEDVHVKKVDRGAFLSEIEKAAKQ
|
data/fasta/prepare_custom_fasta.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# prepare fasta data
|
2 |
+
name_list = ['P18281']
|
3 |
+
sequence_list = ['MNPELQSAIGQGAALKHAETVDKSAPQIENVTVKKVDRSSFLEEVAKPHELKHAETVDKSGPAIPEDVHVKKVDRGAFLSEIEKAAKQ']
|
4 |
+
with open('example.fasta', 'w') as f:
|
5 |
+
for i, j in zip(name_list, sequence_list):
|
6 |
+
f.write('>{}\n'.format(i))
|
7 |
+
f.write('{}\n'.format(j))
|
data/go1.4-basic.obo
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3da20cc774d666b4338446bc81341eaf536885dc10ccb667480a79f6b964aa3c
|
3 |
+
size 31134256
|
data/go_descriptions1.4.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/test_exp_prompt_bp_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/test_exp_prompt_cc_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/test_exp_prompt_mf_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/train_exp_prompt_bp_new.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12359211ab95f1ce1962b69f033b55e9f502a7527f49414792d1c117ec50b0be
|
3 |
+
size 28503657
|
data/swissprot_exp/train_exp_prompt_cc_new.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01c6144b0e338d3ce8ce98adfd4f9d09f56dc58cd347f4fbaafb6782d694ffd1
|
3 |
+
size 23292609
|
data/swissprot_exp/train_exp_prompt_mf_new.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca3eee941dfc0ee37f59adec6abf8a7276441f04c484a9275274c7003ef4145e
|
3 |
+
size 18791760
|
data/swissprot_exp/val_exp_prompt_bp_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/val_exp_prompt_cc_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/swissprot_exp/val_exp_prompt_mf_new.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/terms/bp_terms.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4952f3551e4fe205640b81f9a1816c15c14cc889bbe55f57d378fb3c6d57f2f7
|
3 |
+
size 274892
|
data/terms/cc_terms.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20992c211336c4f876c920c2995ae85c1422e8742b7094c997aa70ddec7fc8fd
|
3 |
+
size 39440
|
data/terms/mf_terms.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:192861bad821ef3523ab2dcdd1db5eac093364e9b9b4869f75587d656864d29b
|
3 |
+
size 107802
|