Hyma7 commited on
Commit
e28d050
·
verified ·
1 Parent(s): d47b442

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -16
app.py CHANGED
@@ -1,9 +1,16 @@
1
- import os
2
- import urllib.request
3
- import zipfile
4
  import streamlit as st
 
 
 
 
 
5
 
 
6
  def download_and_extract_dataset():
 
 
 
 
7
  dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
8
  dataset_zip_path = "nq.zip"
9
  data_path = "./datasets/nq"
@@ -23,29 +30,67 @@ def download_and_extract_dataset():
23
 
24
  return data_path
25
 
26
- # Function to load the dataset
27
  def load_dataset():
28
  from beir.datasets.data_loader import GenericDataLoader
29
-
30
  data_path = download_and_extract_dataset()
31
-
32
- # Load dataset using GenericDataLoader
33
- st.write("Loading the dataset...")
34
  corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
35
- st.write(f"Corpus Size: {len(corpus)}")
36
- st.write(f"Queries Size: {len(queries)}")
37
- st.write(f"Qrels Size: {len(qrels)}")
38
-
39
  return corpus, queries, qrels
40
 
41
- # Streamlit main execution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def main():
43
  st.title("Multi-Stage Retrieval Pipeline")
44
 
45
- # Load the dataset
46
  corpus, queries, qrels = load_dataset()
47
-
48
- st.write("Dataset loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  if __name__ == "__main__":
51
  main()
 
 
 
 
1
  import streamlit as st
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ import numpy as np
6
+ from sklearn.metrics import ndcg_score
7
 
8
+ # Helper function to load the dataset
9
  def download_and_extract_dataset():
10
+ import urllib.request
11
+ import zipfile
12
+ import os
13
+
14
  dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
15
  dataset_zip_path = "nq.zip"
16
  data_path = "./datasets/nq"
 
30
 
31
  return data_path
32
 
33
+ # Function to load corpus, queries, and qrels
34
  def load_dataset():
35
  from beir.datasets.data_loader import GenericDataLoader
 
36
  data_path = download_and_extract_dataset()
 
 
 
37
  corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
 
 
 
 
38
  return corpus, queries, qrels
39
 
40
+ # Stage 1: Candidate retrieval using Sentence Transformer
41
+ def candidate_retrieval(query, corpus, top_k=10):
42
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
43
+ corpus_ids = list(corpus.keys())
44
+ corpus_embeddings = model.encode([corpus[doc_id]['text'] for doc_id in corpus_ids], convert_to_tensor=True)
45
+
46
+ query_embedding = model.encode(query, convert_to_tensor=True)
47
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
48
+
49
+ retrieved_docs = [corpus_ids[hit['corpus_id']] for hit in hits]
50
+ return retrieved_docs
51
+
52
+ # Stage 2: Reranking using cross-encoder
53
+ def rerank(retrieved_docs, query, corpus, top_k=5):
54
+ tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
55
+ model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
56
+
57
+ scores = []
58
+ for doc_id in retrieved_docs:
59
+ text = corpus[doc_id]['text']
60
+ inputs = tokenizer(query, text, return_tensors="pt", truncation=True, padding=True)
61
+ outputs = model(**inputs)
62
+ scores.append(outputs.logits.item())
63
+
64
+ reranked_indices = np.argsort(scores)[::-1][:top_k]
65
+ reranked_docs = [retrieved_docs[idx] for idx in reranked_indices]
66
+ return reranked_docs
67
+
68
+ # Streamlit main function
69
  def main():
70
  st.title("Multi-Stage Retrieval Pipeline")
71
 
72
+ st.write("Loading the dataset...")
73
  corpus, queries, qrels = load_dataset()
74
+ st.write(f"Corpus Size: {len(corpus)}")
75
+
76
+ # User input for asking a question
77
+ user_query = st.text_input("Ask a question:")
78
+
79
+ if user_query:
80
+ st.write(f"Your query: {user_query}")
81
+
82
+ st.write("Running Candidate Retrieval...")
83
+ retrieved_docs = candidate_retrieval(user_query, corpus, top_k=10)
84
+
85
+ st.write("Running Reranking...")
86
+ reranked_docs = rerank(retrieved_docs, user_query, corpus, top_k=5)
87
+
88
+ st.write("Top Reranked Documents:")
89
+ for doc_id in reranked_docs:
90
+ st.write(f"Document ID: {doc_id}")
91
+ st.write(f"Document Text: {corpus[doc_id]['text'][:500]}...") # Show the first 500 characters of the document
92
+
93
+ st.write("Query executed successfully!")
94
 
95
  if __name__ == "__main__":
96
  main()