File size: 4,526 Bytes
8503206 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from prompt import *
from utils import call, permutate
import os
import json
import pandas as pd
import re
load_dotenv()
class Validation():
def __init__(self, llm):
if llm.startswith('gpt'):
self.llm = ChatOpenAI(temperature=0, model_name=llm)
elif llm.startswith('gemini'):
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
else:
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
def validate(self, df, api):
df = df.fillna('')
df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
df['SNPs'] = df['SNPs'].str.lower()
# Check if there is two gene names
sym = [',', '/', '|', '-']
for i in df.index:
gene = df.loc[i, 'Genes']
for s in sym:
if s in gene:
genes = gene.split(s)
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
df = df.sort_index().reset_index(drop=True)
df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:])
break
# Check if there is SNPs without 'rs'
for i in df.index:
safe = True
snp = df.loc[i, 'SNPs']
snp = snp.replace('l', '1')
if re.fullmatch('rs(\d)+|', snp):
pass
elif re.fullmatch('ts(\d)+', snp):
snp = 'r' + snp[1:]
elif re.fullmatch('s(\d)+', snp):
snp = 'r' + snp
elif re.fullmatch('(\d)+', snp):
snp = 'rs' + snp
else:
safe = False
df = df.drop(i)
if safe:
df.loc[i, 'SNPs'] = snp
df.reset_index(drop=True, inplace=True)
df_clean = df.copy()
# Validate genes and SNPs with APIs
if api:
dbsnp = {}
for i in df.index:
snp = df.loc[i, 'SNPs']
gene = df.loc[i, 'Genes']
if snp not in dbsnp:
res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/')
try:
res = res.json()
dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']]
except:
print("Error at first API", e)
dbsnp[snp] = []
try:
res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
if 'error' not in res:
dbsnp[snp].extend([r['name'] for r in res['genes']])
except Exception as e:
print("Error at second API", e)
pass
dbsnp[snp] = list(set(dbsnp[snp]))
if gene not in dbsnp[snp]:
for other in permutate(gene):
if other in dbsnp[snp]:
df.loc[i, 'Genes'] = other
print(f'{gene} corrected to {other}')
break
else:
df = df.drop(i)
df.reset_index(drop=True, inplace=True)
df_no_llm = df.copy()
# Validate genes and diseases with LLM (for each 50 rows)
idx = 0
results = []
while True:
json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records')
str_json_table = json.dumps(json.loads(json_table), indent=2)
result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
print('val', idx)
print(result)
result = result[result.find('['):result.rfind(']')+1]
try:
result = eval(result)
except SyntaxError:
result = []
results.extend(result)
idx += 50
if idx not in df.index:
break
df = pd.DataFrame(results)
df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
return df, df_no_llm, df_clean |