import os import numpy as np import faiss import streamlit as st from sentence_transformers import SentenceTransformer from transformers import AutoModelForSequenceClassification, AutoTokenizer from beir import util from beir.datasets.data_loader import GenericDataLoader from beir import EvaluateRetrieval # Function to load the dataset def load_dataset(): dataset_name = "nq" data_path = f"datasets/{dataset_name}.zip" if not os.path.exists(data_path): url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip" util.download_and_unzip(url, "datasets/") corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") return corpus, queries, qrels # Function for candidate retrieval def candidate_retrieval(corpus, queries): embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") corpus_ids = list(corpus.keys()) corpus_texts = [corpus[pid]["text"] for pid in corpus_ids] corpus_embeddings = embed_model.encode(corpus_texts, convert_to_numpy=True) index = faiss.IndexFlatL2(corpus_embeddings.shape[1]) index.add(corpus_embeddings) query_texts = [queries[qid] for qid in queries.keys()] query_embeddings = embed_model.encode(query_texts, convert_to_numpy=True) _, retrieved_indices = index.search(query_embeddings, 10) return retrieved_indices, corpus_ids # Function for reranking def rerank_passages(retrieved_indices, corpus, queries): cross_encoder_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2") tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2") reranked_passages = [] for i, query in enumerate(queries.values()): query_passage_pairs = [(query, corpus[corpus_ids[idx]]["text"]) for idx in retrieved_indices[i]] inputs = tokenizer(query_passage_pairs, padding=True, truncation=True, return_tensors="pt") scores = cross_encoder_model(**inputs).logits.squeeze(-1) top_reranked_passages = [passage for _, passage in sorted(zip(scores, query_passage_pairs), key=lambda x: x[0], reverse=True)] reranked_passages.append(top_reranked_passages) return reranked_passages # Function for evaluation def evaluate(qrels, retrieved_indices, reranked_passages, queries): evaluator = EvaluateRetrieval() results_stage1 = {} for i, query_id in enumerate(queries.keys()): results_stage1[query_id] = {corpus_ids[idx]: 1 for idx in retrieved_indices[i]} ndcg_score_stage1 = evaluator.evaluate(qrels, results_stage1, [10])['NDCG@10'] results_stage2 = {} for i, query_id in enumerate(queries.keys()): results_stage2[query_id] = {} for passage in reranked_passages[i]: for pid, doc in corpus.items(): if doc["text"] == passage[1]: results_stage2[query_id][pid] = 1 break ndcg_score_stage2 = evaluator.evaluate(qrels, results_stage2, [10])['NDCG@10'] return ndcg_score_stage1, ndcg_score_stage2 # Streamlit app def main(): st.title("Multi-Stage Text Retrieval Pipeline") if st.button("Load Dataset"): corpus, queries, qrels = load_dataset() st.success("Dataset loaded successfully!") if st.button("Run Candidate Retrieval"): retrieved_indices, corpus_ids = candidate_retrieval(corpus, queries) st.success("Candidate retrieval completed!") st.write("Retrieved indices:", retrieved_indices) if st.button("Run Reranking"): reranked_passages = rerank_passages(retrieved_indices, corpus, queries) st.success("Reranking completed!") st.write("Reranked passages:", reranked_passages) if st.button("Evaluate"): ndcg_score_stage1, ndcg_score_stage2 = evaluate(qrels, retrieved_indices, reranked_passages, queries) st.write(f"NDCG@10 for Stage 1 (Candidate Retrieval): {ndcg_score_stage1}") st.write(f"NDCG@10 for Stage 2 (Reranking): {ndcg_score_stage2}") if __name__ == "__main__": main()