|
import json
|
|
import time
|
|
import logging
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Dict, List, Tuple
|
|
import google.generativeai as genai
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
|
|
from config import (
|
|
GEMINI_API_KEY, GEMINI_RATE_LIMIT, PAIRS_PER_PROMPT,
|
|
TARGET_QA_PAIRS, PROCESSED_DIR, FINAL_DIR, LOG_DIR
|
|
)
|
|
|
|
class QAPairGenerator:
|
|
def __init__(self):
|
|
|
|
genai.configure(api_key=GEMINI_API_KEY)
|
|
self.model = genai.GenerativeModel('gemini-1.5-flash')
|
|
|
|
|
|
log_file = LOG_DIR / f"generator_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
filename=log_file
|
|
)
|
|
|
|
self.generated_pairs = []
|
|
self.failed_generations = []
|
|
self.total_pairs_generated = 0
|
|
|
|
def _generate_qa_batch(self, content: str, category: str) -> List[Dict]:
|
|
"""Generate a batch of QA pairs from content using regex parsing"""
|
|
prompt = f"""
|
|
Based on the following content about Bloomington, Indiana, generate {PAIRS_PER_PROMPT} different instruction-response pairs.
|
|
The content is related to the category: {category}
|
|
|
|
Focus on creating specific, practical questions that tourists might ask, with detailed, actionable responses.
|
|
Include relevant details like operating hours, costs, locations, and tips when applicable.
|
|
|
|
Format your response EXACTLY as a JSON array with each object containing "instruction", "response", and "category" fields.
|
|
|
|
Example format:
|
|
[
|
|
{{
|
|
"instruction": "What are the peak times to visit the Sample Gardens and how much does it cost?",
|
|
"response": "Sample Gardens is busiest during weekends and holidays. Admission is $10 for adults, $5 for children (5-12), and free for children under 5. To avoid crowds, visit on weekday mornings between 9-11am. Free parking is available.",
|
|
"category": "attractions"
|
|
}}
|
|
]
|
|
|
|
Content: {content}
|
|
"""
|
|
|
|
try:
|
|
response = self.model.generate_content(prompt)
|
|
response_text = response.text.strip()
|
|
|
|
|
|
try:
|
|
|
|
array_pattern = r'\[\s*(\{[^]]*\})\s*(?:,\s*(\{[^]]*\})\s*)*\]'
|
|
array_match = re.search(array_pattern, response_text, re.DOTALL)
|
|
|
|
if array_match:
|
|
json_str = array_match.group(0)
|
|
|
|
|
|
object_pattern = r'\{\s*"instruction":\s*"([^"]*)",\s*"response":\s*"([^"]*)",\s*"category":\s*"([^"]*)"\s*\}'
|
|
objects = re.finditer(object_pattern, json_str)
|
|
|
|
valid_pairs = []
|
|
for obj_match in objects:
|
|
instruction = obj_match.group(1)
|
|
response = obj_match.group(2)
|
|
obj_category = obj_match.group(3)
|
|
|
|
|
|
if len(instruction) >= 20 and len(response) >= 50:
|
|
valid_pairs.append({
|
|
'instruction': instruction,
|
|
'response': response,
|
|
'category': category
|
|
})
|
|
else:
|
|
logging.warning(f"Pair rejected due to length requirements: Q: {len(instruction)} chars, A: {len(response)} chars")
|
|
|
|
if valid_pairs:
|
|
return valid_pairs
|
|
|
|
logging.warning("Regex parsing failed, attempting JSON parsing as fallback")
|
|
|
|
except Exception as regex_error:
|
|
logging.warning(f"Regex parsing error: {str(regex_error)}")
|
|
|
|
|
|
try:
|
|
|
|
start_idx = response_text.find('[')
|
|
end_idx = response_text.rfind(']') + 1
|
|
|
|
if start_idx != -1 and end_idx > start_idx:
|
|
json_str = response_text[start_idx:end_idx]
|
|
pairs = json.loads(json_str)
|
|
else:
|
|
pairs = json.loads(response_text)
|
|
|
|
|
|
valid_pairs = []
|
|
for pair in pairs:
|
|
if (isinstance(pair, dict) and
|
|
'instruction' in pair and
|
|
'response' in pair and
|
|
isinstance(pair['instruction'], str) and
|
|
isinstance(pair['response'], str) and
|
|
len(pair['instruction']) >= 20 and
|
|
len(pair['response']) >= 50):
|
|
|
|
pair['category'] = category
|
|
valid_pairs.append(pair)
|
|
else:
|
|
logging.warning(f"Invalid pair structure or length: {pair}")
|
|
|
|
return valid_pairs
|
|
|
|
except json.JSONDecodeError as json_error:
|
|
logging.error(f"JSON parsing error: {str(json_error)}\nResponse text: {response_text}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in QA pair generation: {str(e)}")
|
|
self.failed_generations.append({
|
|
'content': content,
|
|
'category': category,
|
|
'error': str(e),
|
|
'response_text': response_text if 'response_text' in locals() else None,
|
|
'timestamp': datetime.now().isoformat()
|
|
})
|
|
return []
|
|
|
|
def generate_pairs_for_category(self, category: str) -> List[Dict]:
|
|
"""Generate QA pairs for a specific category"""
|
|
input_file = PROCESSED_DIR / f"{category}_processed.json"
|
|
|
|
try:
|
|
with open(input_file, 'r') as f:
|
|
processed_data = json.load(f)
|
|
except Exception as e:
|
|
logging.error(f"Error loading {input_file}: {e}")
|
|
return []
|
|
|
|
category_pairs = []
|
|
|
|
|
|
remaining_pairs = TARGET_QA_PAIRS - self.total_pairs_generated
|
|
|
|
if remaining_pairs <= 0:
|
|
logging.info("Target number of QA pairs reached")
|
|
return []
|
|
|
|
for item in tqdm(processed_data, desc=f"Generating pairs for {category}"):
|
|
if self.total_pairs_generated >= TARGET_QA_PAIRS:
|
|
logging.info(f"Target of {TARGET_QA_PAIRS} pairs reached. Stopping generation.")
|
|
break
|
|
|
|
|
|
content = f"{item['title']} {item['snippet']}"
|
|
if 'additional_content' in item:
|
|
content += f" {item['additional_content']}"
|
|
|
|
pairs = self._generate_qa_batch(content, category)
|
|
|
|
|
|
pairs_needed = min(len(pairs), TARGET_QA_PAIRS - self.total_pairs_generated)
|
|
valid_pairs = pairs[:pairs_needed]
|
|
|
|
category_pairs.extend(valid_pairs)
|
|
self.total_pairs_generated += len(valid_pairs)
|
|
|
|
logging.info(f"Progress: {self.total_pairs_generated}/{TARGET_QA_PAIRS} pairs")
|
|
|
|
if self.total_pairs_generated >= TARGET_QA_PAIRS:
|
|
break
|
|
|
|
time.sleep(60/GEMINI_RATE_LIMIT)
|
|
|
|
|
|
output_file = FINAL_DIR / f"{category}_qa_pairs.json"
|
|
with open(output_file, 'w') as f:
|
|
json.dump(category_pairs, f, indent=2)
|
|
|
|
return category_pairs
|
|
|
|
def generate_all_pairs(self) -> None:
|
|
"""Generate QA pairs for all categories until target is reached"""
|
|
categories = [f.stem.replace('_processed', '')
|
|
for f in PROCESSED_DIR.glob('*_processed.json')]
|
|
|
|
all_pairs = []
|
|
|
|
|
|
while self.total_pairs_generated < TARGET_QA_PAIRS and categories:
|
|
for category in categories[:]:
|
|
if self.total_pairs_generated >= TARGET_QA_PAIRS:
|
|
break
|
|
|
|
logging.info(f"Starting generation for category: {category}")
|
|
category_pairs = self.generate_pairs_for_category(category)
|
|
|
|
if not category_pairs:
|
|
categories.remove(category)
|
|
continue
|
|
|
|
all_pairs.extend(category_pairs)
|
|
logging.info(f"Generated {len(category_pairs)} pairs for {category}")
|
|
self._save_progress(all_pairs)
|
|
|
|
if self.total_pairs_generated >= TARGET_QA_PAIRS:
|
|
break
|
|
|
|
|
|
if self.total_pairs_generated < TARGET_QA_PAIRS and not categories:
|
|
logging.warning(f"Exhausted all categories. Generated {self.total_pairs_generated}/{TARGET_QA_PAIRS} pairs")
|
|
break
|
|
|
|
|
|
self._save_final_results(all_pairs)
|
|
|
|
def _save_progress(self, pairs: List[Dict]) -> None:
|
|
"""Save intermediate progress"""
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
progress_file = FINAL_DIR / f"qa_pairs_progress_{timestamp}.json"
|
|
|
|
progress_data = {
|
|
'pairs': pairs,
|
|
'stats': {
|
|
'total_pairs': len(pairs),
|
|
'target_pairs': TARGET_QA_PAIRS,
|
|
'completion_percentage': (self.total_pairs_generated / TARGET_QA_PAIRS) * 100,
|
|
'pairs_per_category': pd.DataFrame(pairs)['category'].value_counts().to_dict(),
|
|
'timestamp': timestamp
|
|
}
|
|
}
|
|
|
|
with open(progress_file, 'w') as f:
|
|
json.dump(progress_data, f, indent=2)
|
|
|
|
def _save_final_results(self, pairs: List[Dict]) -> None:
|
|
"""Save final results and statistics"""
|
|
final_file = FINAL_DIR / "final_qa_pairs.json"
|
|
|
|
final_data = {
|
|
'pairs': pairs,
|
|
'stats': {
|
|
'total_pairs_generated': self.total_pairs_generated,
|
|
'target_pairs': TARGET_QA_PAIRS,
|
|
'completion_percentage': (self.total_pairs_generated / TARGET_QA_PAIRS) * 100,
|
|
'pairs_per_category': pd.DataFrame(pairs)['category'].value_counts().to_dict(),
|
|
'avg_instruction_length': pd.DataFrame(pairs)['instruction'].str.len().mean(),
|
|
'avg_response_length': pd.DataFrame(pairs)['response'].str.len().mean(),
|
|
'failed_generations': len(self.failed_generations),
|
|
'completion_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
}
|
|
}
|
|
|
|
|
|
with open(final_file, 'w') as f:
|
|
json.dump(final_data, f, indent=2)
|
|
|
|
|
|
df = pd.DataFrame(pairs)
|
|
df.to_csv(FINAL_DIR / "final_qa_pairs.csv", index=False)
|
|
|
|
|
|
if self.failed_generations:
|
|
with open(FINAL_DIR / "failed_generations.json", 'w') as f:
|
|
json.dump(self.failed_generations, f, indent=2)
|
|
|
|
logging.info(f"""
|
|
Generation completed:
|
|
- Total pairs generated: {self.total_pairs_generated}
|
|
- Target pairs: {TARGET_QA_PAIRS}
|
|
- Categories used: {len(set(pair['category'] for pair in pairs))}
|
|
""") |