ethanrom commited on
Commit
fc850f2
1 Parent(s): 372d85c

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
analysis.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ import streamlit as st
6
+ import matplotlib.pyplot as plt
7
+
8
+ def calculate_word_overlaps(documents: List[str], query: str):
9
+ """
10
+ Calculate the average word overlaps between documents and the query.
11
+ """
12
+ query_words = set(query.lower().split())
13
+ word_overlaps = []
14
+
15
+ for doc in documents:
16
+ doc_words = set(doc.lower().split())
17
+ overlap = len(query_words.intersection(doc_words))
18
+ word_overlaps.append(overlap)
19
+
20
+ if len(word_overlaps) > 0:
21
+ average_word_overlap = np.mean(word_overlaps)
22
+ else:
23
+ average_word_overlap = 0.0
24
+
25
+ return average_word_overlap
26
+
27
+ def calculate_duplication_rate(documents: List[str]):
28
+ """
29
+ Calculate the duplication rate among a list of documents.
30
+ """
31
+ total_words_set = set()
32
+ total_words = 0
33
+
34
+ for doc in documents:
35
+ doc_words = doc.lower().split()
36
+ total_words_set.update(doc_words)
37
+ total_words += len(doc_words)
38
+
39
+ if total_words > 0:
40
+ duplication_rate = (total_words - len(total_words_set)) / total_words
41
+ else:
42
+ duplication_rate = 0.0
43
+
44
+ return duplication_rate
45
+
46
+
47
+ def cosine_similarity_score(documents: List[str], query: str):
48
+ """
49
+ Calculate cosine similarity between the query and each document.
50
+ """
51
+ tfidf_vectorizer = TfidfVectorizer()
52
+ tfidf_matrix = tfidf_vectorizer.fit_transform([query] + documents)
53
+ cosine_similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
54
+ return cosine_similarities[0]
55
+
56
+ def jaccard_similarity_score(documents: List[str], query: str):
57
+ """
58
+ Calculate Jaccard similarity between the query and each document.
59
+ """
60
+ query_words = set(query.lower().split())
61
+ jaccard_similarities = []
62
+
63
+ for doc in documents:
64
+ doc_words = set(doc.lower().split())
65
+ intersection_size = len(query_words.intersection(doc_words))
66
+ union_size = len(query_words.union(doc_words))
67
+ jaccard_similarity = intersection_size / union_size if union_size > 0 else 0
68
+ jaccard_similarities.append(jaccard_similarity)
69
+
70
+ return jaccard_similarities
71
+
72
+
73
+ def display_similarity_results(cosine_scores, jaccard_scores, title):
74
+ st.subheader(f"{title} - Cosine Similarity to Query")
75
+ plt.bar(range(len(cosine_scores)), cosine_scores)
76
+ plt.xlabel("Documents")
77
+ plt.ylabel("Cosine Similarity")
78
+ st.pyplot(plt)
79
+
80
+ st.subheader(f"{title} - Jaccard Similarity to Query")
81
+ plt.bar(range(len(jaccard_scores)), jaccard_scores, color='orange')
82
+ plt.xlabel("Documents")
83
+ plt.ylabel("Jaccard Similarity")
84
+ st.pyplot(plt)
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import os
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
7
+ from langchain.document_transformers import EmbeddingsRedundantFilter
8
+ from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+
11
+
12
+ from analysis import calculate_word_overlaps, calculate_duplication_rate, cosine_similarity_score, jaccard_similarity_score, display_similarity_results
13
+
14
+
15
+
16
+ with open("docs_data.pkl", "rb") as file:
17
+ docs = pickle.load(file)
18
+
19
+ metadata_list = []
20
+ unique_metadata_list = []
21
+ seen = set()
22
+
23
+ embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-large")
24
+ vectorstore = FAISS.load_local("faiss_index", embeddings)
25
+ retriever = vectorstore.as_retriever(search_type="similarity")
26
+
27
+ splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
28
+ redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
29
+ relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
30
+ pipeline_compressor = DocumentCompressorPipeline(
31
+ transformers=[splitter, redundant_filter, relevant_filter]
32
+ )
33
+
34
+ bm25_retriever = BM25Retriever.from_texts(docs)
35
+
36
+ st.title("Document Retrieval App")
37
+
38
+ vecotstore_k = st.number_input("Set k value for Dense Retriever:", value=5, min_value=1, step=1)
39
+ bm25_k = st.number_input("Set k value for sparse Retriever:", value=2, min_value=1, step=1)
40
+
41
+ retriever.search_kwargs["k"] = vecotstore_k
42
+ bm25_retriever.k = bm25_k
43
+
44
+ compressed_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
45
+ bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
46
+
47
+ query = st.text_input("Enter a query:", "what is a horizontal conflict")
48
+
49
+ if st.button("Retrieve Documents"):
50
+
51
+ compressed_ensemble_retriever = EnsembleRetriever(retrievers=[compressed_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
52
+ ensemble_retriever = EnsembleRetriever(retrievers=[retriever, bm25_retriever], weights=[0.5, 0.5])
53
+
54
+ with st.expander("Retrieved Documents"):
55
+ col1, col2 = st.columns(2)
56
+
57
+ with col1:
58
+ st.header("Without Compression")
59
+ normal_results = ensemble_retriever.get_relevant_documents(query)
60
+ for doc in normal_results:
61
+ st.write(doc.page_content)
62
+ st.write("---")
63
+
64
+ with col2:
65
+ st.header("With Compression")
66
+ compressed_results = compressed_ensemble_retriever.get_relevant_documents(query)
67
+ for doc in compressed_results:
68
+ st.write(doc.page_content)
69
+ st.write("---")
70
+
71
+ if hasattr(doc, 'metadata'):
72
+ metadata = doc.metadata
73
+ metadata_list.append(metadata)
74
+
75
+ for metadata in metadata_list:
76
+ metadata_tuple = tuple(metadata.items())
77
+ if metadata_tuple not in seen:
78
+ unique_metadata_list.append(metadata)
79
+ seen.add(metadata_tuple)
80
+
81
+ st.write(unique_metadata_list)
82
+
83
+ with st.expander("Analysis"):
84
+ st.write("Analysis of Retrieval Results")
85
+
86
+ total_words_normal = sum(len(doc.page_content.split()) for doc in normal_results)
87
+ total_words_compressed = sum(len(doc.page_content.split()) for doc in compressed_results)
88
+ reduction_percentage = ((total_words_normal - total_words_compressed) / total_words_normal) * 100
89
+
90
+ col1, col2 = st.columns(2)
91
+
92
+
93
+ st.write(f"Total words in documents (Normal): {total_words_normal}")
94
+ st.write(f"Total words in documents (Compressed): {total_words_compressed}")
95
+ st.write(f"Reduction Percentage: {reduction_percentage:.2f}%")
96
+
97
+ average_word_overlap_normal = calculate_word_overlaps([doc.page_content for doc in normal_results], query)
98
+ average_word_overlap_compressed = calculate_word_overlaps([doc.page_content for doc in compressed_results], query)
99
+
100
+ duplication_rate_normal = calculate_duplication_rate([doc.page_content for doc in normal_results])
101
+ duplication_rate_compressed = calculate_duplication_rate([doc.page_content for doc in compressed_results])
102
+
103
+ cosine_scores_normal = cosine_similarity_score([doc.page_content for doc in normal_results], query)
104
+ jaccard_scores_normal = jaccard_similarity_score([doc.page_content for doc in normal_results], query)
105
+
106
+ cosine_scores_compressed = cosine_similarity_score([doc.page_content for doc in compressed_results], query)
107
+ jaccard_scores_compressed = jaccard_similarity_score([doc.page_content for doc in compressed_results], query)
108
+
109
+ with col1:
110
+ st.subheader("Normal")
111
+
112
+ st.write(f"Average Word Overlap: {average_word_overlap_normal:.2f}")
113
+ st.write(f"Duplication Rate: {duplication_rate_normal:.2%}")
114
+
115
+ st.write("Results without Compression:")
116
+ display_similarity_results(cosine_scores_normal, jaccard_scores_normal, "")
117
+
118
+ with col2:
119
+ st.subheader("Compressed")
120
+
121
+ st.write(f"Average Word Overlap: {average_word_overlap_compressed:.2f}")
122
+ st.write(f"Duplication Rate: {duplication_rate_compressed:.2%}")
123
+
124
+ st.write("Results with Compression:")
125
+ display_similarity_results(cosine_scores_compressed, jaccard_scores_compressed, "")
docs_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58e85d1679c567117937df49e7294968f0851ec52c3e6858f82cea45f0034a9f
3
+ size 7965757
faiss_index/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcc64cacbcff655b1220247180d2944d22175f8670f28a0d04520e49d2eaa793
3
+ size 19587117
faiss_index/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c8ea8d70b85579b394443d0caddb4fd12601533b15a259f9d7a722a86537841
3
+ size 8545256
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ openai
3
+ streamlit_option_menu
4
+ pypdf
5
+ rank_bm25
6
+ faiss-cpu
7
+ tiktoken
8
+ scikit-learn
9
+ sentence_transformers