metisllm-dashboard / extraction_pipeline /pdf_to_knowledge_graph_transform.py
Gateston Johns
first real commit
9041389
raw
history blame
6.59 kB
import argparse
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Iterable, Optional
from domain.chunk_d import ChunkD, DocumentD
from domain.entity_d import EntityKnowledgeGraphD
from extraction_pipeline.base_stage import BaseTransform
from extraction_pipeline.document_metadata_extractor.openai_document_metadata_extractor import (
OpenAIDocumentMetadataExtractor,)
from extraction_pipeline.pdf_process_stage import (
PdfToPageTransform,
PdfToParagraphTransform,
PdfToSentencesTransform,
)
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
RelationshipExtractor,)
from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
OpenAIRelationshipExtractor,)
from llm_handler.openai_handler import OpenAIHandler
from storage.domain_dao import InMemDomainDAO
from storage.neo4j_dao import Neo4jDomainDAO
class PdfToKnowledgeGraphTransform(BaseTransform[DocumentD, EntityKnowledgeGraphD]):
_metadata_extractor: OpenAIDocumentMetadataExtractor
_pdf_chunker: BaseTransform[DocumentD, ChunkD]
_relationship_extractor: RelationshipExtractor
def __init__(self, pdf_chunker: BaseTransform[DocumentD, ChunkD]):
openai_handler = OpenAIHandler()
self._metadata_extractor = OpenAIDocumentMetadataExtractor(openai_handler=openai_handler)
self._pdf_chunker = pdf_chunker
self._relationship_extractor = OpenAIRelationshipExtractor(openai_handler=openai_handler)
def _process_collection(self,
collection: Iterable[DocumentD]) -> Iterable[EntityKnowledgeGraphD]:
# produce 1 EntityKnowledgeGraphD per DocumentD
for pdf_document in collection:
# metadata extractor only yields 1 filled in DocumentD for each input DocumentD
document = next(iter(self._metadata_extractor.process_element(pdf_document)))
entity_relationships = []
for pdf_chunk in self._pdf_chunker.process_collection([document]):
for relationship in self._relationship_extractor.process_element(pdf_chunk):
entity_relationships.append(relationship)
yield EntityKnowledgeGraphD(entity_relationships=entity_relationships)
if __name__ == '__main__':
## CLI Arguments for running transform as multi-threaded script
parser = argparse.ArgumentParser(description='Extract knowledge graphs from PDF files')
parser.add_argument('--pdf_folder',
type=str,
help='Path to folder of PDF files to process',
default='')
parser.add_argument('--pdf_file',
type=str,
help='Path to the one PDF file to process',
default='')
parser.add_argument('--output_json_file',
type=str,
help='Path for output json file of knowledge graphs',
default='./knowledge_graphs.json')
parser.add_argument('--log_folder', type=str, help='Path to log folder', default='log')
parser.add_argument('--chunk_to',
type=str,
help='What level to chunk PDF text',
default='page',
choices=['page', 'paragraph', 'sentence'])
parser.add_argument('--verbose', help='Enable DEBUG level logs', action='store_true')
parser.add_argument('--upload_to_neo4j',
help='Enable uploads to Neo4j database',
action='store_true')
args = parser.parse_args()
## Setup logging
if args.verbose:
log_level = logging.DEBUG
else:
log_level = logging.INFO
os.makedirs(args.log_folder, exist_ok=True)
script_name = os.path.splitext(os.path.basename(__file__))[0]
logging.basicConfig(level=log_level,
format='%(asctime)s - %(levelname)s - %(message)s',
filename=f'{args.log_folder}/{script_name}.log',
filemode='w')
logger = logging.getLogger(__name__)
## Setup PDF Chunking
if args.chunk_to == 'page':
pdf_transform = PdfToPageTransform()
elif args.chunk_to == 'paragraph':
pdf_transform = PdfToParagraphTransform()
elif args.chunk_to == 'sentence':
pdf_transform = PdfToSentencesTransform()
else:
logging.error('Invalid chunking level: %s', args.chunk_to)
sys.exit(1)
## Process PDF Files
if not args.pdf_folder and not args.pdf_file:
logging.error('No PDF file or folder provided')
sys.exit(1)
elif args.pdf_folder:
pdf_folder = Path(args.pdf_folder)
if not pdf_folder.exists():
logging.error('PDF folder does not exist: %s', pdf_folder)
sys.exit(1)
pdf_files = list(pdf_folder.glob('*.pdf'))
if len(pdf_files) == 0:
logging.warning('No PDF files found in folder: %s', pdf_folder)
sys.exit(0)
pdfs = [
DocumentD(file_path=str(pdf_file), authors='', publish_date='')
for pdf_file in pdf_files
]
else:
pdf_file = Path(args.pdf_file)
if not pdf_file.exists():
logging.error('PDF file does not exist: %s', pdf_file)
sys.exit(1)
pdfs = [DocumentD(file_path=str(pdf_file), authors='', publish_date='')]
pdf_to_kg = PdfToKnowledgeGraphTransform(pdf_transform)
def process_pdf(pdf: DocumentD) -> tuple[Optional[EntityKnowledgeGraphD], str]:
pdf_name = Path(pdf.file_path).name
try:
# process collection yields 1 KG per pdf but we are only
# inputing 1 PDF at a time so we just need the 1st element
return list(pdf_to_kg.process_collection([pdf]))[0], pdf_name
except Exception as e:
logging.error(f"Error processing pdf: {e}")
return None, pdf_name
results: list[EntityKnowledgeGraphD] = []
with ThreadPoolExecutor() as executor, Neo4jDomainDAO() as dao:
futures = [executor.submit(process_pdf, pdf) for pdf in pdfs]
for future in as_completed(futures):
kg, pdf_name = future.result()
if not kg:
continue
results.append(kg)
if args.upload_to_neo4j:
dao.insert(kg, pdf_name)
dao = InMemDomainDAO()
dao.insert(results)
dao.save_to_file(args.output_json_file)