from dotenv import load_dotenv from img2table.document import Image from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from pdf2image import convert_from_path from prompt import * from table_detector import detection_transform, device, model, ocr, outputs_to_objects import io import json import os import pandas as pd import re import torch load_dotenv() prompts = { 'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], 'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], 'all': [prompt_entities_chunk, prompt_entities_combine] } class Process(): def __init__(self, llm): if llm.startswith('gpt'): self.llm = ChatOpenAI(temperature=0, model_name=llm) elif llm.startswith('gemini'): self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) else: self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") def get_entity(self, data): chunks, types = data map_template = prompts[types][0] map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=self.llm, prompt=map_prompt) reduce_template = prompts[types][1] reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt) combine_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="doc_summaries" ) reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_chain, collapse_documents_chain=combine_chain, token_max=100000, ) map_reduce_chain = MapReduceDocumentsChain( llm_chain=map_chain, reduce_documents_chain=reduce_documents_chain, document_variable_name="docs", return_intermediate_steps=False, ) result = map_reduce_chain.invoke(chunks)['output_text'] print(types) print(result) if types != 'summ': result = re.findall('(\{[^}]+\})', result)[0] return eval(result) return result def get_entity_one(self, chunks): result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content print('One') print(result) result = re.findall('(\{[^}]+\})', result)[0] return eval(result) def get_table(self, path): images = convert_from_path(path) tables = [] # Loop pages for image in images: pixel_values = detection_transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) # Loop table in page (if any) for idx in range(len(detected_tables)): cropped_table = image.crop(detected_tables[idx]["bbox"]) if detected_tables[idx]["label"] == 'table rotated': cropped_table = cropped_table.rotate(270, expand=True) # TODO: what is the perfect threshold? if detected_tables[idx]['score'] > 0.9: print(detected_tables[idx]) tables.append(cropped_table) genes = [] snps = [] diseases = [] # Loop tables for table in tables: buffer = io.BytesIO() table.save(buffer, format='PNG') image = Image(buffer) # Extract to dataframe extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) if len(extracted_tables) == 0: continue # Combine multiple dataframe df_table = extracted_tables[0].df for extracted_table in extracted_tables[1:]: df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) df_table = df_table.fillna('') # Ask LLM with JSON data json_table = df_table.to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = self.llm.invoke(prompt_table.format(str_json_table)).content print('table') print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] for res in result: res_gene = res['Genes'] res_snp = res['SNPs'] res_disease = res['Diseases'] for snp in res_snp: genes.append(res_gene) snps.append(snp) diseases.append(res_disease) print(genes, snps, diseases) return genes, snps, diseases