File size: 4,526 Bytes
8503206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from prompt import *
from utils import call, permutate

import os
import json
import pandas as pd
import re

load_dotenv()

class Validation():

    def __init__(self, llm):

        if llm.startswith('gpt'):
            self.llm = ChatOpenAI(temperature=0, model_name=llm)
        elif llm.startswith('gemini'):
            self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
        else:
            self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")

    def validate(self, df, api):

        df = df.fillna('')
        df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
        df['SNPs'] = df['SNPs'].str.lower()

        # Check if there is two gene names
        sym = [',', '/', '|', '-']
        for i in df.index:
            gene = df.loc[i, 'Genes']
            for s in sym:
                if s in gene:
                    genes = gene.split(s)
                    df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
                    df = df.sort_index().reset_index(drop=True)
                    df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:])
                    break

        # Check if there is SNPs without 'rs'
        for i in df.index:
            safe = True
            snp = df.loc[i, 'SNPs']
            snp = snp.replace('l', '1')
            if re.fullmatch('rs(\d)+|', snp):
                pass
            elif re.fullmatch('ts(\d)+', snp):
                snp = 'r' + snp[1:]
            elif re.fullmatch('s(\d)+', snp):
                snp = 'r' + snp
            elif re.fullmatch('(\d)+', snp):
                snp = 'rs' + snp
            else:
                safe = False
                df = df.drop(i)

            if safe:
                df.loc[i, 'SNPs'] = snp

        df.reset_index(drop=True, inplace=True)
        df_clean = df.copy()

        # Validate genes and SNPs with APIs
        if api:
            dbsnp = {}
            for i in df.index:
                snp = df.loc[i, 'SNPs']
                gene = df.loc[i, 'Genes']

                if snp not in dbsnp:
                    res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/')
                    try:
                        res = res.json()
                        dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']]
                    except:
                        print("Error at first API", e)
                        dbsnp[snp] = []

                    try:
                        res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
                        if 'error' not in res:
                            dbsnp[snp].extend([r['name'] for r in res['genes']])
                    except Exception as e:
                        print("Error at second API", e)
                        pass

                    dbsnp[snp] = list(set(dbsnp[snp]))

                if gene not in dbsnp[snp]:
                    for other in permutate(gene):
                        if other in dbsnp[snp]:
                            df.loc[i, 'Genes'] = other
                            print(f'{gene} corrected to {other}')
                            break
                    else:
                        df = df.drop(i)

        df.reset_index(drop=True, inplace=True)
        df_no_llm = df.copy()

        # Validate genes and diseases with LLM (for each 50 rows)
        idx = 0
        results = []

        while True:
            json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records')
            str_json_table = json.dumps(json.loads(json_table), indent=2)

            result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
            print('val', idx)
            print(result)

            result = result[result.find('['):result.rfind(']')+1]
            try:
                result = eval(result)
            except SyntaxError:
                result = []

            results.extend(result)
            idx += 50
            if idx not in df.index:
                break

        df = pd.DataFrame(results)
        df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')

        return df, df_no_llm, df_clean