Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- .gitattributes +1 -0
- analysis.py +84 -0
- app.py +125 -0
- docs_data.pkl +3 -0
- faiss_index/index.faiss +3 -0
- faiss_index/index.pkl +3 -0
- requirements.txt +9 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
|
analysis.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
+
import streamlit as st
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
def calculate_word_overlaps(documents: List[str], query: str):
|
9 |
+
"""
|
10 |
+
Calculate the average word overlaps between documents and the query.
|
11 |
+
"""
|
12 |
+
query_words = set(query.lower().split())
|
13 |
+
word_overlaps = []
|
14 |
+
|
15 |
+
for doc in documents:
|
16 |
+
doc_words = set(doc.lower().split())
|
17 |
+
overlap = len(query_words.intersection(doc_words))
|
18 |
+
word_overlaps.append(overlap)
|
19 |
+
|
20 |
+
if len(word_overlaps) > 0:
|
21 |
+
average_word_overlap = np.mean(word_overlaps)
|
22 |
+
else:
|
23 |
+
average_word_overlap = 0.0
|
24 |
+
|
25 |
+
return average_word_overlap
|
26 |
+
|
27 |
+
def calculate_duplication_rate(documents: List[str]):
|
28 |
+
"""
|
29 |
+
Calculate the duplication rate among a list of documents.
|
30 |
+
"""
|
31 |
+
total_words_set = set()
|
32 |
+
total_words = 0
|
33 |
+
|
34 |
+
for doc in documents:
|
35 |
+
doc_words = doc.lower().split()
|
36 |
+
total_words_set.update(doc_words)
|
37 |
+
total_words += len(doc_words)
|
38 |
+
|
39 |
+
if total_words > 0:
|
40 |
+
duplication_rate = (total_words - len(total_words_set)) / total_words
|
41 |
+
else:
|
42 |
+
duplication_rate = 0.0
|
43 |
+
|
44 |
+
return duplication_rate
|
45 |
+
|
46 |
+
|
47 |
+
def cosine_similarity_score(documents: List[str], query: str):
|
48 |
+
"""
|
49 |
+
Calculate cosine similarity between the query and each document.
|
50 |
+
"""
|
51 |
+
tfidf_vectorizer = TfidfVectorizer()
|
52 |
+
tfidf_matrix = tfidf_vectorizer.fit_transform([query] + documents)
|
53 |
+
cosine_similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
|
54 |
+
return cosine_similarities[0]
|
55 |
+
|
56 |
+
def jaccard_similarity_score(documents: List[str], query: str):
|
57 |
+
"""
|
58 |
+
Calculate Jaccard similarity between the query and each document.
|
59 |
+
"""
|
60 |
+
query_words = set(query.lower().split())
|
61 |
+
jaccard_similarities = []
|
62 |
+
|
63 |
+
for doc in documents:
|
64 |
+
doc_words = set(doc.lower().split())
|
65 |
+
intersection_size = len(query_words.intersection(doc_words))
|
66 |
+
union_size = len(query_words.union(doc_words))
|
67 |
+
jaccard_similarity = intersection_size / union_size if union_size > 0 else 0
|
68 |
+
jaccard_similarities.append(jaccard_similarity)
|
69 |
+
|
70 |
+
return jaccard_similarities
|
71 |
+
|
72 |
+
|
73 |
+
def display_similarity_results(cosine_scores, jaccard_scores, title):
|
74 |
+
st.subheader(f"{title} - Cosine Similarity to Query")
|
75 |
+
plt.bar(range(len(cosine_scores)), cosine_scores)
|
76 |
+
plt.xlabel("Documents")
|
77 |
+
plt.ylabel("Cosine Similarity")
|
78 |
+
st.pyplot(plt)
|
79 |
+
|
80 |
+
st.subheader(f"{title} - Jaccard Similarity to Query")
|
81 |
+
plt.bar(range(len(jaccard_scores)), jaccard_scores, color='orange')
|
82 |
+
plt.xlabel("Documents")
|
83 |
+
plt.ylabel("Jaccard Similarity")
|
84 |
+
st.pyplot(plt)
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
|
7 |
+
from langchain.document_transformers import EmbeddingsRedundantFilter
|
8 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
9 |
+
from langchain.text_splitter import CharacterTextSplitter
|
10 |
+
|
11 |
+
|
12 |
+
from analysis import calculate_word_overlaps, calculate_duplication_rate, cosine_similarity_score, jaccard_similarity_score, display_similarity_results
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
with open("docs_data.pkl", "rb") as file:
|
17 |
+
docs = pickle.load(file)
|
18 |
+
|
19 |
+
metadata_list = []
|
20 |
+
unique_metadata_list = []
|
21 |
+
seen = set()
|
22 |
+
|
23 |
+
embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-large")
|
24 |
+
vectorstore = FAISS.load_local("faiss_index", embeddings)
|
25 |
+
retriever = vectorstore.as_retriever(search_type="similarity")
|
26 |
+
|
27 |
+
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
|
28 |
+
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
29 |
+
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
|
30 |
+
pipeline_compressor = DocumentCompressorPipeline(
|
31 |
+
transformers=[splitter, redundant_filter, relevant_filter]
|
32 |
+
)
|
33 |
+
|
34 |
+
bm25_retriever = BM25Retriever.from_texts(docs)
|
35 |
+
|
36 |
+
st.title("Document Retrieval App")
|
37 |
+
|
38 |
+
vecotstore_k = st.number_input("Set k value for Dense Retriever:", value=5, min_value=1, step=1)
|
39 |
+
bm25_k = st.number_input("Set k value for sparse Retriever:", value=2, min_value=1, step=1)
|
40 |
+
|
41 |
+
retriever.search_kwargs["k"] = vecotstore_k
|
42 |
+
bm25_retriever.k = bm25_k
|
43 |
+
|
44 |
+
compressed_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
|
45 |
+
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
|
46 |
+
|
47 |
+
query = st.text_input("Enter a query:", "what is a horizontal conflict")
|
48 |
+
|
49 |
+
if st.button("Retrieve Documents"):
|
50 |
+
|
51 |
+
compressed_ensemble_retriever = EnsembleRetriever(retrievers=[compressed_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
|
52 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[retriever, bm25_retriever], weights=[0.5, 0.5])
|
53 |
+
|
54 |
+
with st.expander("Retrieved Documents"):
|
55 |
+
col1, col2 = st.columns(2)
|
56 |
+
|
57 |
+
with col1:
|
58 |
+
st.header("Without Compression")
|
59 |
+
normal_results = ensemble_retriever.get_relevant_documents(query)
|
60 |
+
for doc in normal_results:
|
61 |
+
st.write(doc.page_content)
|
62 |
+
st.write("---")
|
63 |
+
|
64 |
+
with col2:
|
65 |
+
st.header("With Compression")
|
66 |
+
compressed_results = compressed_ensemble_retriever.get_relevant_documents(query)
|
67 |
+
for doc in compressed_results:
|
68 |
+
st.write(doc.page_content)
|
69 |
+
st.write("---")
|
70 |
+
|
71 |
+
if hasattr(doc, 'metadata'):
|
72 |
+
metadata = doc.metadata
|
73 |
+
metadata_list.append(metadata)
|
74 |
+
|
75 |
+
for metadata in metadata_list:
|
76 |
+
metadata_tuple = tuple(metadata.items())
|
77 |
+
if metadata_tuple not in seen:
|
78 |
+
unique_metadata_list.append(metadata)
|
79 |
+
seen.add(metadata_tuple)
|
80 |
+
|
81 |
+
st.write(unique_metadata_list)
|
82 |
+
|
83 |
+
with st.expander("Analysis"):
|
84 |
+
st.write("Analysis of Retrieval Results")
|
85 |
+
|
86 |
+
total_words_normal = sum(len(doc.page_content.split()) for doc in normal_results)
|
87 |
+
total_words_compressed = sum(len(doc.page_content.split()) for doc in compressed_results)
|
88 |
+
reduction_percentage = ((total_words_normal - total_words_compressed) / total_words_normal) * 100
|
89 |
+
|
90 |
+
col1, col2 = st.columns(2)
|
91 |
+
|
92 |
+
|
93 |
+
st.write(f"Total words in documents (Normal): {total_words_normal}")
|
94 |
+
st.write(f"Total words in documents (Compressed): {total_words_compressed}")
|
95 |
+
st.write(f"Reduction Percentage: {reduction_percentage:.2f}%")
|
96 |
+
|
97 |
+
average_word_overlap_normal = calculate_word_overlaps([doc.page_content for doc in normal_results], query)
|
98 |
+
average_word_overlap_compressed = calculate_word_overlaps([doc.page_content for doc in compressed_results], query)
|
99 |
+
|
100 |
+
duplication_rate_normal = calculate_duplication_rate([doc.page_content for doc in normal_results])
|
101 |
+
duplication_rate_compressed = calculate_duplication_rate([doc.page_content for doc in compressed_results])
|
102 |
+
|
103 |
+
cosine_scores_normal = cosine_similarity_score([doc.page_content for doc in normal_results], query)
|
104 |
+
jaccard_scores_normal = jaccard_similarity_score([doc.page_content for doc in normal_results], query)
|
105 |
+
|
106 |
+
cosine_scores_compressed = cosine_similarity_score([doc.page_content for doc in compressed_results], query)
|
107 |
+
jaccard_scores_compressed = jaccard_similarity_score([doc.page_content for doc in compressed_results], query)
|
108 |
+
|
109 |
+
with col1:
|
110 |
+
st.subheader("Normal")
|
111 |
+
|
112 |
+
st.write(f"Average Word Overlap: {average_word_overlap_normal:.2f}")
|
113 |
+
st.write(f"Duplication Rate: {duplication_rate_normal:.2%}")
|
114 |
+
|
115 |
+
st.write("Results without Compression:")
|
116 |
+
display_similarity_results(cosine_scores_normal, jaccard_scores_normal, "")
|
117 |
+
|
118 |
+
with col2:
|
119 |
+
st.subheader("Compressed")
|
120 |
+
|
121 |
+
st.write(f"Average Word Overlap: {average_word_overlap_compressed:.2f}")
|
122 |
+
st.write(f"Duplication Rate: {duplication_rate_compressed:.2%}")
|
123 |
+
|
124 |
+
st.write("Results with Compression:")
|
125 |
+
display_similarity_results(cosine_scores_compressed, jaccard_scores_compressed, "")
|
docs_data.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58e85d1679c567117937df49e7294968f0851ec52c3e6858f82cea45f0034a9f
|
3 |
+
size 7965757
|
faiss_index/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dcc64cacbcff655b1220247180d2944d22175f8670f28a0d04520e49d2eaa793
|
3 |
+
size 19587117
|
faiss_index/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c8ea8d70b85579b394443d0caddb4fd12601533b15a259f9d7a722a86537841
|
3 |
+
size 8545256
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
openai
|
3 |
+
streamlit_option_menu
|
4 |
+
pypdf
|
5 |
+
rank_bm25
|
6 |
+
faiss-cpu
|
7 |
+
tiktoken
|
8 |
+
scikit-learn
|
9 |
+
sentence_transformers
|