Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload 3 files
Browse files- rag_chain/chain.py +96 -81
- rag_chain/prompt_template.py +6 -6
- rag_chain/retrievers_setup.py +124 -144
rag_chain/chain.py
CHANGED
@@ -13,57 +13,44 @@ from langchain_core.runnables import RunnableLambda
|
|
13 |
from langchain_openai.chat_models import ChatOpenAI
|
14 |
|
15 |
from .prompt_template import generate_prompt_template
|
16 |
-
from .retrievers_setup import (
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Helpers
|
20 |
|
21 |
|
22 |
def reorder_documents(docs: list[Document]) -> list[Document]:
|
23 |
-
"""
|
24 |
-
|
25 |
-
|
26 |
-
Args:
|
27 |
-
docs (list): List of Langchain documents
|
28 |
-
|
29 |
-
Returns:
|
30 |
-
list: Reordered list of Langchain documents
|
31 |
-
"""
|
32 |
-
reorder = LongContextReorder()
|
33 |
-
return reorder.transform_documents(docs)
|
34 |
|
35 |
|
36 |
def randomize_documents(documents: list[Document]) -> list[Document]:
|
37 |
-
"""Randomize
|
38 |
random.shuffle(documents)
|
39 |
return documents
|
40 |
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
docs (list[Documents]): List of Langchain documents
|
46 |
-
Returns:
|
47 |
-
docs (str):
|
48 |
-
"""
|
49 |
-
return f"\n{'-' * 3}\n".join(
|
50 |
-
[f"- Practitioner #{i+1}:\n\n\t" +
|
51 |
-
d.page_content for i, d in enumerate(docs)]
|
52 |
-
)
|
53 |
-
|
54 |
-
|
55 |
-
def format_tall_tree_docs(docs: list[Document]) -> str:
|
56 |
-
"""Format the tall_tree_db Documents to markdown.
|
57 |
-
Args:
|
58 |
-
docs (list[Documents]): List of Langchain documents
|
59 |
-
Returns:
|
60 |
-
docs (str):
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
@cache
|
@@ -74,8 +61,7 @@ def create_langsmith_client():
|
|
74 |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
75 |
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
|
76 |
if not langsmith_api_key:
|
77 |
-
raise EnvironmentError(
|
78 |
-
"Missing environment variable: LANGCHAIN_API_KEY")
|
79 |
return langsmith.Client()
|
80 |
|
81 |
|
@@ -83,7 +69,9 @@ def create_langsmith_client():
|
|
83 |
|
84 |
|
85 |
@cache
|
86 |
-
def get_rag_chain(
|
|
|
|
|
87 |
"""Set up runnable and chat memory
|
88 |
|
89 |
Args:
|
@@ -94,78 +82,105 @@ def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[
|
|
94 |
Runnable, Memory: Chain and Memory
|
95 |
"""
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
# Set up Langsmith to trace the chain
|
98 |
langsmith_tracing = create_langsmith_client()
|
99 |
|
100 |
# LLM and prompt template
|
101 |
-
llm = ChatOpenAI(
|
102 |
-
|
|
|
|
|
103 |
|
104 |
prompt = generate_prompt_template()
|
105 |
|
106 |
# Set retrievers pointing to the practitioners's dataset
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
110 |
|
111 |
# Qdrant db as a retriever
|
112 |
-
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever(
|
113 |
-
k=10)
|
114 |
|
115 |
-
#
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
sparse_retriever_client = SparseRetrieverClient(
|
119 |
-
collection_name=
|
120 |
-
vector_name=
|
121 |
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
122 |
-
k=
|
|
|
|
|
123 |
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
124 |
|
125 |
# Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
|
126 |
practitioners_ensemble_retriever = EnsembleRetriever(
|
127 |
-
retrievers=[
|
128 |
-
|
|
|
|
|
|
|
129 |
)
|
130 |
|
131 |
# Compression retriever for practitioners db
|
132 |
practitioners_db_compression_retriever = compression_retriever_setup(
|
133 |
practitioners_ensemble_retriever,
|
134 |
-
embeddings_model="
|
135 |
-
|
136 |
)
|
137 |
|
138 |
# Set retrievers pointing to the tall_tree_db
|
139 |
-
dense_retriever_client = DenseRetrieverClient(
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
# Compression retriever for tall_tree_db
|
144 |
tall_tree_db_compression_retriever = compression_retriever_setup(
|
145 |
tall_tree_db_dense_retriever,
|
146 |
-
embeddings_model="
|
147 |
-
|
148 |
)
|
149 |
|
150 |
# Set conversation history window memory. It only uses the last k interactions.
|
151 |
-
memory = ConversationBufferWindowMemory(
|
152 |
-
|
153 |
-
|
|
|
|
|
154 |
|
155 |
# Set up runnable using LCEL
|
156 |
-
setup_and_retrieval = {
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
| StrOutputParser()
|
169 |
-
)
|
170 |
|
171 |
return chain, memory
|
|
|
13 |
from langchain_openai.chat_models import ChatOpenAI
|
14 |
|
15 |
from .prompt_template import generate_prompt_template
|
16 |
+
from .retrievers_setup import (
|
17 |
+
DenseRetrieverClient,
|
18 |
+
SparseRetrieverClient,
|
19 |
+
compression_retriever_setup,
|
20 |
+
multi_query_retriever_setup,
|
21 |
+
)
|
22 |
|
23 |
# Helpers
|
24 |
|
25 |
|
26 |
def reorder_documents(docs: list[Document]) -> list[Document]:
|
27 |
+
"""Reorder documents to mitigate performance degradation with long contexts."""
|
28 |
+
return LongContextReorder().transform_documents(docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def randomize_documents(documents: list[Document]) -> list[Document]:
|
32 |
+
"""Randomize documents to vary model recommendations."""
|
33 |
random.shuffle(documents)
|
34 |
return documents
|
35 |
|
36 |
|
37 |
+
class DocumentFormatter:
|
38 |
+
def __init__(self, prefix: str):
|
39 |
+
self.prefix = prefix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
def __call__(self, docs: list[Document]) -> str:
|
42 |
+
"""Format the Documents to markdown.
|
43 |
+
Args:
|
44 |
+
docs (list[Documents]): List of Langchain documents
|
45 |
+
Returns:
|
46 |
+
docs (str):
|
47 |
+
"""
|
48 |
+
return f"\n---\n".join(
|
49 |
+
[
|
50 |
+
f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
|
51 |
+
for i, d in enumerate(docs)
|
52 |
+
]
|
53 |
+
)
|
54 |
|
55 |
|
56 |
@cache
|
|
|
61 |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
62 |
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
|
63 |
if not langsmith_api_key:
|
64 |
+
raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
|
|
|
65 |
return langsmith.Client()
|
66 |
|
67 |
|
|
|
69 |
|
70 |
|
71 |
@cache
|
72 |
+
def get_rag_chain(
|
73 |
+
model_name: str = "gpt-4", temperature: float = 0.2
|
74 |
+
) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]:
|
75 |
"""Set up runnable and chat memory
|
76 |
|
77 |
Args:
|
|
|
82 |
Runnable, Memory: Chain and Memory
|
83 |
"""
|
84 |
|
85 |
+
RETRIEVER_PARAMETERS = {
|
86 |
+
"embeddings_model": "text-embedding-3-small",
|
87 |
+
"k_dense_practitioners_db": 8,
|
88 |
+
"k_sparse_practitioners_db": 15,
|
89 |
+
"weights_ensemble_practitioners_db": [0.2, 0.8],
|
90 |
+
"k_compression_practitioners_db": 18,
|
91 |
+
"k_dense_talltree": 6,
|
92 |
+
"k_compression_talltree": 6,
|
93 |
+
}
|
94 |
+
|
95 |
# Set up Langsmith to trace the chain
|
96 |
langsmith_tracing = create_langsmith_client()
|
97 |
|
98 |
# LLM and prompt template
|
99 |
+
llm = ChatOpenAI(
|
100 |
+
model_name=model_name,
|
101 |
+
temperature=temperature,
|
102 |
+
)
|
103 |
|
104 |
prompt = generate_prompt_template()
|
105 |
|
106 |
# Set retrievers pointing to the practitioners's dataset
|
107 |
+
dense_retriever_client = DenseRetrieverClient(
|
108 |
+
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
109 |
+
collection_name="practitioners_db",
|
110 |
+
search_type="similarity",
|
111 |
+
k=RETRIEVER_PARAMETERS["k_dense_practitioners_db"],
|
112 |
+
) # k x 4 using multiquery retriever
|
113 |
|
114 |
# Qdrant db as a retriever
|
115 |
+
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever()
|
|
|
116 |
|
117 |
+
# Multiquery retriever using the dense retriever
|
118 |
+
practitioners_db_dense_multiquery_retriever = multi_query_retriever_setup(
|
119 |
+
practitioners_db_dense_retriever
|
120 |
+
)
|
121 |
+
|
122 |
+
# Sparse vector retriever
|
123 |
sparse_retriever_client = SparseRetrieverClient(
|
124 |
+
collection_name="practitioners_db_sparse_collection",
|
125 |
+
vector_name="sparse_vector",
|
126 |
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
127 |
+
k=RETRIEVER_PARAMETERS["k_sparse_practitioners_db"],
|
128 |
+
)
|
129 |
+
|
130 |
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
131 |
|
132 |
# Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
|
133 |
practitioners_ensemble_retriever = EnsembleRetriever(
|
134 |
+
retrievers=[
|
135 |
+
practitioners_db_dense_multiquery_retriever,
|
136 |
+
practitioners_db_sparse_retriever,
|
137 |
+
],
|
138 |
+
weights=RETRIEVER_PARAMETERS["weights_ensemble_practitioners_db"],
|
139 |
)
|
140 |
|
141 |
# Compression retriever for practitioners db
|
142 |
practitioners_db_compression_retriever = compression_retriever_setup(
|
143 |
practitioners_ensemble_retriever,
|
144 |
+
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
145 |
+
k=RETRIEVER_PARAMETERS["k_compression_practitioners_db"],
|
146 |
)
|
147 |
|
148 |
# Set retrievers pointing to the tall_tree_db
|
149 |
+
dense_retriever_client = DenseRetrieverClient(
|
150 |
+
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
151 |
+
collection_name="tall_tree_db",
|
152 |
+
search_type="similarity",
|
153 |
+
k=RETRIEVER_PARAMETERS["k_dense_talltree"],
|
154 |
+
)
|
155 |
+
|
156 |
+
tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever()
|
157 |
+
|
158 |
# Compression retriever for tall_tree_db
|
159 |
tall_tree_db_compression_retriever = compression_retriever_setup(
|
160 |
tall_tree_db_dense_retriever,
|
161 |
+
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
162 |
+
k=RETRIEVER_PARAMETERS["k_compression_talltree"],
|
163 |
)
|
164 |
|
165 |
# Set conversation history window memory. It only uses the last k interactions.
|
166 |
+
memory = ConversationBufferWindowMemory(
|
167 |
+
memory_key="history",
|
168 |
+
return_messages=True,
|
169 |
+
k=6,
|
170 |
+
)
|
171 |
|
172 |
# Set up runnable using LCEL
|
173 |
+
setup_and_retrieval = {
|
174 |
+
"practitioners_db": itemgetter("message")
|
175 |
+
| practitioners_db_compression_retriever
|
176 |
+
| DocumentFormatter("Practitioner #"),
|
177 |
+
"tall_tree_db": itemgetter("message")
|
178 |
+
| tall_tree_db_compression_retriever
|
179 |
+
| DocumentFormatter("No."),
|
180 |
+
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
|
181 |
+
"message": itemgetter("message"),
|
182 |
+
}
|
183 |
+
|
184 |
+
chain = setup_and_retrieval | prompt | llm | StrOutputParser()
|
|
|
|
|
185 |
|
186 |
return chain, memory
|
rag_chain/prompt_template.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from langchain.prompts import (
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
|
7 |
def generate_prompt_template():
|
@@ -83,8 +84,7 @@ You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Can
|
|
83 |
"""
|
84 |
|
85 |
# Template for system message with markdown formatting
|
86 |
-
system_message = SystemMessagePromptTemplate.from_template(
|
87 |
-
system_template)
|
88 |
|
89 |
prompt = ChatPromptTemplate.from_messages(
|
90 |
[
|
|
|
1 |
+
from langchain.prompts import (
|
2 |
+
ChatPromptTemplate,
|
3 |
+
SystemMessagePromptTemplate,
|
4 |
+
MessagesPlaceholder,
|
5 |
+
)
|
6 |
|
7 |
|
8 |
def generate_prompt_template():
|
|
|
84 |
"""
|
85 |
|
86 |
# Template for system message with markdown formatting
|
87 |
+
system_message = SystemMessagePromptTemplate.from_template(system_template)
|
|
|
88 |
|
89 |
prompt = ChatPromptTemplate.from_messages(
|
90 |
[
|
rag_chain/retrievers_setup.py
CHANGED
@@ -14,139 +14,144 @@ from langchain_openai.embeddings import OpenAIEmbeddings
|
|
14 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
15 |
|
16 |
|
17 |
-
class
|
18 |
-
"""
|
19 |
|
20 |
-
|
21 |
-
embeddings_model (str): The embeddings model to use. Right now only OpenAI text embeddings.
|
22 |
-
collection_name (str): Qdrant collection name.
|
23 |
-
client (QdrantClient): Qdrant client.
|
24 |
-
qdrant_collection (Qdrant): Qdrant collection.
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(self, embeddings_model: str = "text-embedding-ada-002", collection_name: str = "practitioners_db"):
|
28 |
self.validate_environment_variables()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.embeddings_model = embeddings_model
|
30 |
self.collection_name = collection_name
|
|
|
|
|
31 |
self.client = qdrant_client.QdrantClient(
|
32 |
url=os.getenv("QDRANT_URL"),
|
33 |
api_key=os.getenv("QDRANT_API_KEY"),
|
34 |
prefer_grpc=True,
|
35 |
)
|
36 |
-
self.
|
37 |
-
|
38 |
-
def validate_environment_variables(self):
|
39 |
-
""" Check if the Qdrant environment variables are set."""
|
40 |
-
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
|
41 |
-
for var in required_vars:
|
42 |
-
if not os.getenv(var):
|
43 |
-
raise EnvironmentError(f"Missing environment variable: {var}")
|
44 |
|
45 |
def set_qdrant_collection(self, embeddings):
|
46 |
"""Prepare the Qdrant collection for the embeddings model."""
|
47 |
-
return Qdrant(
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
|
|
|
51 |
@cache
|
52 |
-
def
|
53 |
"""Load Qdrant collection for a given embeddings model."""
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
def get_dense_retriever(self, search_type: str = "similarity", k: int = 4):
|
66 |
-
"""Set up retrievers (Qdrant vectorstore as retriever).
|
67 |
-
|
68 |
-
Args:
|
69 |
-
search_type (str, optional): similarity or mmr. Defaults to "similarity".
|
70 |
-
k (int, optional): Number of documents retrieved. Defaults to 4.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Retriever: Vectorstore as a retriever
|
74 |
-
"""
|
75 |
-
dense_retriever = self.qdrant_collection.as_retriever(search_type=search_type,
|
76 |
-
search_kwargs={
|
77 |
-
"k": k}
|
78 |
-
)
|
79 |
-
return dense_retriever
|
80 |
|
81 |
|
82 |
-
class SparseRetrieverClient:
|
83 |
-
"""
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
self.client = qdrant_client.QdrantClient(
|
96 |
-
"QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
|
|
|
97 |
self.model_id = splade_model_id
|
98 |
-
self.
|
|
|
99 |
self.collection_name = collection_name
|
100 |
self.vector_name = vector_name
|
101 |
self.k = k
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
|
|
|
109 |
@cache
|
110 |
-
def
|
111 |
-
"""Initialize the
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
116 |
-
return tokenizer, model
|
117 |
|
118 |
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
|
119 |
-
"""
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
"""
|
128 |
-
tokens = self.tokenizer(text, return_tensors="pt",
|
129 |
-
max_length=512, padding="max_length", truncation=True)
|
130 |
|
131 |
with torch.no_grad():
|
132 |
-
|
133 |
-
|
134 |
-
logits, attention_mask = output.logits, tokens.attention_mask
|
135 |
|
136 |
relu_log = torch.log1p(torch.relu(logits))
|
137 |
-
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
138 |
-
|
139 |
-
max_val, _ = torch.max(weighted_log, dim=1)
|
140 |
-
vec = max_val.squeeze()
|
141 |
|
142 |
-
|
143 |
-
|
|
|
144 |
|
145 |
return indices.tolist(), values.tolist()
|
146 |
|
147 |
-
def get_sparse_retriever(self):
|
|
|
148 |
|
149 |
-
|
150 |
client=self.client,
|
151 |
collection_name=self.collection_name,
|
152 |
sparse_vector_name=self.vector_name,
|
@@ -154,63 +159,38 @@ class SparseRetrieverClient:
|
|
154 |
k=self.k,
|
155 |
)
|
156 |
|
157 |
-
return sparse_retriever
|
158 |
-
|
159 |
-
|
160 |
-
def compression_retriever_setup(base_retriever, embeddings_model: str = "text-embedding-ada-002", similarity_threshold: float = 0.76) -> ContextualCompressionRetriever:
|
161 |
-
"""
|
162 |
-
Creates a ContextualCompressionRetriever with a base retriever and a similarity threshold.
|
163 |
-
|
164 |
-
The ContextualCompressionRetriever uses an EmbeddingsFilter with OpenAIEmbeddings to filter out documents
|
165 |
-
with a similarity score below the given threshold.
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
"""
|
176 |
-
|
177 |
-
# Set up compression retriever (filter out documents with low similarity)
|
178 |
-
relevant_filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model),
|
179 |
-
similarity_threshold=similarity_threshold)
|
180 |
-
|
181 |
-
compression_retriever = ContextualCompressionRetriever(
|
182 |
-
base_compressor=relevant_filter, base_retriever=base_retriever
|
183 |
)
|
184 |
|
185 |
-
return compression_retriever
|
186 |
-
|
187 |
-
|
188 |
-
def multi_query_retriever_setup(retriever) -> MultiQueryRetriever:
|
189 |
-
""" Configure a multi-query retriever using a base retriever and the LLM.
|
190 |
-
|
191 |
-
Args:
|
192 |
-
retriever:
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
"""
|
197 |
|
198 |
-
|
199 |
input_variables=["question"],
|
200 |
template="""
|
201 |
-
|
202 |
-
Your task is to generate 3 different versions of the provided
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
206 |
{question}
|
207 |
-
|
208 |
-
""",
|
209 |
)
|
210 |
|
211 |
-
llm = ChatOpenAI(model=
|
212 |
-
multi_query_retriever = MultiQueryRetriever.from_llm(
|
213 |
-
retriever=retriever, llm=llm, prompt=QUERY_PROMPT, include_original=True,
|
214 |
-
)
|
215 |
|
216 |
-
return
|
|
|
|
|
|
14 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
15 |
|
16 |
|
17 |
+
class ValidateQdrantClient:
|
18 |
+
"""Base class for retriever clients to ensure environment variables are set."""
|
19 |
|
20 |
+
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
self.validate_environment_variables()
|
22 |
+
|
23 |
+
def validate_environment_variables(self):
|
24 |
+
"""Check if the Qdrant environment variables are set."""
|
25 |
+
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
|
26 |
+
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
27 |
+
if missing_vars:
|
28 |
+
raise EnvironmentError(
|
29 |
+
f"Missing environment variable(s): {', '.join(missing_vars)}"
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class DenseRetrieverClient(ValidateQdrantClient):
|
34 |
+
"""Initialize the dense retriever using OpenAI text embeddings and Qdrant vector database."""
|
35 |
+
|
36 |
+
TEXT_EMBEDDING_MODELS = [
|
37 |
+
"text-embedding-ada-002",
|
38 |
+
"text-embedding-3-small",
|
39 |
+
"text-embedding-3-large",
|
40 |
+
]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
embeddings_model="text-embedding-3-small",
|
45 |
+
collection_name="practitioners_db",
|
46 |
+
search_type="similarity",
|
47 |
+
k=4,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
|
51 |
+
raise ValueError(
|
52 |
+
f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
|
53 |
+
)
|
54 |
self.embeddings_model = embeddings_model
|
55 |
self.collection_name = collection_name
|
56 |
+
self.search_type = search_type
|
57 |
+
self.k = k
|
58 |
self.client = qdrant_client.QdrantClient(
|
59 |
url=os.getenv("QDRANT_URL"),
|
60 |
api_key=os.getenv("QDRANT_API_KEY"),
|
61 |
prefer_grpc=True,
|
62 |
)
|
63 |
+
self._qdrant_collection = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def set_qdrant_collection(self, embeddings):
|
66 |
"""Prepare the Qdrant collection for the embeddings model."""
|
67 |
+
return Qdrant(
|
68 |
+
client=self.client,
|
69 |
+
collection_name=self.collection_name,
|
70 |
+
embeddings=embeddings,
|
71 |
+
)
|
72 |
|
73 |
+
@property
|
74 |
@cache
|
75 |
+
def qdrant_collection(self):
|
76 |
"""Load Qdrant collection for a given embeddings model."""
|
77 |
+
if self._qdrant_collection is None:
|
78 |
+
self._qdrant_collection = self.set_qdrant_collection(
|
79 |
+
OpenAIEmbeddings(model=self.embeddings_model)
|
80 |
+
)
|
81 |
+
return self._qdrant_collection
|
82 |
+
|
83 |
+
def get_dense_retriever(self):
|
84 |
+
"""Set up retrievers (Qdrant vectorstore as retriever)."""
|
85 |
+
return self.qdrant_collection.as_retriever(
|
86 |
+
search_type=self.search_type, search_kwargs={"k": self.k}
|
87 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
+
class SparseRetrieverClient(ValidateQdrantClient):
|
91 |
+
"""Initialize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database."""
|
92 |
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
collection_name,
|
96 |
+
vector_name,
|
97 |
+
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
98 |
+
k=15,
|
99 |
+
):
|
100 |
|
101 |
+
# Validate Qdrant client
|
102 |
+
super().__init__()
|
103 |
+
self.client = qdrant_client.QdrantClient(
|
104 |
+
url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
|
105 |
+
) # TODO: prefer_grpc=True is not working
|
106 |
self.model_id = splade_model_id
|
107 |
+
self._tokenizer = None
|
108 |
+
self._model = None
|
109 |
self.collection_name = collection_name
|
110 |
self.vector_name = vector_name
|
111 |
self.k = k
|
112 |
|
113 |
+
@property
|
114 |
+
@cache
|
115 |
+
def tokenizer(self):
|
116 |
+
"""Initialize the tokenizer."""
|
117 |
+
if self._tokenizer is None:
|
118 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
119 |
+
return self._tokenizer
|
120 |
|
121 |
+
@property
|
122 |
@cache
|
123 |
+
def model(self):
|
124 |
+
"""Initialize the SPLADE neural retrieval model."""
|
125 |
+
if self._model is None:
|
126 |
+
self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
127 |
+
return self._model
|
|
|
|
|
128 |
|
129 |
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
|
130 |
+
"""Encode the input text into a sparse vector."""
|
131 |
+
tokens = self.tokenizer(
|
132 |
+
text,
|
133 |
+
return_tensors="pt",
|
134 |
+
max_length=512,
|
135 |
+
padding="max_length",
|
136 |
+
truncation=True,
|
137 |
+
)
|
|
|
|
|
|
|
138 |
|
139 |
with torch.no_grad():
|
140 |
+
logits = self.model(**tokens).logits
|
|
|
|
|
141 |
|
142 |
relu_log = torch.log1p(torch.relu(logits))
|
143 |
+
weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
|
|
|
|
|
|
|
144 |
|
145 |
+
max_val = torch.max(weighted_log, dim=1).values.squeeze()
|
146 |
+
indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
|
147 |
+
values = max_val[indices].cpu().numpy()
|
148 |
|
149 |
return indices.tolist(), values.tolist()
|
150 |
|
151 |
+
def get_sparse_retriever(self) -> QdrantSparseVectorRetriever:
|
152 |
+
"""Return a Qdrant vector sparse retriever."""
|
153 |
|
154 |
+
return QdrantSparseVectorRetriever(
|
155 |
client=self.client,
|
156 |
collection_name=self.collection_name,
|
157 |
sparse_vector_name=self.vector_name,
|
|
|
159 |
k=self.k,
|
160 |
)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
+
def compression_retriever_setup(
|
164 |
+
base_retriever, embeddings_model="text-embedding-3-small", k=20
|
165 |
+
):
|
166 |
+
"""Creates a ContextualCompressionRetriever with an EmbeddingsFilter."""
|
167 |
+
filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), k=k)
|
168 |
|
169 |
+
return ContextualCompressionRetriever(
|
170 |
+
base_compressor=filter, base_retriever=base_retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
def multi_query_retriever_setup(retriever):
|
175 |
+
"""Configure a multi-query retriever using a base retriever."""
|
|
|
176 |
|
177 |
+
prompt = PromptTemplate(
|
178 |
input_variables=["question"],
|
179 |
template="""
|
180 |
+
|
181 |
+
Your task is to generate 3 different grammatically correct versions of the provided text,
|
182 |
+
incorporating the user's location preference in each version. Format these versions as paragraphs and present them as items in a Markdown formatted numbered list ("1. "). There should be no additional new lines or spaces between each version. Do not enclose your response in quotation marks. Do not modify unfamiliar acronyms and keep your responses clear and concise.
|
183 |
+
|
184 |
+
**Notes**: The text provided are user questions to Tall Tree Health Centre's AI virtual assistant. `Location preference:` is the location of the Tall Tree Health clinic that the user prefers.
|
185 |
+
|
186 |
+
Text to be modified:
|
187 |
+
```
|
188 |
{question}
|
189 |
+
```""",
|
|
|
190 |
)
|
191 |
|
192 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
|
|
|
|
|
|
193 |
|
194 |
+
return MultiQueryRetriever.from_llm(
|
195 |
+
retriever=retriever, llm=llm, prompt=prompt, include_original=True
|
196 |
+
)
|