fadliaulawi's picture
Initial commit
fb4710e
raw
history blame
No virus
7.3 kB
from datetime import datetime
from dotenv import load_dotenv
from img2table.document import Image
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pdf2image import convert_from_path
from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table
from table_detector import detection_transform, device, model, ocr, outputs_to_objects
import io
import json
import os
import pandas as pd
import re
import torch
load_dotenv()
llm = ChatOpenAI(temperature=0, model_name="gpt-4-0125-preview")
llm_p = ChatOpenAI(temperature=0, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
prompts = {
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine],
'all': [prompt_entities_chunk, prompt_entities_combine]
}
def get_entity(data):
chunks, types = data
map_template = prompts[types][0]
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
reduce_template = prompts[types][1]
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
combine_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="doc_summaries"
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_chain,
collapse_documents_chain=combine_chain,
token_max=100000,
)
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
result = map_reduce_chain.invoke(chunks)['output_text']
print(types)
print(result)
if types != 'summ':
result = re.findall('(\{[^}]+\})', result)[0]
return eval(result)
return result
def get_entity_one(chunks):
result = llm.invoke(prompt_entity_one_chunk.format(chunks)).content
print('One')
print(result)
result = re.findall('(\{[^}]+\})', result)[0]
return eval(result)
def get_table(path):
start_time = datetime.now()
images = convert_from_path(path)
print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
tables = []
# Loop pages
for image in images:
pixel_values = detection_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
# Loop table in page (if any)
for idx in range(len(detected_tables)):
cropped_table = image.crop(detected_tables[idx]["bbox"])
if detected_tables[idx]["label"] == 'table rotated':
cropped_table = cropped_table.rotate(270, expand=True)
# TODO: what is the perfect threshold?
if detected_tables[idx]['score'] > 0.9:
print(detected_tables[idx])
tables.append(cropped_table)
print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
genes = []
snps = []
diseases = []
# Loop tables
for table in tables:
buffer = io.BytesIO()
table.save(buffer, format='PNG')
image = Image(buffer)
# Extract to dataframe
extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0)
if len(extracted_tables) == 0:
continue
# Combine multiple dataframe
df_table = extracted_tables[0].df
for extracted_table in extracted_tables[1:]:
df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
df_table.loc[0] = df_table.loc[0].fillna('')
# Identify multiple rows (in dataframe) as one row (in image)
rows = []
indexes = []
for i in df_table.index:
if not df_table.loc[i].isna().any():
if len(indexes) > 0:
rows.append(indexes)
indexes = []
indexes.append(i)
rows.append(indexes)
df_table_cleaned = pd.DataFrame(columns=df_table.columns)
for row in rows:
row_str = df_table.loc[row[0]]
for idx in row[1:]:
row_str += ' ' + df_table.loc[idx].fillna('')
row_str = row_str.str.strip()
df_table_cleaned.loc[len(df_table_cleaned)] = row_str
# Ask LLM with JSON data
json_table = df_table_cleaned.to_json(orient='records')
str_json_table = json.dumps(json.loads(json_table), indent=2)
result = llm.invoke(prompt_table.format(str_json_table)).content
print('table')
print(result)
result = result[result.find('['):result.rfind(']')+1]
try:
result = eval(result)
except SyntaxError:
result = []
for res in result:
res_gene = res['Genes']
res_snp = res['SNPs']
res_disease = res['Diseases']
for snp in res_snp:
genes.append(res_gene)
snps.append(snp)
diseases.append(res_disease)
print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
print(genes, snps, diseases)
return genes, snps, diseases
def validate(df):
df = df.fillna('')
df['Genes'] = df['Genes'].str.upper()
df['SNPs'] = df['SNPs'].str.lower()
# Check if there is two gene names
sym = ['-', '/', '|']
for i in df.index:
gene = df.loc[i, 'Genes']
for s in sym:
if s in gene:
genes = gene.split(s)
df.loc[len(df)] = df.loc[i]
df.loc[i, 'Genes'] = genes[0]
df.loc[len(df) - 1, 'Genes'] = genes[1]
# Check if there is SNPs without 'rs'
for i in df.index:
safe = True
snp = df.loc[i, 'SNPs']
if not re.fullmatch('rs(\d)+|', snp):
if not re.fullmatch('s(\d)+', snp):
if not re.fullmatch('(\d)+', snp):
safe = False
df = df.drop(i)
else:
snp = 'rs' + snp
else:
snp = 'r' + snp
if safe:
df.loc[i, 'SNPs'] = snp
df.reset_index(drop=True, inplace=True)
# TODO: How to validate genes and SNPs?
# TODO: Validate genes and diseases with LLM
result = llm_p.invoke(model='mistral-7b-instruct', input='How many stars?')
return df