File size: 10,746 Bytes
42b54f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import json
import re
import nltk
from nltk.tokenize import sent_tokenize
import torch
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from rank_bm25 import BM25Okapi  # BM25 for hybrid search
import logging


nltk.download('punkt', quiet=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class Hogragger:
    def __init__(self, corpus_path, model_name='sentence-transformers/all-MiniLM-L12-v2', qa_model='deepset/roberta-large-squad2', classifier_model='deepset/roberta-large-squad2'):
        self.corpus = self.load_corpus(corpus_path)
        self.cleaned_passages = self.preprocess_corpus()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logging.info(f"Using device: {self.device}")

        # Initialize embedding model and build FAISS index
        self.model = SentenceTransformer(model_name).to(self.device)
        self.index = self.build_faiss_index()

        # Initialize BM25 for lexical matching
        self.bm25 = self.build_bm25_index()

        # Initialize classifier for question type prediction
        self.tokenizer = AutoTokenizer.from_pretrained(classifier_model)
        self.classifier = AutoModelForSequenceClassification.from_pretrained(classifier_model).to(self.device)

        # QA Model
        self.qa_model = pipeline('question-answering', model=qa_model, device=0 if self.device == 'cuda' else -1)

    def load_corpus(self, path):
        logging.info(f"Loading corpus from {path}")
        with open(path, "r") as f:
            corpus = json.load(f)
        logging.info(f"Loaded {len(corpus)} documents")
        return corpus

    # def preprocess_corpus(self):
    #     cleaned_passages = []
    #     for article in self.corpus:
    #         body = article.get('body', '')
    #         clean_body = re.sub(r'<.*?>', '', body)  # Clean HTML tags
    #         clean_body = re.sub(r'\s+', ' ', clean_body).strip()  # Clean extra spaces
    #         sentences = sent_tokenize(clean_body)

    #         chunk = ""
    #         for sentence in sentences:
    #             if len(chunk.split()) + len(sentence.split()) <= 300:
    #                 chunk += " " + sentence
    #             else:
    #                 cleaned_passages.append(self.create_passage(article, chunk))
    #                 chunk = sentence

    #         if chunk:
    #             cleaned_passages.append(self.create_passage(article, chunk))
    #     logging.info(f"Created {len(cleaned_passages)} passages")
    #     return cleaned_passages
    def preprocess_corpus(self):
        cleaned_passages = []
        for article in self.corpus:
            body = article.get('body', '')
            clean_body = re.sub(r'<.*?>', '', body)  # Clean HTML tags
            clean_body = re.sub(r'\s+', ' ', clean_body).strip()  # Clean extra spaces

            # Simply take the full cleaned text as a passage without chunking or sentence splitting
            cleaned_passages.append(self.create_passage(article, clean_body))

        logging.info(f"Created {len(cleaned_passages)} passages")
        return cleaned_passages

    def create_passage(self, article, chunk):
        """Creates a passage dictionary from an article and chunk of text."""
        return {
            "title": article['title'],
            "author": article.get('author', 'Unknown'),
            "published_at": article['published_at'],
            "category": article['category'],
            "url": article['url'],
            "source": article['source'],
            "passage": chunk.strip()
        }

    def build_faiss_index(self):
        logging.info("Building FAISS index...")
        embeddings = self.model.encode([p['passage'] for p in self.cleaned_passages], convert_to_tensor=True, device=self.device)
        embeddings = np.array(embeddings.cpu()).astype('float32')
        logging.info(f"Shape of embeddings: {embeddings.shape}")

        index = faiss.IndexFlatL2(embeddings.shape[1])  # Initialize FAISS index

        if self.device == 'cuda':
            try:
                res = faiss.StandardGpuResources()
                gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
                gpu_index.add(embeddings)
                logging.info("Successfully created GPU index")
                return gpu_index
            except RuntimeError as e:
                logging.error(f"GPU index creation failed: {e}")
                logging.info("Falling back to CPU index")

        index.add(embeddings)  # Add embeddings to CPU index
        logging.info("Successfully created CPU index")
        return index

    def build_bm25_index(self):
        logging.info("Building BM25 index...")
        tokenized_corpus = [p['passage'].split() for p in self.cleaned_passages]
        bm25 = BM25Okapi(tokenized_corpus)
        logging.info("Successfully built BM25 index")
        return bm25

    def predict_question_type(self, query):
        inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        outputs = self.classifier(**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).item()

        labels = {0: 'inference_query', 1: 'comparison_query', 2: 'null_query', 3: 'temporal_query', 4: 'fact_query'}
        return labels.get(prediction, 'unknown_query')

    def retrieve_passages(self, query, k=100, threshold=0.7):
        try:
            # FAISS retrieval
            query_embedding = self.model.encode([query], convert_to_tensor=True, device=self.device)
            D, I = self.index.search(np.array(query_embedding.cpu()), k)

            # BM25 retrieval
            tokenized_query = query.split()
            bm25_scores = self.bm25.get_scores(tokenized_query)

            # Combine FAISS and BM25 results
            hybrid_scores = self.combine_faiss_bm25_scores(D[0], bm25_scores, I)

            # Filter passages based on hybrid score
            passages = [self.cleaned_passages[i] for i, score in zip(I[0], hybrid_scores) if score > threshold]

            logging.info(f"Retrieved {len(passages)} passages using hybrid search for query.")
            return passages
        except Exception as e:
            logging.error(f"Error in retrieving passages: {e}")
            return []

    def combine_faiss_bm25_scores(self, faiss_scores, bm25_scores, passage_indices):
        # Normalize and combine FAISS and BM25 scores
        bm25_scores = np.array(bm25_scores)[passage_indices]
        faiss_scores = np.array(faiss_scores)

        # Convert FAISS distances into similarities by inverting the scale
        faiss_similarities = 1 / (faiss_scores + 1e-6)  # Avoid division by zero

        # Normalize scores (scale between 0 and 1)
        bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores) + 1e-6)
        faiss_similarities = (faiss_similarities - np.min(faiss_similarities)) / (np.max(faiss_similarities) - np.min(faiss_similarities) + 1e-6)

        # Weighted combination (you can adjust weights)
        combined_scores = 0.7 * faiss_similarities + 0.3 * bm25_scores
        combined_scores = np.squeeze(combined_scores)  # Ensure it's a single-dimensional array

        return combined_scores

    def filter_passages(self, query, passages):
        try:
            query_embedding = self.model.encode(query, convert_to_tensor=True)
            passage_embeddings = self.model.encode([p['passage'] for p in passages], convert_to_tensor=True)

            similarities = util.pytorch_cos_sim(query_embedding, passage_embeddings)
            top_k = min(10, len(passages))
            top_indices = similarities.topk(k=top_k)[1].tolist()[0]

            selected_passages = []
            used_titles = set()
            for i in top_indices:
                if passages[i]['title'] not in used_titles:
                    selected_passages.append(passages[i])
                    used_titles.add(passages[i]['title'])

            return selected_passages
        except Exception as e:
            logging.error(f"Error in filtering passages: {e}")
            return []

    def generate_answer(self, query, passages):
        try:
            context = " ".join([p['passage'] for p in passages[:5]])
            answer = self.qa_model(question=query, context=context)
            logging.info(f"Generated answer: {answer['answer']}")
            return answer['answer']
        except Exception as e:
            logging.error(f"Error in generating answer: {e}")
            return "Insufficient information."

    def post_process_answer(self, answer, confidence=0.2):
        answer = re.sub(r'^.*\?', '', answer).strip()
        answer = answer.capitalize()

        if len(answer) > 100:
            truncated = re.match(r'^(.*?[.!?])\s', answer)
            if truncated:
                answer = truncated.group(1)

        if confidence < 0.2:
            logging.warning(f"Answer confidence too low: {confidence}")
            return "I'm unsure about this answer."

        return answer

    def process_query(self, query):
        question_type = self.predict_question_type(query)
        retrieved_passages = self.retrieve_passages(query, k=100, threshold=0.7)
        if not retrieved_passages:
            return {"query": query, "answer": "No relevant information found", "question_type": question_type, "evidence_list": []}

        filtered_passages = self.filter_passages(query, retrieved_passages)
        raw_answer = self.generate_answer(query, filtered_passages)

        evidence_count = min(len(filtered_passages), 4)
        evidence_list = [
            {
                "title": p['title'],
                "author": p['author'],
                "url": p['url'],
                "source": p['source'],
                "category": p['category'],
                "published_at": p['published_at'],
                "fact": self.extract_fact(p['passage'], query)
            } for p in filtered_passages[:evidence_count]
        ]
        final_answer = self.post_process_answer(raw_answer)

        return {
            "query": query,
            "answer": final_answer,
            "question_type": question_type,
            "evidence_list": evidence_list
        }

    def extract_fact(self, passage, query):
        # Extracting most relevant sentence from passage
        sentences = sent_tokenize(passage)
        query_keywords = set(query.lower().split())

        best_sentence = max(sentences, key=lambda s: len(set(s.lower().split()) & query_keywords), default="")
        return best_sentence if best_sentence else (sentences[0] if sentences else "")