Hyma7 commited on
Commit
d47b442
·
verified ·
1 Parent(s): f0c4204

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -87
app.py CHANGED
@@ -1,103 +1,51 @@
1
  import os
2
- import numpy as np
3
- import faiss
4
  import streamlit as st
5
- from sentence_transformers import SentenceTransformer
6
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
- from beir import util
8
- from beir.datasets.data_loader import GenericDataLoader
9
- #from beir import EvaluateRetrieval
10
 
 
 
 
 
11
 
12
- # Function to load the dataset
13
- def load_dataset():
14
- dataset_name = "nq"
15
- data_path = f"datasets/{dataset_name}.zip"
16
- if not os.path.exists(data_path):
17
- url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
18
- util.download_and_unzip(url, "datasets/")
19
-
20
- corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
21
- return corpus, queries, qrels
22
-
23
- # Function for candidate retrieval
24
- def candidate_retrieval(corpus, queries):
25
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
26
- corpus_ids = list(corpus.keys())
27
- corpus_texts = [corpus[pid]["text"] for pid in corpus_ids]
28
- corpus_embeddings = embed_model.encode(corpus_texts, convert_to_numpy=True)
29
 
30
- index = faiss.IndexFlatL2(corpus_embeddings.shape[1])
31
- index.add(corpus_embeddings)
 
 
 
 
32
 
33
- query_texts = [queries[qid] for qid in queries.keys()]
34
- query_embeddings = embed_model.encode(query_texts, convert_to_numpy=True)
35
 
36
- _, retrieved_indices = index.search(query_embeddings, 10)
37
- return retrieved_indices, corpus_ids
 
 
 
38
 
39
- # Function for reranking
40
- def rerank_passages(retrieved_indices, corpus, queries):
41
- cross_encoder_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
42
- tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
 
 
43
 
44
- reranked_passages = []
45
- for i, query in enumerate(queries.values()):
46
- query_passage_pairs = [(query, corpus[corpus_ids[idx]]["text"]) for idx in retrieved_indices[i]]
47
- inputs = tokenizer(query_passage_pairs, padding=True, truncation=True, return_tensors="pt")
48
- scores = cross_encoder_model(**inputs).logits.squeeze(-1)
49
-
50
- top_reranked_passages = [passage for _, passage in sorted(zip(scores, query_passage_pairs), key=lambda x: x[0], reverse=True)]
51
- reranked_passages.append(top_reranked_passages)
52
-
53
- return reranked_passages
54
 
55
- # Function for evaluation
56
- """"
57
- def evaluate(qrels, retrieved_indices, reranked_passages, queries):
58
- evaluator = EvaluateRetrieval()
59
-
60
- results_stage1 = {}
61
- for i, query_id in enumerate(queries.keys()):
62
- results_stage1[query_id] = {corpus_ids[idx]: 1 for idx in retrieved_indices[i]}
63
-
64
- ndcg_score_stage1 = evaluator.evaluate(qrels, results_stage1, [10])['NDCG@10']
65
-
66
- results_stage2 = {}
67
- for i, query_id in enumerate(queries.keys()):
68
- results_stage2[query_id] = {}
69
- for passage in reranked_passages[i]:
70
- for pid, doc in corpus.items():
71
- if doc["text"] == passage[1]:
72
- results_stage2[query_id][pid] = 1
73
- break
74
-
75
- ndcg_score_stage2 = evaluator.evaluate(qrels, results_stage2, [10])['NDCG@10']
76
- return ndcg_score_stage1, ndcg_score_stage2
77
- """
78
- # Streamlit app
79
  def main():
80
- st.title("Multi-Stage Text Retrieval Pipeline")
81
-
82
- if st.button("Load Dataset"):
83
- corpus, queries, qrels = load_dataset()
84
- st.success("Dataset loaded successfully!")
85
-
86
- if st.button("Run Candidate Retrieval"):
87
- retrieved_indices, corpus_ids = candidate_retrieval(corpus, queries)
88
- st.success("Candidate retrieval completed!")
89
- st.write("Retrieved indices:", retrieved_indices)
90
 
91
- if st.button("Run Reranking"):
92
- reranked_passages = rerank_passages(retrieved_indices, corpus, queries)
93
- st.success("Reranking completed!")
94
- st.write("Reranked passages:", reranked_passages)
95
 
96
- """if st.button("Evaluate"):
97
- ndcg_score_stage1, ndcg_score_stage2 = evaluate(qrels, retrieved_indices, reranked_passages, queries)
98
- st.write(f"NDCG@10 for Stage 1 (Candidate Retrieval): {ndcg_score_stage1}")
99
- st.write(f"NDCG@10 for Stage 2 (Reranking): {ndcg_score_stage2}")
100
- """
101
 
102
  if __name__ == "__main__":
103
  main()
 
1
  import os
2
+ import urllib.request
3
+ import zipfile
4
  import streamlit as st
 
 
 
 
 
5
 
6
+ def download_and_extract_dataset():
7
+ dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
8
+ dataset_zip_path = "nq.zip"
9
+ data_path = "./datasets/nq"
10
 
11
+ # Download the dataset if not already downloaded
12
+ if not os.path.exists(dataset_zip_path):
13
+ st.write("Downloading the dataset... This may take a few minutes.")
14
+ urllib.request.urlretrieve(dataset_url, dataset_zip_path)
15
+ st.write("Download complete!")
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Unzip the dataset if not already unzipped
18
+ if not os.path.exists(data_path):
19
+ st.write("Unzipping the dataset...")
20
+ with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
21
+ zip_ref.extractall("./datasets")
22
+ st.write("Dataset unzipped!")
23
 
24
+ return data_path
 
25
 
26
+ # Function to load the dataset
27
+ def load_dataset():
28
+ from beir.datasets.data_loader import GenericDataLoader
29
+
30
+ data_path = download_and_extract_dataset()
31
 
32
+ # Load dataset using GenericDataLoader
33
+ st.write("Loading the dataset...")
34
+ corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
35
+ st.write(f"Corpus Size: {len(corpus)}")
36
+ st.write(f"Queries Size: {len(queries)}")
37
+ st.write(f"Qrels Size: {len(qrels)}")
38
 
39
+ return corpus, queries, qrels
 
 
 
 
 
 
 
 
 
40
 
41
+ # Streamlit main execution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def main():
43
+ st.title("Multi-Stage Retrieval Pipeline")
 
 
 
 
 
 
 
 
 
44
 
45
+ # Load the dataset
46
+ corpus, queries, qrels = load_dataset()
 
 
47
 
48
+ st.write("Dataset loaded successfully!")
 
 
 
 
49
 
50
  if __name__ == "__main__":
51
  main()