Bloomington / generator.py
krishna3103's picture
Upload 8 files
0baf78e verified
import json
import time
import logging
import re
from datetime import datetime
from typing import Dict, List, Tuple
import google.generativeai as genai
from tqdm import tqdm
import pandas as pd
from config import (
GEMINI_API_KEY, GEMINI_RATE_LIMIT, PAIRS_PER_PROMPT,
TARGET_QA_PAIRS, PROCESSED_DIR, FINAL_DIR, LOG_DIR
)
class QAPairGenerator:
def __init__(self):
# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
self.model = genai.GenerativeModel('gemini-1.5-flash')
# Set up logging
log_file = LOG_DIR / f"generator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename=log_file
)
self.generated_pairs = []
self.failed_generations = []
self.total_pairs_generated = 0 # Add counter for total pairs
def _generate_qa_batch(self, content: str, category: str) -> List[Dict]:
"""Generate a batch of QA pairs from content using regex parsing"""
prompt = f"""
Based on the following content about Bloomington, Indiana, generate {PAIRS_PER_PROMPT} different instruction-response pairs.
The content is related to the category: {category}
Focus on creating specific, practical questions that tourists might ask, with detailed, actionable responses.
Include relevant details like operating hours, costs, locations, and tips when applicable.
Format your response EXACTLY as a JSON array with each object containing "instruction", "response", and "category" fields.
Example format:
[
{{
"instruction": "What are the peak times to visit the Sample Gardens and how much does it cost?",
"response": "Sample Gardens is busiest during weekends and holidays. Admission is $10 for adults, $5 for children (5-12), and free for children under 5. To avoid crowds, visit on weekday mornings between 9-11am. Free parking is available.",
"category": "attractions"
}}
]
Content: {content}
"""
try:
response = self.model.generate_content(prompt)
response_text = response.text.strip()
# Try regex parsing first
try:
# Pattern to match the entire JSON array
array_pattern = r'\[\s*(\{[^]]*\})\s*(?:,\s*(\{[^]]*\})\s*)*\]'
array_match = re.search(array_pattern, response_text, re.DOTALL)
if array_match:
json_str = array_match.group(0)
# Additional regex to validate individual objects
object_pattern = r'\{\s*"instruction":\s*"([^"]*)",\s*"response":\s*"([^"]*)",\s*"category":\s*"([^"]*)"\s*\}'
objects = re.finditer(object_pattern, json_str)
valid_pairs = []
for obj_match in objects:
instruction = obj_match.group(1)
response = obj_match.group(2)
obj_category = obj_match.group(3)
# Validate lengths
if len(instruction) >= 20 and len(response) >= 50:
valid_pairs.append({
'instruction': instruction,
'response': response,
'category': category # Use the passed category instead of the one in response
})
else:
logging.warning(f"Pair rejected due to length requirements: Q: {len(instruction)} chars, A: {len(response)} chars")
if valid_pairs:
return valid_pairs
logging.warning("Regex parsing failed, attempting JSON parsing as fallback")
except Exception as regex_error:
logging.warning(f"Regex parsing error: {str(regex_error)}")
# Fallback to JSON parsing
try:
# Find the first '[' and last ']' to extract JSON array
start_idx = response_text.find('[')
end_idx = response_text.rfind(']') + 1
if start_idx != -1 and end_idx > start_idx:
json_str = response_text[start_idx:end_idx]
pairs = json.loads(json_str)
else:
pairs = json.loads(response_text)
# Validate pairs
valid_pairs = []
for pair in pairs:
if (isinstance(pair, dict) and
'instruction' in pair and
'response' in pair and
isinstance(pair['instruction'], str) and
isinstance(pair['response'], str) and
len(pair['instruction']) >= 20 and
len(pair['response']) >= 50):
pair['category'] = category
valid_pairs.append(pair)
else:
logging.warning(f"Invalid pair structure or length: {pair}")
return valid_pairs
except json.JSONDecodeError as json_error:
logging.error(f"JSON parsing error: {str(json_error)}\nResponse text: {response_text}")
return []
except Exception as e:
logging.error(f"Error in QA pair generation: {str(e)}")
self.failed_generations.append({
'content': content,
'category': category,
'error': str(e),
'response_text': response_text if 'response_text' in locals() else None,
'timestamp': datetime.now().isoformat()
})
return []
def generate_pairs_for_category(self, category: str) -> List[Dict]:
"""Generate QA pairs for a specific category"""
input_file = PROCESSED_DIR / f"{category}_processed.json"
try:
with open(input_file, 'r') as f:
processed_data = json.load(f)
except Exception as e:
logging.error(f"Error loading {input_file}: {e}")
return []
category_pairs = []
# Calculate remaining pairs needed
remaining_pairs = TARGET_QA_PAIRS - self.total_pairs_generated
if remaining_pairs <= 0:
logging.info("Target number of QA pairs reached")
return []
for item in tqdm(processed_data, desc=f"Generating pairs for {category}"):
if self.total_pairs_generated >= TARGET_QA_PAIRS:
logging.info(f"Target of {TARGET_QA_PAIRS} pairs reached. Stopping generation.")
break
# Combine all available content
content = f"{item['title']} {item['snippet']}"
if 'additional_content' in item:
content += f" {item['additional_content']}"
pairs = self._generate_qa_batch(content, category)
# Only take as many pairs as needed
pairs_needed = min(len(pairs), TARGET_QA_PAIRS - self.total_pairs_generated)
valid_pairs = pairs[:pairs_needed]
category_pairs.extend(valid_pairs)
self.total_pairs_generated += len(valid_pairs)
logging.info(f"Progress: {self.total_pairs_generated}/{TARGET_QA_PAIRS} pairs")
if self.total_pairs_generated >= TARGET_QA_PAIRS:
break
time.sleep(60/GEMINI_RATE_LIMIT) # Respect rate limit
# Save category pairs
output_file = FINAL_DIR / f"{category}_qa_pairs.json"
with open(output_file, 'w') as f:
json.dump(category_pairs, f, indent=2)
return category_pairs
def generate_all_pairs(self) -> None:
"""Generate QA pairs for all categories until target is reached"""
categories = [f.stem.replace('_processed', '')
for f in PROCESSED_DIR.glob('*_processed.json')]
all_pairs = []
# Keep generating pairs until we reach the target
while self.total_pairs_generated < TARGET_QA_PAIRS and categories:
for category in categories[:]: # Create a copy to modify safely
if self.total_pairs_generated >= TARGET_QA_PAIRS:
break
logging.info(f"Starting generation for category: {category}")
category_pairs = self.generate_pairs_for_category(category)
if not category_pairs: # If no more pairs can be generated for this category
categories.remove(category)
continue
all_pairs.extend(category_pairs)
logging.info(f"Generated {len(category_pairs)} pairs for {category}")
self._save_progress(all_pairs)
if self.total_pairs_generated >= TARGET_QA_PAIRS:
break
# Check if we need to continue
if self.total_pairs_generated < TARGET_QA_PAIRS and not categories:
logging.warning(f"Exhausted all categories. Generated {self.total_pairs_generated}/{TARGET_QA_PAIRS} pairs")
break
# Save final results
self._save_final_results(all_pairs)
def _save_progress(self, pairs: List[Dict]) -> None:
"""Save intermediate progress"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
progress_file = FINAL_DIR / f"qa_pairs_progress_{timestamp}.json"
progress_data = {
'pairs': pairs,
'stats': {
'total_pairs': len(pairs),
'target_pairs': TARGET_QA_PAIRS,
'completion_percentage': (self.total_pairs_generated / TARGET_QA_PAIRS) * 100,
'pairs_per_category': pd.DataFrame(pairs)['category'].value_counts().to_dict(),
'timestamp': timestamp
}
}
with open(progress_file, 'w') as f:
json.dump(progress_data, f, indent=2)
def _save_final_results(self, pairs: List[Dict]) -> None:
"""Save final results and statistics"""
final_file = FINAL_DIR / "final_qa_pairs.json"
final_data = {
'pairs': pairs,
'stats': {
'total_pairs_generated': self.total_pairs_generated,
'target_pairs': TARGET_QA_PAIRS,
'completion_percentage': (self.total_pairs_generated / TARGET_QA_PAIRS) * 100,
'pairs_per_category': pd.DataFrame(pairs)['category'].value_counts().to_dict(),
'avg_instruction_length': pd.DataFrame(pairs)['instruction'].str.len().mean(),
'avg_response_length': pd.DataFrame(pairs)['response'].str.len().mean(),
'failed_generations': len(self.failed_generations),
'completion_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
}
# Save main results
with open(final_file, 'w') as f:
json.dump(final_data, f, indent=2)
# Save as CSV
df = pd.DataFrame(pairs)
df.to_csv(FINAL_DIR / "final_qa_pairs.csv", index=False)
# Save failed generations for analysis
if self.failed_generations:
with open(FINAL_DIR / "failed_generations.json", 'w') as f:
json.dump(self.failed_generations, f, indent=2)
logging.info(f"""
Generation completed:
- Total pairs generated: {self.total_pairs_generated}
- Target pairs: {TARGET_QA_PAIRS}
- Categories used: {len(set(pair['category'] for pair in pairs))}
""")