Spaces:
Sleeping
Sleeping
import json | |
import re | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
import faiss | |
import numpy as np | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
from rank_bm25 import BM25Okapi # BM25 for hybrid search | |
import logging | |
nltk.download('punkt', quiet=True) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
class Hogragger: | |
def __init__(self, corpus_path, model_name='sentence-transformers/all-MiniLM-L12-v2', qa_model='deepset/roberta-large-squad2', classifier_model='deepset/roberta-large-squad2'): | |
self.corpus = self.load_corpus(corpus_path) | |
self.cleaned_passages = self.preprocess_corpus() | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
logging.info(f"Using device: {self.device}") | |
# Initialize embedding model and build FAISS index | |
self.model = SentenceTransformer(model_name).to(self.device) | |
self.index = self.build_faiss_index() | |
# Initialize BM25 for lexical matching | |
self.bm25 = self.build_bm25_index() | |
# Initialize classifier for question type prediction | |
self.tokenizer = AutoTokenizer.from_pretrained(classifier_model) | |
self.classifier = AutoModelForSequenceClassification.from_pretrained(classifier_model).to(self.device) | |
# QA Model | |
self.qa_model = pipeline('question-answering', model=qa_model, device=0 if self.device == 'cuda' else -1) | |
def load_corpus(self, path): | |
logging.info(f"Loading corpus from {path}") | |
with open(path, "r") as f: | |
corpus = json.load(f) | |
logging.info(f"Loaded {len(corpus)} documents") | |
return corpus | |
# def preprocess_corpus(self): | |
# cleaned_passages = [] | |
# for article in self.corpus: | |
# body = article.get('body', '') | |
# clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags | |
# clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces | |
# sentences = sent_tokenize(clean_body) | |
# chunk = "" | |
# for sentence in sentences: | |
# if len(chunk.split()) + len(sentence.split()) <= 300: | |
# chunk += " " + sentence | |
# else: | |
# cleaned_passages.append(self.create_passage(article, chunk)) | |
# chunk = sentence | |
# if chunk: | |
# cleaned_passages.append(self.create_passage(article, chunk)) | |
# logging.info(f"Created {len(cleaned_passages)} passages") | |
# return cleaned_passages | |
def preprocess_corpus(self): | |
cleaned_passages = [] | |
for article in self.corpus: | |
body = article.get('body', '') | |
clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags | |
clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces | |
# Simply take the full cleaned text as a passage without chunking or sentence splitting | |
cleaned_passages.append(self.create_passage(article, clean_body)) | |
logging.info(f"Created {len(cleaned_passages)} passages") | |
return cleaned_passages | |
def create_passage(self, article, chunk): | |
"""Creates a passage dictionary from an article and chunk of text.""" | |
return { | |
"title": article['title'], | |
"author": article.get('author', 'Unknown'), | |
"published_at": article['published_at'], | |
"category": article['category'], | |
"url": article['url'], | |
"source": article['source'], | |
"passage": chunk.strip() | |
} | |
def build_faiss_index(self): | |
logging.info("Building FAISS index...") | |
embeddings = self.model.encode([p['passage'] for p in self.cleaned_passages], convert_to_tensor=True, device=self.device) | |
embeddings = np.array(embeddings.cpu()).astype('float32') | |
logging.info(f"Shape of embeddings: {embeddings.shape}") | |
index = faiss.IndexFlatL2(embeddings.shape[1]) # Initialize FAISS index | |
if self.device == 'cuda': | |
try: | |
res = faiss.StandardGpuResources() | |
gpu_index = faiss.index_cpu_to_gpu(res, 0, index) | |
gpu_index.add(embeddings) | |
logging.info("Successfully created GPU index") | |
return gpu_index | |
except RuntimeError as e: | |
logging.error(f"GPU index creation failed: {e}") | |
logging.info("Falling back to CPU index") | |
index.add(embeddings) # Add embeddings to CPU index | |
logging.info("Successfully created CPU index") | |
return index | |
def build_bm25_index(self): | |
logging.info("Building BM25 index...") | |
tokenized_corpus = [p['passage'].split() for p in self.cleaned_passages] | |
bm25 = BM25Okapi(tokenized_corpus) | |
logging.info("Successfully built BM25 index") | |
return bm25 | |
def predict_question_type(self, query): | |
inputs = self.tokenizer(query, return_tensors='pt').to(self.device) | |
outputs = self.classifier(**inputs) | |
prediction = torch.argmax(outputs.logits, dim=1).item() | |
labels = {0: 'inference_query', 1: 'comparison_query', 2: 'null_query', 3: 'temporal_query', 4: 'fact_query'} | |
return labels.get(prediction, 'unknown_query') | |
def retrieve_passages(self, query, k=100, threshold=0.7): | |
try: | |
# FAISS retrieval | |
query_embedding = self.model.encode([query], convert_to_tensor=True, device=self.device) | |
D, I = self.index.search(np.array(query_embedding.cpu()), k) | |
# BM25 retrieval | |
tokenized_query = query.split() | |
bm25_scores = self.bm25.get_scores(tokenized_query) | |
# Combine FAISS and BM25 results | |
hybrid_scores = self.combine_faiss_bm25_scores(D[0], bm25_scores, I) | |
# Filter passages based on hybrid score | |
passages = [self.cleaned_passages[i] for i, score in zip(I[0], hybrid_scores) if score > threshold] | |
logging.info(f"Retrieved {len(passages)} passages using hybrid search for query.") | |
return passages | |
except Exception as e: | |
logging.error(f"Error in retrieving passages: {e}") | |
return [] | |
def combine_faiss_bm25_scores(self, faiss_scores, bm25_scores, passage_indices): | |
# Normalize and combine FAISS and BM25 scores | |
bm25_scores = np.array(bm25_scores)[passage_indices] | |
faiss_scores = np.array(faiss_scores) | |
# Convert FAISS distances into similarities by inverting the scale | |
faiss_similarities = 1 / (faiss_scores + 1e-6) # Avoid division by zero | |
# Normalize scores (scale between 0 and 1) | |
bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores) + 1e-6) | |
faiss_similarities = (faiss_similarities - np.min(faiss_similarities)) / (np.max(faiss_similarities) - np.min(faiss_similarities) + 1e-6) | |
# Weighted combination (you can adjust weights) | |
combined_scores = 0.7 * faiss_similarities + 0.3 * bm25_scores | |
combined_scores = np.squeeze(combined_scores) # Ensure it's a single-dimensional array | |
return combined_scores | |
def filter_passages(self, query, passages): | |
try: | |
query_embedding = self.model.encode(query, convert_to_tensor=True) | |
passage_embeddings = self.model.encode([p['passage'] for p in passages], convert_to_tensor=True) | |
similarities = util.pytorch_cos_sim(query_embedding, passage_embeddings) | |
top_k = min(10, len(passages)) | |
top_indices = similarities.topk(k=top_k)[1].tolist()[0] | |
selected_passages = [] | |
used_titles = set() | |
for i in top_indices: | |
if passages[i]['title'] not in used_titles: | |
selected_passages.append(passages[i]) | |
used_titles.add(passages[i]['title']) | |
return selected_passages | |
except Exception as e: | |
logging.error(f"Error in filtering passages: {e}") | |
return [] | |
def generate_answer(self, query, passages): | |
try: | |
context = " ".join([p['passage'] for p in passages[:5]]) | |
answer = self.qa_model(question=query, context=context) | |
logging.info(f"Generated answer: {answer['answer']}") | |
return answer['answer'] | |
except Exception as e: | |
logging.error(f"Error in generating answer: {e}") | |
return "Insufficient information." | |
def post_process_answer(self, answer, confidence=0.2): | |
answer = re.sub(r'^.*\?', '', answer).strip() | |
answer = answer.capitalize() | |
if len(answer) > 100: | |
truncated = re.match(r'^(.*?[.!?])\s', answer) | |
if truncated: | |
answer = truncated.group(1) | |
if confidence < 0.2: | |
logging.warning(f"Answer confidence too low: {confidence}") | |
return "I'm unsure about this answer." | |
return answer | |
def process_query(self, query): | |
question_type = self.predict_question_type(query) | |
retrieved_passages = self.retrieve_passages(query, k=100, threshold=0.7) | |
if not retrieved_passages: | |
return {"query": query, "answer": "No relevant information found", "question_type": question_type, "evidence_list": []} | |
filtered_passages = self.filter_passages(query, retrieved_passages) | |
raw_answer = self.generate_answer(query, filtered_passages) | |
evidence_count = min(len(filtered_passages), 4) | |
evidence_list = [ | |
{ | |
"title": p['title'], | |
"author": p['author'], | |
"url": p['url'], | |
"source": p['source'], | |
"category": p['category'], | |
"published_at": p['published_at'], | |
"fact": self.extract_fact(p['passage'], query) | |
} for p in filtered_passages[:evidence_count] | |
] | |
final_answer = self.post_process_answer(raw_answer) | |
return { | |
"query": query, | |
"answer": final_answer, | |
"question_type": question_type, | |
"evidence_list": evidence_list | |
} | |
def extract_fact(self, passage, query): | |
# Extracting most relevant sentence from passage | |
sentences = sent_tokenize(passage) | |
query_keywords = set(query.lower().split()) | |
best_sentence = max(sentences, key=lambda s: len(set(s.lower().split()) & query_keywords), default="") | |
return best_sentence if best_sentence else (sentences[0] if sentences else "") |