Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload chain.py
Browse files- rag_chain/chain.py +4 -5
rag_chain/chain.py
CHANGED
@@ -121,17 +121,16 @@ def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[
|
|
121 |
k=15)
|
122 |
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
123 |
|
124 |
-
#
|
125 |
-
# Using only the filtered sparse retriever
|
126 |
practitioners_ensemble_retriever = EnsembleRetriever(
|
127 |
retrievers=[practitioners_db_dense_retriever,
|
128 |
-
practitioners_db_sparse_retriever], weights=[0.
|
129 |
)
|
130 |
|
131 |
# Compression retriever for practitioners db
|
132 |
-
# TODO
|
133 |
practitioners_db_compression_retriever = compression_retriever_setup(
|
134 |
-
|
135 |
embeddings_model="text-embedding-ada-002",
|
136 |
similarity_threshold=0.74
|
137 |
)
|
|
|
121 |
k=15)
|
122 |
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
123 |
|
124 |
+
# Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
|
|
|
125 |
practitioners_ensemble_retriever = EnsembleRetriever(
|
126 |
retrievers=[practitioners_db_dense_retriever,
|
127 |
+
practitioners_db_sparse_retriever], weights=[0.2, 0.8]
|
128 |
)
|
129 |
|
130 |
# Compression retriever for practitioners db
|
131 |
+
# TODO
|
132 |
practitioners_db_compression_retriever = compression_retriever_setup(
|
133 |
+
practitioners_ensemble_retriever,
|
134 |
embeddings_model="text-embedding-ada-002",
|
135 |
similarity_threshold=0.74
|
136 |
)
|