Smart_AAS / unified_document_processor.py
TahaRasouli's picture
Create unified_document_processor.py
1abb1bd verified
raw
history blame
28 kB
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure NLTK's `punkt` tokenizer resource is available."""
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
print("Downloading NLTK 'punkt' tokenizer...")
nltk.download('punkt')
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def flatten_xml_to_text(self, element, depth=0) -> str:
"""Convert XML element and its children to a flat text representation"""
text_parts = []
element_info = f"Element: {element.tag}"
if element.attrib:
element_info += f", Attributes: {json.dumps(element.attrib)}"
if element.text and element.text.strip():
element_info += f", Text: {element.text.strip()}"
text_parts.append(element_info)
for child in element:
child_text = self.flatten_xml_to_text(child, depth + 1)
text_parts.append(child_text)
return "\n".join(text_parts)
def chunk_xml_text(self, text: str, max_chunk_size: int = 2000) -> List[str]:
"""Split flattened XML text into manageable chunks"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > max_chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def generate_natural_language(self, content: Union[List[Dict], str], content_type: str) -> str:
"""Generate natural language description with improved error handling and chunking"""
try:
if content_type == "xml":
prompt = f"Convert this XML structure description to a natural language summary: {content}"
else: # pdf
prompt = f"Summarize this text while preserving key information: {content}"
max_prompt_length = 4000
if len(prompt) > max_prompt_length:
prompt = prompt[:max_prompt_length] + "..."
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating natural language: {str(e)}")
if len(content) > 2000:
half_length = len(content) // 2
first_half = content[:half_length]
try:
return self.generate_natural_language(first_half, content_type)
except:
return None
return None
# Additional methods (unchanged but structured for easier review)...
def store_in_vector_db(self, natural_language: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[natural_language],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with improved chunking"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
flattened_text = self.flatten_xml_to_text(root)
chunks = self.chunk_xml_text(flattened_text)
print(f"Split XML into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing XML chunk {i+1}/{len(chunks)}")
try:
natural_language = self.generate_natural_language(chunk, "xml")
if natural_language:
metadata = {
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language'
})
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing PDF chunk {i+1}/{len(chunks)}")
natural_language = self.generate_natural_language(chunk, "pdf")
if natural_language:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language,
'original_text': chunk[:200] + "..."
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language summary'
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
context = "\n\n".join(results['documents'][0])
prompt = f"""Based on the following content from the selected files, please answer this question: {question}
Content:
{context}
Please provide a direct answer based only on the information provided above."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK - Updated to handle both resources
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure both NLTK resources are available."""
try:
nltk.download('punkt')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
except Exception as e:
print(f"Warning: Error downloading NLTK resources: {str(e)}")
print("Falling back to basic sentence splitting...")
def _basic_sentence_split(self, text: str) -> List[str]:
"""Fallback method for sentence tokenization"""
sentences = []
current = ""
for char in text:
current += char
if char in ['.', '!', '?'] and len(current.strip()) > 0:
sentences.append(current.strip())
current = ""
if current.strip():
sentences.append(current.strip())
return sentences
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
try:
sentences = sent_tokenize(text)
except Exception as e:
print(f"Warning: Using fallback sentence splitting: {str(e)}")
sentences = self._basic_sentence_split(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def flatten_xml_to_text(self, element, depth=0) -> str:
"""Convert XML element and its children to a flat text representation"""
text_parts = []
element_info = f"Element: {element.tag}"
if element.attrib:
element_info += f", Attributes: {json.dumps(element.attrib)}"
if element.text and element.text.strip():
element_info += f", Text: {element.text.strip()}"
text_parts.append(element_info)
for child in element:
child_text = self.flatten_xml_to_text(child, depth + 1)
text_parts.append(child_text)
return "\n".join(text_parts)
def chunk_xml_text(self, text: str, max_chunk_size: int = 2000) -> List[str]:
"""Split flattened XML text into manageable chunks"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > max_chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def generate_natural_language(self, content: Union[List[Dict], str], content_type: str) -> str:
"""Generate natural language description with improved error handling and chunking"""
try:
if content_type == "xml":
prompt = f"Convert this XML structure description to a natural language summary: {content}"
else: # pdf
prompt = f"Summarize this text while preserving key information: {content}"
max_prompt_length = 4000
if len(prompt) > max_prompt_length:
prompt = prompt[:max_prompt_length] + "..."
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating natural language: {str(e)}")
if len(content) > 2000:
half_length = len(content) // 2
first_half = content[:half_length]
try:
return self.generate_natural_language(first_half, content_type)
except:
return None
return None
def store_in_vector_db(self, natural_language: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[natural_language],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with improved chunking"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
flattened_text = self.flatten_xml_to_text(root)
chunks = self.chunk_xml_text(flattened_text)
print(f"Split XML into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing XML chunk {i+1}/{len(chunks)}")
try:
natural_language = self.generate_natural_language(chunk, "xml")
if natural_language:
metadata = {
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language'
})
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing PDF chunk {i+1}/{len(chunks)}")
natural_language = self.generate_natural_language(chunk, "pdf")
if natural_language:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language,
'original_text': chunk[:200] + "..."
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language summary'
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
context = "\n\n".join(results['documents'][0])
prompt = f"""Based on the following content from the selected files, please answer this question: {question}
Content:
{context}
Please provide a direct answer based only on the information provided above."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"