Hyma7 commited on
Commit
95aac88
·
verified ·
1 Parent(s): dff895f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -37
app.py CHANGED
@@ -1,46 +1,45 @@
1
  import streamlit as st
2
- import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
- from transformers import pipeline
 
5
 
6
- # Sample passages
7
- passages = [
8
- "The sky is blue.",
9
- "The grass is green.",
10
- "The sun is bright.",
11
- "Rain falls from the sky.",
12
- "Flowers bloom in spring."
13
- ]
14
 
15
  # Load models
16
  embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
17
- ranking_model = pipeline("text-classification", model='cross-encoder/ms-marco-MiniLM-L-12-v2')
18
 
19
- def get_relevant_passages(question, passages):
20
- keywords = question.lower().split()
21
- relevant_passages = [p for p in passages if any(keyword in p.lower() for keyword in keywords)]
22
- return relevant_passages if relevant_passages else passages # Return all if no match
23
 
24
- def main():
25
- st.title("Multi-Stage Text Retrieval Pipeline for QA")
26
- question = st.text_input("Enter a question:")
27
 
28
- if question:
29
- relevant_passages = get_relevant_passages(question, passages)
30
- st.write("Relevant passages:")
31
- for p in relevant_passages:
32
- st.write(f"- {p}")
33
-
34
- # Embedding and ranking
35
- if st.button("Retrieve Answers"):
36
- passage_embeddings = embedding_model.encode(relevant_passages)
37
- question_embedding = embedding_model.encode(question)
38
- scores = np.dot(passage_embeddings, question_embedding.T)
39
- ranked_indices = np.argsort(scores)[::-1]
40
-
41
- st.write("Ranked passages:")
42
- for idx in ranked_indices:
43
- st.write(f"- {relevant_passages[idx]} (Score: {scores[idx]:.2f})")
44
-
45
- if __name__ == "__main__":
46
- main()
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from sentence_transformers import SentenceTransformer
4
+ from transformers import CrossEncoder
5
+ import numpy as np
6
 
7
+ # Load the dataset
8
+ def load_dataset():
9
+ # Load the Databricks Dolly 15K dataset
10
+ return pd.read_csv('dolly_15k.csv')
 
 
 
 
11
 
12
  # Load models
13
  embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
+ ranking_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
15
 
16
+ # Streamlit UI
17
+ st.title("Multi-Stage Text Retrieval Pipeline for QA")
 
 
18
 
19
+ question = st.text_input("Enter a question:")
20
+ if question:
21
+ dataset = load_dataset()
22
 
23
+ # Generate embeddings for the questions and the dataset passages
24
+ passages = dataset['response'].tolist() # Adjust this according to your dataset's structure
25
+ question_embedding = embedding_model.encode(question)
26
+ passage_embeddings = embedding_model.encode(passages)
27
+
28
+ # Retrieve top-k passages based on embeddings
29
+ top_k = 5
30
+ similarities = np.inner(question_embedding, passage_embeddings)
31
+ top_k_indices = np.argsort(similarities)[-top_k:][::-1]
32
+
33
+ relevant_passages = [passages[i] for i in top_k_indices]
34
+
35
+ st.subheader("Relevant passages:")
36
+ for passage in relevant_passages:
37
+ st.write(passage)
38
+
39
+ # Re-ranking the passages
40
+ ranked_scores = ranking_model.predict([[question, passage] for passage in relevant_passages])
41
+ ranked_passages = sorted(zip(relevant_passages, ranked_scores), key=lambda x: x[1], reverse=True)
42
+
43
+ st.subheader("Ranked passages:")
44
+ for passage, score in ranked_passages:
45
+ st.write(f"{passage} (Score: {score:.2f})")