Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- app.py +70 -0
- data_preparation.py +7 -0
- evaluation.py +5 -0
- requirements.txt +6 -0
- reranking.py +16 -0
- retrieval.py +13 -0
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from data_preparation import load_dataset
|
3 |
+
from retrieval import load_embedding_model, retrieve_top_k
|
4 |
+
from reranking import load_ranking_model, rerank
|
5 |
+
from evaluation import evaluate_ndcg
|
6 |
+
|
7 |
+
# Set up the Streamlit interface
|
8 |
+
st.title("Multi-Stage Text Retrieval Pipeline for QA")
|
9 |
+
|
10 |
+
# Query Input
|
11 |
+
query = st.text_input("Enter a question:", "What is the capital of France?")
|
12 |
+
|
13 |
+
# Embedding model selection
|
14 |
+
embedding_model = st.selectbox(
|
15 |
+
"Select Embedding Model for Candidate Retrieval",
|
16 |
+
["sentence-transformers/all-MiniLM-L6-v2", "nvidia/nv-embedqa-e5-v5"]
|
17 |
+
)
|
18 |
+
|
19 |
+
# Ranking model selection
|
20 |
+
ranking_model = st.selectbox(
|
21 |
+
"Select Ranking Model for Re-Ranking",
|
22 |
+
["cross-encoder/ms-marco-MiniLM-L-12-v2", "nvidia/nv-rerankqa-mistral-4b-v3"]
|
23 |
+
)
|
24 |
+
|
25 |
+
# Run retrieval pipeline on button click
|
26 |
+
if st.button("Run Retrieval"):
|
27 |
+
# Load dataset
|
28 |
+
st.write("Loading dataset...")
|
29 |
+
corpus, queries, qrels = load_dataset("nq")
|
30 |
+
|
31 |
+
# Load selected embedding model
|
32 |
+
st.write(f"Loading embedding model: {embedding_model}...")
|
33 |
+
embed_model = load_embedding_model(embedding_model)
|
34 |
+
|
35 |
+
# Retrieve top-k passages using embedding model
|
36 |
+
st.write("Retrieving top-k passages...")
|
37 |
+
top_k_passages = retrieve_top_k(embed_model, query, corpus, k=10)
|
38 |
+
|
39 |
+
# Display retrieved passages
|
40 |
+
st.write("Top-k passages before reranking:")
|
41 |
+
for i, (passage, score) in enumerate(top_k_passages):
|
42 |
+
st.write(f"{i+1}. Passage: {passage}, Score: {score:.4f}")
|
43 |
+
|
44 |
+
# Load selected ranking model
|
45 |
+
st.write(f"Loading ranking model: {ranking_model}...")
|
46 |
+
rank_model, rank_tokenizer = load_ranking_model(ranking_model)
|
47 |
+
|
48 |
+
# Rerank the retrieved passages
|
49 |
+
st.write("Reranking passages...")
|
50 |
+
ranked_passages = rerank(rank_model, rank_tokenizer, query, top_k_passages)
|
51 |
+
|
52 |
+
# Display reranked passages
|
53 |
+
st.write("Top-k passages after reranking:")
|
54 |
+
for i, (passage, score) in enumerate(ranked_passages):
|
55 |
+
st.write(f"{i+1}. Passage: {passage}, Score: {score:.4f}")
|
56 |
+
|
57 |
+
# Evaluate using NDCG@10
|
58 |
+
st.write("Evaluating NDCG@10...")
|
59 |
+
query_id = list(queries.keys())[0] # Assuming we are using the first query for evaluation
|
60 |
+
ndcg_score = evaluate_ndcg(ranked_passages, qrels[query_id])
|
61 |
+
st.write(f"NDCG@10: {ndcg_score:.4f}")
|
62 |
+
|
63 |
+
# Sidebar with instructions
|
64 |
+
st.sidebar.title("Instructions")
|
65 |
+
st.sidebar.write("""
|
66 |
+
1. Enter a question in the text input.
|
67 |
+
2. Select the embedding model for candidate retrieval.
|
68 |
+
3. Select the ranking model for reranking the retrieved passages.
|
69 |
+
4. Click 'Run Retrieval' to start the pipeline and display the results.
|
70 |
+
""")
|
data_preparation.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from beir import util
|
2 |
+
from beir.datasets.data_loader import GenericDataLoader
|
3 |
+
|
4 |
+
def load_dataset(dataset_name="nq"):
|
5 |
+
data_path = util.download_and_unzip(f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip", dataset_name)
|
6 |
+
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
|
7 |
+
return corpus, queries, qrels
|
evaluation.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import ndcg_score
|
2 |
+
|
3 |
+
def evaluate_ndcg(top_k_passages, qrels):
|
4 |
+
relevance_scores = [1 if doc in qrels else 0 for doc, _ in top_k_passages]
|
5 |
+
return ndcg_score([relevance_scores], [[1]*len(relevance_scores)], k=10)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
beir
|
3 |
+
sentence-transformers
|
4 |
+
transformers
|
5 |
+
torch
|
6 |
+
scikit-learn
|
reranking.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def load_ranking_model(model_name):
|
5 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
+
return model, tokenizer
|
8 |
+
|
9 |
+
def rerank(model, tokenizer, query, top_k_passages):
|
10 |
+
inputs = tokenizer([f"{query} [SEP] {passage}" for passage, _ in top_k_passages], return_tensors="pt", truncation=True, padding=True)
|
11 |
+
with torch.no_grad():
|
12 |
+
outputs = model(**inputs).logits
|
13 |
+
scores = outputs.squeeze(-1)
|
14 |
+
|
15 |
+
ranked_passages = sorted(zip(top_k_passages, scores), key=lambda x: x[1], reverse=True)
|
16 |
+
return [(passage, score.item()) for (passage, _), score in ranked_passages]
|
retrieval.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer, util
|
2 |
+
|
3 |
+
def load_embedding_model(model_name):
|
4 |
+
return SentenceTransformer(model_name)
|
5 |
+
|
6 |
+
def retrieve_top_k(model, query, corpus, k=10):
|
7 |
+
query_embedding = model.encode(query, convert_to_tensor=True)
|
8 |
+
corpus_embeddings = model.encode([corpus[doc_id]["text"] for doc_id in corpus], convert_to_tensor=True)
|
9 |
+
|
10 |
+
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=k)[0]
|
11 |
+
top_k_passages = [(corpus[list(corpus.keys())[hit['corpus_id']]]["text"], hit['score']) for hit in hits]
|
12 |
+
|
13 |
+
return top_k_passages
|