Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Delete rag_chain
Browse files- rag_chain/__init__.py +0 -1
- rag_chain/chain.py +0 -186
- rag_chain/prompt_template.py +0 -95
- rag_chain/retrievers_setup.py +0 -196
rag_chain/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
rag_chain/chain.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
from functools import cache
|
4 |
-
from operator import itemgetter
|
5 |
-
|
6 |
-
import langsmith
|
7 |
-
from langchain.memory import ConversationBufferWindowMemory
|
8 |
-
from langchain.retrievers import EnsembleRetriever
|
9 |
-
from langchain_community.document_transformers import LongContextReorder
|
10 |
-
from langchain_core.documents import Document
|
11 |
-
from langchain_core.output_parsers import StrOutputParser
|
12 |
-
from langchain_core.runnables import RunnableLambda
|
13 |
-
from langchain_openai.chat_models import ChatOpenAI
|
14 |
-
|
15 |
-
from .prompt_template import generate_prompt_template
|
16 |
-
from .retrievers_setup import (
|
17 |
-
DenseRetrieverClient,
|
18 |
-
SparseRetrieverClient,
|
19 |
-
compression_retriever_setup,
|
20 |
-
multi_query_retriever_setup,
|
21 |
-
)
|
22 |
-
|
23 |
-
# Helpers
|
24 |
-
|
25 |
-
|
26 |
-
def reorder_documents(docs: list[Document]) -> list[Document]:
|
27 |
-
"""Reorder documents to mitigate performance degradation with long contexts."""
|
28 |
-
return LongContextReorder().transform_documents(docs)
|
29 |
-
|
30 |
-
|
31 |
-
def randomize_documents(documents: list[Document]) -> list[Document]:
|
32 |
-
"""Randomize documents to vary model recommendations."""
|
33 |
-
random.shuffle(documents)
|
34 |
-
return documents
|
35 |
-
|
36 |
-
|
37 |
-
class DocumentFormatter:
|
38 |
-
def __init__(self, prefix: str):
|
39 |
-
self.prefix = prefix
|
40 |
-
|
41 |
-
def __call__(self, docs: list[Document]) -> str:
|
42 |
-
"""Format the Documents to markdown.
|
43 |
-
Args:
|
44 |
-
docs (list[Documents]): List of Langchain documents
|
45 |
-
Returns:
|
46 |
-
docs (str):
|
47 |
-
"""
|
48 |
-
return f"\n---\n".join(
|
49 |
-
[
|
50 |
-
f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
|
51 |
-
for i, d in enumerate(docs)
|
52 |
-
]
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
def create_langsmith_client():
|
57 |
-
"""Create a Langsmith client."""
|
58 |
-
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
59 |
-
os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant"
|
60 |
-
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
61 |
-
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
|
62 |
-
if not langsmith_api_key:
|
63 |
-
raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
|
64 |
-
return langsmith.Client()
|
65 |
-
|
66 |
-
|
67 |
-
# Set up Runnable and Memory
|
68 |
-
|
69 |
-
|
70 |
-
@cache
|
71 |
-
def get_rag_chain(
|
72 |
-
model_name: str = "gpt-4", temperature: float = 0.2
|
73 |
-
) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]:
|
74 |
-
"""Set up runnable and chat memory
|
75 |
-
|
76 |
-
Args:
|
77 |
-
model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024.
|
78 |
-
temperature (float, optional): Model temperature. Defaults to 0.2.
|
79 |
-
|
80 |
-
Returns:
|
81 |
-
Runnable, Memory: Chain and Memory
|
82 |
-
"""
|
83 |
-
|
84 |
-
RETRIEVER_PARAMETERS = {
|
85 |
-
"embeddings_model": "text-embedding-3-small",
|
86 |
-
"k_dense_practitioners_db": 8,
|
87 |
-
"k_sparse_practitioners_db": 15,
|
88 |
-
"weights_ensemble_practitioners_db": [0.2, 0.8],
|
89 |
-
"k_compression_practitioners_db": 12,
|
90 |
-
"k_dense_talltree": 6,
|
91 |
-
"k_compression_talltree": 6,
|
92 |
-
}
|
93 |
-
|
94 |
-
# Set up Langsmith to trace the chain
|
95 |
-
langsmith_tracing = create_langsmith_client()
|
96 |
-
|
97 |
-
# LLM and prompt template
|
98 |
-
llm = ChatOpenAI(
|
99 |
-
model_name=model_name,
|
100 |
-
temperature=temperature,
|
101 |
-
)
|
102 |
-
|
103 |
-
prompt = generate_prompt_template()
|
104 |
-
|
105 |
-
# Set retrievers pointing to the practitioners's dataset
|
106 |
-
dense_retriever_client = DenseRetrieverClient(
|
107 |
-
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
108 |
-
collection_name="practitioners_db",
|
109 |
-
search_type="similarity",
|
110 |
-
k=RETRIEVER_PARAMETERS["k_dense_practitioners_db"],
|
111 |
-
) # k x 4 using multiquery retriever
|
112 |
-
|
113 |
-
# Qdrant db as a retriever
|
114 |
-
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever()
|
115 |
-
|
116 |
-
# Multiquery retriever using the dense retriever
|
117 |
-
# This retriever can be passed or not to the EnsembleRetriever. It uses GPT-3.5-turbo.
|
118 |
-
practitioners_db_dense_multiquery_retriever = multi_query_retriever_setup(
|
119 |
-
practitioners_db_dense_retriever
|
120 |
-
)
|
121 |
-
|
122 |
-
# Sparse vector retriever
|
123 |
-
sparse_retriever_client = SparseRetrieverClient(
|
124 |
-
collection_name="practitioners_db_sparse_collection",
|
125 |
-
vector_name="sparse_vector",
|
126 |
-
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
127 |
-
k=RETRIEVER_PARAMETERS["k_sparse_practitioners_db"],
|
128 |
-
)
|
129 |
-
|
130 |
-
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever()
|
131 |
-
|
132 |
-
# Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT)
|
133 |
-
practitioners_ensemble_retriever = EnsembleRetriever(
|
134 |
-
retrievers=[
|
135 |
-
practitioners_db_dense_retriever,
|
136 |
-
practitioners_db_sparse_retriever,
|
137 |
-
],
|
138 |
-
weights=RETRIEVER_PARAMETERS["weights_ensemble_practitioners_db"],
|
139 |
-
)
|
140 |
-
|
141 |
-
# Compression retriever for practitioners db
|
142 |
-
practitioners_db_compression_retriever = compression_retriever_setup(
|
143 |
-
practitioners_ensemble_retriever,
|
144 |
-
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
145 |
-
k=RETRIEVER_PARAMETERS["k_compression_practitioners_db"],
|
146 |
-
)
|
147 |
-
|
148 |
-
# Set retrievers pointing to the tall_tree_db
|
149 |
-
dense_retriever_client = DenseRetrieverClient(
|
150 |
-
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
151 |
-
collection_name="tall_tree_db",
|
152 |
-
search_type="similarity",
|
153 |
-
k=RETRIEVER_PARAMETERS["k_dense_talltree"],
|
154 |
-
)
|
155 |
-
|
156 |
-
tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever()
|
157 |
-
|
158 |
-
# Compression retriever for tall_tree_db
|
159 |
-
tall_tree_db_compression_retriever = compression_retriever_setup(
|
160 |
-
tall_tree_db_dense_retriever,
|
161 |
-
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"],
|
162 |
-
k=RETRIEVER_PARAMETERS["k_compression_talltree"],
|
163 |
-
)
|
164 |
-
|
165 |
-
# Set conversation history window memory. It only uses the last k interactions.
|
166 |
-
memory = ConversationBufferWindowMemory(
|
167 |
-
memory_key="history",
|
168 |
-
return_messages=True,
|
169 |
-
k=6,
|
170 |
-
)
|
171 |
-
|
172 |
-
# Set up runnable using LCEL
|
173 |
-
setup_and_retrieval = {
|
174 |
-
"practitioners_db": itemgetter("message")
|
175 |
-
| practitioners_db_compression_retriever
|
176 |
-
| DocumentFormatter("Practitioner #"),
|
177 |
-
"tall_tree_db": itemgetter("message")
|
178 |
-
| tall_tree_db_dense_retriever
|
179 |
-
| DocumentFormatter("No."),
|
180 |
-
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
|
181 |
-
"message": itemgetter("message"),
|
182 |
-
}
|
183 |
-
|
184 |
-
chain = setup_and_retrieval | prompt | llm | StrOutputParser()
|
185 |
-
|
186 |
-
return chain, memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_chain/prompt_template.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
from langchain.prompts import (
|
2 |
-
ChatPromptTemplate,
|
3 |
-
MessagesPlaceholder,
|
4 |
-
SystemMessagePromptTemplate,
|
5 |
-
)
|
6 |
-
|
7 |
-
|
8 |
-
def generate_prompt_template():
|
9 |
-
|
10 |
-
system_template = """# Role
|
11 |
-
|
12 |
-
---
|
13 |
-
|
14 |
-
Your name is Ella (Empathetic, Logical, Liaison, Accessible). You are a helpful Virtual Assistant at Tall Tree Health in British Columbia, Canada. Based on the patient's symptoms/needs, connect them with the appropriate practitioner or service offered by Tall Tree. Respond to `Patient Queries` using the `Practitioners Database` and `Tall Tree Health Centre Information` provided in the `Context`. Follow the `Response Guidelines` listed below:
|
15 |
-
|
16 |
-
---
|
17 |
-
|
18 |
-
# Response Guidelines
|
19 |
-
|
20 |
-
1. **Interaction**: Engage in a warm, empathetic, and professional manner. Keep responses brief and focused on the patient's query. Always conclude positively with a reassuring statement. Use markdown formatting.
|
21 |
-
|
22 |
-
2. **Symptoms/needs and Location Preference**: Only if not specified, ask for symptoms/needs and location preference (Cordova Bay, James Bay, and Vancouver) before recommending a practitioner or service.
|
23 |
-
|
24 |
-
3. **Avoid Making Assumptions**: Stick to the given `Context`. If you're unable to assist, offer the user the contact details for the closest `Tall Tree Health` clinic.
|
25 |
-
|
26 |
-
4. Do not give medical advice or act as a health professional. Avoid discussing healthcare costs.
|
27 |
-
|
28 |
-
5. **Symptoms/needs and Service Verification**: Match the patient's symptoms/needs with the `Focus Area` field in the `Practitioners Database`. If no match is found, advise the patient accordingly without recommending a practitioner, as Tall Tree is not a primary healthcare provider.
|
29 |
-
|
30 |
-
6. **Recommending Practitioners**: Based on the patient's symptoms/needs and location, randomly recommend up to 3 practitioners (only with active status) from the `Practitioners Database`. Focus on `Discipline`, `Focus Areas`, `Location`, `Treatment Method`,`Status`, and `Populations in Focus`. Once you recommend a practitioner, provide the contact info for the corresponding `Tall Tree Health` location for additional assistance.
|
31 |
-
|
32 |
-
7. **Practitioner's Contact Information**: Provide contact information in the following structured format. Do not print their `Focus Areas`:
|
33 |
-
|
34 |
-
- `First Name` and `Last Name`
|
35 |
-
- `Discipline`
|
36 |
-
- `Booking Link` (print only if available)
|
37 |
-
|
38 |
-
## Tall Tree Health Service Routing Guidelines
|
39 |
-
|
40 |
-
8. **Mental Health Urgent Queries**: For urgent situations such as self-harm, suicidal thoughts, violence, hallucinations, or dissociation direct the patient to call the [9-8-8](tel:9-8-8) suicide crisis helpline, reach out to the Vancouver Island Crisis Line at [1-888-494-3888](tel:1-888-494-3888), or head to the nearest emergency room. Tall Tree isn't equipped for mental health emergencies. Do not recommend a practitioner or service.
|
41 |
-
|
42 |
-
9. **Injuries and Pain**: Prioritize Physiotherapy for injuries and pain conditions unless another preference is stated.
|
43 |
-
|
44 |
-
10. **Concussion Protocol**: Direct to the `Concussion Treatment Program` for the appropriate location for a comprehensive assessment with a physiotherapist. Do not recommend a practitioner.
|
45 |
-
|
46 |
-
11. **Psychologist in Vancouver**: If a Psychologist is requested in the Vancouver location, provide only the contact and booking link for our mental health team in Cordova Bay - Upstairs location. Do not recommend an alternative practitioner.
|
47 |
-
|
48 |
-
12. **Sleep issues**: Recommend only the Sleep Program intake and provide the phone number to book an appointment. Do not recommend a practitioner.
|
49 |
-
|
50 |
-
13. **Longevity Program**: For longevity queries, provide the Longevity Program phone number. Do not recommend a practitioner.
|
51 |
-
|
52 |
-
14. **DEXA Testing or body composition**: Inform that this service is exclusive to the Cordova Bay clinic and provide the clinic phone number and booking link. Do not recommend a practitioner.
|
53 |
-
|
54 |
-
15. **For VO2 Max Testing**: Determine the patient's location preference for Vancouver or Victoria and provide the booking link for the appropriate location. If Victoria, we only do it at our Cordova Bay location.
|
55 |
-
|
56 |
-
---
|
57 |
-
|
58 |
-
# Patient Query
|
59 |
-
|
60 |
-
```
|
61 |
-
{message}
|
62 |
-
```
|
63 |
-
---
|
64 |
-
|
65 |
-
# Context
|
66 |
-
|
67 |
-
---
|
68 |
-
1. **Practitioners Database**:
|
69 |
-
|
70 |
-
```
|
71 |
-
{practitioners_db}
|
72 |
-
```
|
73 |
-
---
|
74 |
-
|
75 |
-
2. **Tall Tree Health Centre Information**:
|
76 |
-
|
77 |
-
```
|
78 |
-
{tall_tree_db}
|
79 |
-
```
|
80 |
-
---
|
81 |
-
|
82 |
-
"""
|
83 |
-
|
84 |
-
# Template for system message with markdown formatting
|
85 |
-
system_message = SystemMessagePromptTemplate.from_template(system_template)
|
86 |
-
|
87 |
-
prompt = ChatPromptTemplate.from_messages(
|
88 |
-
[
|
89 |
-
system_message,
|
90 |
-
MessagesPlaceholder(variable_name="history"),
|
91 |
-
("human", "{message}"),
|
92 |
-
]
|
93 |
-
)
|
94 |
-
|
95 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_chain/retrievers_setup.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from functools import cache
|
3 |
-
|
4 |
-
import qdrant_client
|
5 |
-
import torch
|
6 |
-
from langchain.prompts import PromptTemplate
|
7 |
-
from langchain.retrievers import ContextualCompressionRetriever
|
8 |
-
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
9 |
-
from langchain.retrievers.multi_query import MultiQueryRetriever
|
10 |
-
from langchain_community.retrievers import QdrantSparseVectorRetriever
|
11 |
-
from langchain_community.vectorstores import Qdrant
|
12 |
-
from langchain_openai import ChatOpenAI
|
13 |
-
from langchain_openai.embeddings import OpenAIEmbeddings
|
14 |
-
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
15 |
-
|
16 |
-
|
17 |
-
class ValidateQdrantClient:
|
18 |
-
"""Base class for retriever clients to ensure environment variables are set."""
|
19 |
-
|
20 |
-
def __init__(self):
|
21 |
-
self.validate_environment_variables()
|
22 |
-
|
23 |
-
def validate_environment_variables(self):
|
24 |
-
"""Check if the Qdrant environment variables are set."""
|
25 |
-
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
|
26 |
-
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
27 |
-
if missing_vars:
|
28 |
-
raise EnvironmentError(
|
29 |
-
f"Missing environment variable(s): {', '.join(missing_vars)}"
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
class DenseRetrieverClient(ValidateQdrantClient):
|
34 |
-
"""Initialize the dense retriever using OpenAI text embeddings and Qdrant vector database."""
|
35 |
-
|
36 |
-
TEXT_EMBEDDING_MODELS = [
|
37 |
-
"text-embedding-ada-002",
|
38 |
-
"text-embedding-3-small",
|
39 |
-
"text-embedding-3-large",
|
40 |
-
]
|
41 |
-
|
42 |
-
def __init__(
|
43 |
-
self,
|
44 |
-
embeddings_model="text-embedding-3-small",
|
45 |
-
collection_name="practitioners_db",
|
46 |
-
search_type="similarity",
|
47 |
-
k=4,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
|
51 |
-
raise ValueError(
|
52 |
-
f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
|
53 |
-
)
|
54 |
-
self.embeddings_model = embeddings_model
|
55 |
-
self.collection_name = collection_name
|
56 |
-
self.search_type = search_type
|
57 |
-
self.k = k
|
58 |
-
self.client = qdrant_client.QdrantClient(
|
59 |
-
url=os.getenv("QDRANT_URL"),
|
60 |
-
api_key=os.getenv("QDRANT_API_KEY"),
|
61 |
-
prefer_grpc=True,
|
62 |
-
)
|
63 |
-
self._qdrant_collection = None
|
64 |
-
|
65 |
-
def set_qdrant_collection(self, embeddings):
|
66 |
-
"""Prepare the Qdrant collection for the embeddings model."""
|
67 |
-
return Qdrant(
|
68 |
-
client=self.client,
|
69 |
-
collection_name=self.collection_name,
|
70 |
-
embeddings=embeddings,
|
71 |
-
)
|
72 |
-
|
73 |
-
@property
|
74 |
-
@cache
|
75 |
-
def qdrant_collection(self):
|
76 |
-
"""Load Qdrant collection for a given embeddings model."""
|
77 |
-
if self._qdrant_collection is None:
|
78 |
-
self._qdrant_collection = self.set_qdrant_collection(
|
79 |
-
OpenAIEmbeddings(model=self.embeddings_model)
|
80 |
-
)
|
81 |
-
return self._qdrant_collection
|
82 |
-
|
83 |
-
def get_dense_retriever(self):
|
84 |
-
"""Set up retrievers (Qdrant vectorstore as retriever)."""
|
85 |
-
return self.qdrant_collection.as_retriever(
|
86 |
-
search_type=self.search_type, search_kwargs={"k": self.k}
|
87 |
-
)
|
88 |
-
|
89 |
-
|
90 |
-
class SparseRetrieverClient(ValidateQdrantClient):
|
91 |
-
"""Initialize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database."""
|
92 |
-
|
93 |
-
def __init__(
|
94 |
-
self,
|
95 |
-
collection_name,
|
96 |
-
vector_name,
|
97 |
-
splade_model_id="naver/splade-cocondenser-ensembledistil",
|
98 |
-
k=15,
|
99 |
-
):
|
100 |
-
|
101 |
-
# Validate Qdrant client
|
102 |
-
super().__init__()
|
103 |
-
self.client = qdrant_client.QdrantClient(
|
104 |
-
url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
|
105 |
-
) # TODO: prefer_grpc=True is not working
|
106 |
-
self.model_id = splade_model_id
|
107 |
-
self._tokenizer = None
|
108 |
-
self._model = None
|
109 |
-
self.collection_name = collection_name
|
110 |
-
self.vector_name = vector_name
|
111 |
-
self.k = k
|
112 |
-
|
113 |
-
@property
|
114 |
-
@cache
|
115 |
-
def tokenizer(self):
|
116 |
-
"""Initialize the tokenizer."""
|
117 |
-
if self._tokenizer is None:
|
118 |
-
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
119 |
-
return self._tokenizer
|
120 |
-
|
121 |
-
@property
|
122 |
-
@cache
|
123 |
-
def model(self):
|
124 |
-
"""Initialize the SPLADE neural retrieval model."""
|
125 |
-
if self._model is None:
|
126 |
-
self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
127 |
-
return self._model
|
128 |
-
|
129 |
-
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
|
130 |
-
"""Encode the input text into a sparse vector."""
|
131 |
-
tokens = self.tokenizer(
|
132 |
-
text,
|
133 |
-
return_tensors="pt",
|
134 |
-
max_length=512,
|
135 |
-
padding="max_length",
|
136 |
-
truncation=True,
|
137 |
-
)
|
138 |
-
|
139 |
-
with torch.no_grad():
|
140 |
-
logits = self.model(**tokens).logits
|
141 |
-
|
142 |
-
relu_log = torch.log1p(torch.relu(logits))
|
143 |
-
weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
|
144 |
-
|
145 |
-
max_val = torch.max(weighted_log, dim=1).values.squeeze()
|
146 |
-
indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
|
147 |
-
values = max_val[indices].cpu().numpy()
|
148 |
-
|
149 |
-
return indices.tolist(), values.tolist()
|
150 |
-
|
151 |
-
def get_sparse_retriever(self) -> QdrantSparseVectorRetriever:
|
152 |
-
"""Return a Qdrant vector sparse retriever."""
|
153 |
-
|
154 |
-
return QdrantSparseVectorRetriever(
|
155 |
-
client=self.client,
|
156 |
-
collection_name=self.collection_name,
|
157 |
-
sparse_vector_name=self.vector_name,
|
158 |
-
sparse_encoder=self.sparse_encoder,
|
159 |
-
k=self.k,
|
160 |
-
)
|
161 |
-
|
162 |
-
|
163 |
-
def compression_retriever_setup(
|
164 |
-
base_retriever, embeddings_model="text-embedding-3-small", k=20
|
165 |
-
):
|
166 |
-
"""Creates a ContextualCompressionRetriever with an EmbeddingsFilter."""
|
167 |
-
filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), k=k)
|
168 |
-
|
169 |
-
return ContextualCompressionRetriever(
|
170 |
-
base_compressor=filter, base_retriever=base_retriever
|
171 |
-
)
|
172 |
-
|
173 |
-
|
174 |
-
def multi_query_retriever_setup(retriever):
|
175 |
-
"""Configure a multi-query retriever using a base retriever."""
|
176 |
-
|
177 |
-
prompt = PromptTemplate(
|
178 |
-
input_variables=["question"],
|
179 |
-
template="""
|
180 |
-
|
181 |
-
Your task is to generate 3 different grammatically correct versions of the provided text,
|
182 |
-
incorporating the user's location preference in each version. Format these versions as paragraphs and present them as items in a Markdown formatted numbered list ("1. "). There should be no additional new lines or spaces between each version. Do not enclose your response in quotation marks. Do not modify unfamiliar acronyms and keep your responses clear and concise.
|
183 |
-
|
184 |
-
**Notes**: The text provided are user questions to Tall Tree Health Centre's AI virtual assistant. `Location preference:` is the location of the Tall Tree Health clinic that the user prefers.
|
185 |
-
|
186 |
-
Text to be modified:
|
187 |
-
```
|
188 |
-
{question}
|
189 |
-
```""",
|
190 |
-
)
|
191 |
-
|
192 |
-
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
193 |
-
|
194 |
-
return MultiQueryRetriever.from_llm(
|
195 |
-
retriever=retriever, llm=llm, prompt=prompt, include_original=True
|
196 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|