llama-aa-fine-tuned / qanda_gen_script.py
Lukeam's picture
Upload qanda_gen_script.py with huggingface_hub
1bf5b03 verified
import json
from pathlib import Path
from datetime import datetime
import hashlib
class QAGenerator:
def __init__(self):
# Fix paths to stay within aa_book directory
self.base_dir = Path(__file__).parent # current directory (aa_book)
self.output_dir = self.base_dir / 'processed_data'
self.qa_dir = self.output_dir / 'qa_pairs'
self.raw_dir = self.output_dir / 'raw_extractions'
print(f"Looking for raw extractions in: {self.raw_dir}")
self.qa_dir.mkdir(parents=True, exist_ok=True)
self.manifest = []
def add_to_manifest(self, input_file, output_file, process_type, metadata):
"""Track transformations in manifest"""
manifest_entry = {
'timestamp': datetime.now().isoformat(),
'input_file': str(input_file),
'output_file': str(output_file),
'process_type': process_type,
'metadata': metadata
}
self.manifest.append(manifest_entry)
def generate_qa_pairs(self, text, source_info):
"""Generate Q&A pairs from text"""
qa_pairs = []
# Split into sections (chapters, paragraphs)
sections = text.split('\n\n')
for i, section in enumerate(sections):
if len(section.strip()) < 100: # Skip short sections
continue
# Generate different types of questions
qa_pairs.extend([
{
'question': f"What are the main points discussed in this section of {source_info['title']}?",
'answer': section.strip(),
'source': source_info,
'section_index': i,
'qa_type': 'main_points',
'timestamp': datetime.now().isoformat()
},
{
'question': f"Can you summarize the key concepts from this passage in {source_info['title']}?",
'answer': section.strip(),
'source': source_info,
'section_index': i,
'qa_type': 'summary',
'timestamp': datetime.now().isoformat()
}
])
# Add specific AA-related questions if relevant keywords are found
if any(word in section.lower() for word in ['step', 'tradition', 'recovery', 'sobriety']):
qa_pairs.append({
'question': f"What recovery principles or concepts are discussed in this section of {source_info['title']}?",
'answer': section.strip(),
'source': source_info,
'section_index': i,
'qa_type': 'aa_specific',
'timestamp': datetime.now().isoformat()
})
return qa_pairs
def process_all_sources(self):
"""Process all extracted texts into QA pairs"""
# Update path to look in the correct location
raw_dir = self.output_dir / 'raw_extractions'
if not raw_dir.exists():
raise FileNotFoundError(f"Directory not found: {raw_dir}. Please run extract_pdfs.py first.")
all_qa_pairs = []
sources_processed = []
for raw_file in raw_dir.glob('*_raw.json'):
print(f"\nProcessing {raw_file.name}...")
with open(raw_file, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
# Create source info
source_info = {
'title': raw_data['filename'],
'extraction_date': raw_data['extraction_date'],
'total_pages': raw_data['total_pages']
}
# Combine all page text
full_text = ' '.join(
page['text'] for page in raw_data['pages']
if 'text' in page
)
# Generate QA pairs
qa_pairs = self.generate_qa_pairs(full_text, source_info)
# Save source-specific QA pairs
source_output = self.qa_dir / f"{raw_file.stem.replace('_raw', '')}_qa.jsonl"
with open(source_output, 'w', encoding='utf-8') as f:
for pair in qa_pairs:
f.write(json.dumps(pair) + '\n')
# Add to manifest
self.add_to_manifest(
raw_file,
source_output,
'qa_generation',
{
'pairs_generated': len(qa_pairs),
'source': source_info['title']
}
)
all_qa_pairs.extend(qa_pairs)
sources_processed.append(source_info)
print(f"Generated {len(qa_pairs)} Q&A pairs")
# Save combined QA pairs
combined_output = self.qa_dir / 'combined_qa.jsonl'
with open(combined_output, 'w', encoding='utf-8') as f:
# Write metadata first
metadata = {
'timestamp': datetime.now().isoformat(),
'total_pairs': len(all_qa_pairs),
'sources': sources_processed
}
f.write(json.dumps(metadata) + '\n')
# Write all QA pairs
for pair in all_qa_pairs:
f.write(json.dumps(pair) + '\n')
# Save QA generation manifest
manifest_file = self.qa_dir / 'qa_generation_manifest.json'
with open(manifest_file, 'w', encoding='utf-8') as f:
json.dump(self.manifest, f, indent=2)
print("\nQ&A Generation Summary:")
print(f"Total sources processed: {len(sources_processed)}")
print(f"Total Q&A pairs generated: {len(all_qa_pairs)}")
print(f"Individual source files saved in: {self.qa_dir}")
print(f"Combined Q&A pairs saved as: {combined_output}")
print(f"Provenance data saved as: {manifest_file}")
if __name__ == "__main__":
generator = QAGenerator()
generator.process_all_sources()