fadliaulawi commited on
Commit
8503206
1 Parent(s): 9c8e6da

Separate validation code

Browse files
Files changed (2) hide show
  1. utils.py +30 -0
  2. validate.py +130 -0
utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+
4
+ # Common mistakes need to be maintained
5
+ mistakes = {'I': '1', 'O': '0'}
6
+
7
+ def permutate(word):
8
+
9
+ if len(word) == 0:
10
+ return ['']
11
+
12
+ change = []
13
+ res = permutate(word[1:])
14
+
15
+ if word[0] in mistakes:
16
+ change = [mistakes[word[0]] + r for r in res]
17
+
18
+ return [word[0] + r for r in res] + change
19
+
20
+ def call(url):
21
+
22
+ while True:
23
+ try:
24
+ res = requests.get(url)
25
+ time.sleep(1)
26
+ break
27
+ except Exception as e:
28
+ print(e)
29
+
30
+ return res
validate.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from langchain_openai import ChatOpenAI
4
+ from prompt import *
5
+ from utils import call, permutate
6
+
7
+ import os
8
+ import json
9
+ import pandas as pd
10
+ import re
11
+
12
+ load_dotenv()
13
+
14
+ class Validation():
15
+
16
+ def __init__(self, llm):
17
+
18
+ if llm.startswith('gpt'):
19
+ self.llm = ChatOpenAI(temperature=0, model_name=llm)
20
+ elif llm.startswith('gemini'):
21
+ self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
22
+ else:
23
+ self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
24
+
25
+ def validate(self, df, api):
26
+
27
+ df = df.fillna('')
28
+ df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
29
+ df['SNPs'] = df['SNPs'].str.lower()
30
+
31
+ # Check if there is two gene names
32
+ sym = [',', '/', '|', '-']
33
+ for i in df.index:
34
+ gene = df.loc[i, 'Genes']
35
+ for s in sym:
36
+ if s in gene:
37
+ genes = gene.split(s)
38
+ df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
39
+ df = df.sort_index().reset_index(drop=True)
40
+ df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:])
41
+ break
42
+
43
+ # Check if there is SNPs without 'rs'
44
+ for i in df.index:
45
+ safe = True
46
+ snp = df.loc[i, 'SNPs']
47
+ snp = snp.replace('l', '1')
48
+ if re.fullmatch('rs(\d)+|', snp):
49
+ pass
50
+ elif re.fullmatch('ts(\d)+', snp):
51
+ snp = 'r' + snp[1:]
52
+ elif re.fullmatch('s(\d)+', snp):
53
+ snp = 'r' + snp
54
+ elif re.fullmatch('(\d)+', snp):
55
+ snp = 'rs' + snp
56
+ else:
57
+ safe = False
58
+ df = df.drop(i)
59
+
60
+ if safe:
61
+ df.loc[i, 'SNPs'] = snp
62
+
63
+ df.reset_index(drop=True, inplace=True)
64
+ df_clean = df.copy()
65
+
66
+ # Validate genes and SNPs with APIs
67
+ if api:
68
+ dbsnp = {}
69
+ for i in df.index:
70
+ snp = df.loc[i, 'SNPs']
71
+ gene = df.loc[i, 'Genes']
72
+
73
+ if snp not in dbsnp:
74
+ res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/')
75
+ try:
76
+ res = res.json()
77
+ dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']]
78
+ except:
79
+ print("Error at first API", e)
80
+ dbsnp[snp] = []
81
+
82
+ try:
83
+ res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
84
+ if 'error' not in res:
85
+ dbsnp[snp].extend([r['name'] for r in res['genes']])
86
+ except Exception as e:
87
+ print("Error at second API", e)
88
+ pass
89
+
90
+ dbsnp[snp] = list(set(dbsnp[snp]))
91
+
92
+ if gene not in dbsnp[snp]:
93
+ for other in permutate(gene):
94
+ if other in dbsnp[snp]:
95
+ df.loc[i, 'Genes'] = other
96
+ print(f'{gene} corrected to {other}')
97
+ break
98
+ else:
99
+ df = df.drop(i)
100
+
101
+ df.reset_index(drop=True, inplace=True)
102
+ df_no_llm = df.copy()
103
+
104
+ # Validate genes and diseases with LLM (for each 50 rows)
105
+ idx = 0
106
+ results = []
107
+
108
+ while True:
109
+ json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records')
110
+ str_json_table = json.dumps(json.loads(json_table), indent=2)
111
+
112
+ result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
113
+ print('val', idx)
114
+ print(result)
115
+
116
+ result = result[result.find('['):result.rfind(']')+1]
117
+ try:
118
+ result = eval(result)
119
+ except SyntaxError:
120
+ result = []
121
+
122
+ results.extend(result)
123
+ idx += 50
124
+ if idx not in df.index:
125
+ break
126
+
127
+ df = pd.DataFrame(results)
128
+ df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
129
+
130
+ return df, df_no_llm, df_clean