Spaces:
Sleeping
Sleeping
John Graham Reynolds
commited on
Commit
·
f8977f5
1
Parent(s):
a1495e2
only cache chat model, vector store retriever, and embedding model for retriever
Browse files
chain.py
CHANGED
@@ -33,31 +33,41 @@ class ChainBuilder:
|
|
33 |
def extract_chat_history(chat_messages_array):
|
34 |
return chat_messages_array[:-1]
|
35 |
|
36 |
-
# ** working logic for querying glossary embeddings
|
37 |
-
# Same embedding model we used to create embeddings of terms
|
38 |
-
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
|
39 |
-
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
|
40 |
-
# does this cache to the given folder though? It does appear to populate the folder as expected after being run
|
41 |
-
@st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
|
42 |
def load_embedding_model(self):
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# update this to read from a presaved cache of bge-large
|
45 |
-
|
|
|
46 |
|
47 |
def get_retriever(self):
|
|
|
|
|
48 |
embeddings = self.load_embedding_model()
|
49 |
-
|
50 |
-
|
51 |
-
#
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
|
63 |
# # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
|
@@ -70,7 +80,6 @@ class ChainBuilder:
|
|
70 |
# )
|
71 |
|
72 |
# Method to format the terms and definitions returned by the retriever into the prompt
|
73 |
-
# TODO double check the contents here
|
74 |
def format_context(self, retrieved_terms):
|
75 |
chunk_template = self.retriever_config.get("chunk_template")
|
76 |
chunk_contents = [
|
@@ -125,16 +134,20 @@ class ChainBuilder:
|
|
125 |
)
|
126 |
return query_rewrite_prompt
|
127 |
|
128 |
-
@st.cache_resource
|
129 |
def get_model(self):
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
)
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
@st.cache_resource
|
138 |
def build_chain(self):
|
139 |
model = self.get_model()
|
140 |
prompt = self.get_prompt()
|
@@ -169,7 +182,6 @@ class ChainBuilder:
|
|
169 |
| model # prompt passed to model
|
170 |
| StrOutputParser()
|
171 |
)
|
172 |
-
|
173 |
return chain
|
174 |
|
175 |
# ## Tell MLflow logging where to find your chain.
|
|
|
33 |
def extract_chat_history(chat_messages_array):
|
34 |
return chat_messages_array[:-1]
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def load_embedding_model(self):
|
37 |
+
model_name = self.retriever_config.get("embedding_model")
|
38 |
+
|
39 |
+
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
|
40 |
+
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
|
41 |
+
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
|
42 |
+
# does this cache to the given folder though? It does appear to populate the folder as expected after being run
|
43 |
+
@st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
|
44 |
+
def load_and_cache_embedding_model(model_name):
|
45 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
|
46 |
# update this to read from a presaved cache of bge-large
|
47 |
+
|
48 |
+
return load_and_cache_embedding_model(model_name)
|
49 |
|
50 |
def get_retriever(self):
|
51 |
+
endpoint=self.databricks_resources.get("vector_search_endpoint_name")
|
52 |
+
index_name=self.retriever_config.get("vector_search_index")
|
53 |
embeddings = self.load_embedding_model()
|
54 |
+
search_kwargs=self.retriever_config.get("parameters")
|
55 |
+
|
56 |
+
# you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
|
57 |
+
# This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
|
58 |
+
@st.cache_resource # cache the Databricks vector store retriever
|
59 |
+
def get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs):
|
60 |
+
vector_search_as_retriever = DatabricksVectorSearch(
|
61 |
+
endpoint=endpoint,
|
62 |
+
index_name=index_name,
|
63 |
+
embedding=embeddings,
|
64 |
+
text_column="name",
|
65 |
+
columns=["name", "description"],
|
66 |
+
).as_retriever(search_kwargs=search_kwargs)
|
67 |
+
|
68 |
+
return vector_search_as_retriever
|
69 |
+
|
70 |
+
return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs)
|
71 |
|
72 |
# # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
|
73 |
# # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
|
|
|
80 |
# )
|
81 |
|
82 |
# Method to format the terms and definitions returned by the retriever into the prompt
|
|
|
83 |
def format_context(self, retrieved_terms):
|
84 |
chunk_template = self.retriever_config.get("chunk_template")
|
85 |
chunk_contents = [
|
|
|
134 |
)
|
135 |
return query_rewrite_prompt
|
136 |
|
|
|
137 |
def get_model(self):
|
138 |
+
endpoint = self.databricks_resources.get("llm_endpoint_name")
|
139 |
+
extra_params=self.llm_config.get("llm_parameters")
|
140 |
+
|
141 |
+
@st.cache_resource # cache the DBRX Instruct model we are loading for repeated use in our chain for chat completion
|
142 |
+
def get_and_cache_model(endpoint, extra_params):
|
143 |
+
model = ChatDatabricks(
|
144 |
+
endpoint=endpoint,
|
145 |
+
extra_params=extra_params,
|
146 |
+
)
|
147 |
+
return model
|
148 |
+
|
149 |
+
return get_and_cache_model(endpoint, extra_params)
|
150 |
|
|
|
151 |
def build_chain(self):
|
152 |
model = self.get_model()
|
153 |
prompt = self.get_prompt()
|
|
|
182 |
| model # prompt passed to model
|
183 |
| StrOutputParser()
|
184 |
)
|
|
|
185 |
return chain
|
186 |
|
187 |
# ## Tell MLflow logging where to find your chain.
|