from datetime import datetime from dotenv import load_dotenv from img2table.document import Image from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from pdf2image import convert_from_path from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation from table_detector import detection_transform, device, model, ocr, outputs_to_objects import io import json import os import pandas as pd import re import requests import time import torch load_dotenv() prompts = { 'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], 'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], 'all': [prompt_entities_chunk, prompt_entities_combine] } class Process(): def __init__(self, llm, llm_val): if llm.startswith('gpt'): self.llm = ChatOpenAI(temperature=0, model_name=llm) elif llm.startswith('gemini'): self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) else: self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") if llm_val.startswith('gpt'): self.llm_val = ChatOpenAI(temperature=0, model_name=llm_val) elif llm_val.startswith('gemini'): self.llm_val = ChatGoogleGenerativeAI(temperature=0, model=llm_val) else: self.llm_val = ChatOpenAI(temperature=0, model_name=llm_val, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") def get_entity(self, data): chunks, types = data map_template = prompts[types][0] map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=self.llm, prompt=map_prompt) reduce_template = prompts[types][1] reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt) combine_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="doc_summaries" ) reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_chain, collapse_documents_chain=combine_chain, token_max=100000, ) map_reduce_chain = MapReduceDocumentsChain( llm_chain=map_chain, reduce_documents_chain=reduce_documents_chain, document_variable_name="docs", return_intermediate_steps=False, ) result = map_reduce_chain.invoke(chunks)['output_text'] print(types) print(result) if types != 'summ': result = re.findall('(\{[^}]+\})', result)[0] return eval(result) return result def get_entity_one(self, chunks): result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content print('One') print(result) result = re.findall('(\{[^}]+\})', result)[0] return eval(result) def get_table(self, path): start_time = datetime.now() images = convert_from_path(path) print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") tables = [] # Loop pages for image in images: pixel_values = detection_transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) # Loop table in page (if any) for idx in range(len(detected_tables)): cropped_table = image.crop(detected_tables[idx]["bbox"]) if detected_tables[idx]["label"] == 'table rotated': cropped_table = cropped_table.rotate(270, expand=True) # TODO: what is the perfect threshold? if detected_tables[idx]['score'] > 0.9: print(detected_tables[idx]) tables.append(cropped_table) print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") genes = [] snps = [] diseases = [] # Loop tables for table in tables: buffer = io.BytesIO() table.save(buffer, format='PNG') image = Image(buffer) # Extract to dataframe extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) if len(extracted_tables) == 0: continue # Combine multiple dataframe df_table = extracted_tables[0].df for extracted_table in extracted_tables[1:]: df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) df_table.loc[0] = df_table.loc[0].fillna('') # Identify multiple rows (in dataframe) as one row (in image) rows = [] indexes = [] for i in df_table.index: if not df_table.loc[i].isna().any(): if len(indexes) > 0: rows.append(indexes) indexes = [] indexes.append(i) rows.append(indexes) df_table_cleaned = pd.DataFrame(columns=df_table.columns) for row in rows: row_str = df_table.loc[row[0]] for idx in row[1:]: row_str += ' ' + df_table.loc[idx].fillna('') row_str = row_str.str.strip() df_table_cleaned.loc[len(df_table_cleaned)] = row_str # Ask LLM with JSON data json_table = df_table_cleaned.to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = self.llm.invoke(prompt_table.format(str_json_table)).content print('table') print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] for res in result: res_gene = res['Genes'] res_snp = res['SNPs'] res_disease = res['Diseases'] for snp in res_snp: genes.append(res_gene) snps.append(snp) diseases.append(res_disease) print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") print(genes, snps, diseases) return genes, snps, diseases def validate(self, df): df = df.fillna('') df['Genes'] = df['Genes'].str.replace(' ', '').str.upper() df['SNPs'] = df['SNPs'].str.lower() # Check if there is two gene names sym = [',', '/', '|'] for i in df.index: gene = df.loc[i, 'Genes'] for s in sym: if s in gene: genes = gene.split(s) df.loc[i + 0.5] = df.loc[i] df = df.sort_index().reset_index(drop=True) df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], s.join(genes[1:]) break # Check if there is SNPs without 'rs' for i in df.index: safe = True snp = df.loc[i, 'SNPs'] snp = snp.replace('l', '1') if re.fullmatch('rs(\d)+|', snp): pass elif re.fullmatch('ts(\d)+', snp): snp = 'r' + snp[1:] elif re.fullmatch('s(\d)+', snp): snp = 'r' + snp elif re.fullmatch('(\d)+', snp): snp = 'rs' + snp else: safe = False df = df.drop(i) if safe: df.loc[i, 'SNPs'] = snp df.reset_index(drop=True, inplace=True) df_clean = df.copy() # # Validate genes and SNPs with APIs def permutate(word): if len(word) == 0: return [''] change = [] res = permutate(word[1:]) if word[0] in mistakes: change = [mistakes[word[0]] + r for r in res] return [word[0] + r for r in res] + change def call(url): while True: try: res = requests.get(url) time.sleep(1) break except Exception as e: print(e) return res mistakes = {'I': '1', 'O': '0'} # Common mistakes need to be maintained dbsnp = {} for i in df.index: snp = df.loc[i, 'SNPs'] gene = df.loc[i, 'Genes'] if snp not in dbsnp: res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/') try: res = res.json() dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']] except: dbsnp[snp] = [] res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]] if 'error' not in res: dbsnp[snp].extend([r['name'] for r in res['genes']]) dbsnp[snp] = list(set(dbsnp[snp])) if gene not in dbsnp[snp]: for other in permutate(gene): if other in dbsnp[snp]: df.loc[i, 'Genes'] = other print(f'{gene} corrected to {other}') break else: df = df.drop(i) # df.reset_index(drop=True, inplace=True) df_no_llm = df.copy() # Validate genes and diseases with LLM (for each 50 rows) idx = 0 results = [] while True: json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = self.llm_val.invoke(input=prompt_validation.format(str_json_table)).content print('val', idx) print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] results.extend(result) idx += 50 if idx not in df.index: break df = pd.DataFrame(results) df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross') return df, df_no_llm, df_clean