John Graham Reynolds commited on
Commit
f8977f5
·
1 Parent(s): a1495e2

only cache chat model, vector store retriever, and embedding model for retriever

Browse files
Files changed (1) hide show
  1. chain.py +42 -30
chain.py CHANGED
@@ -33,31 +33,41 @@ class ChainBuilder:
33
  def extract_chat_history(chat_messages_array):
34
  return chat_messages_array[:-1]
35
 
36
- # ** working logic for querying glossary embeddings
37
- # Same embedding model we used to create embeddings of terms
38
- # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
39
- # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
40
- # does this cache to the given folder though? It does appear to populate the folder as expected after being run
41
- @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
42
  def load_embedding_model(self):
43
- embeddings = HuggingFaceEmbeddings(model_name=self.retriever_config.get("embedding_model"), cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
 
 
 
 
 
 
 
 
44
  # update this to read from a presaved cache of bge-large
45
- return embeddings
 
46
 
47
  def get_retriever(self):
 
 
48
  embeddings = self.load_embedding_model()
49
- # instantiate the vector store for similarity search in our chain
50
- # need to make this a function and decorate it with @st.experimental_memo as above?
51
- # We are only calling this initiatially when the Space starts and builds the chain. Can we expedite this process for users when opening up this Space?
52
- # @st.cache_data # TODO add this in
53
- vector_search_as_retriever = DatabricksVectorSearch(
54
- endpoint=self.databricks_resources.get("vector_search_endpoint_name"),
55
- index_name=self.retriever_config.get("vector_search_index"),
56
- embedding=embeddings,
57
- text_column="name",
58
- columns=["name", "description"],
59
- ).as_retriever(search_kwargs=self.retriever_config.get("parameters"))
60
- return vector_search_as_retriever
 
 
 
 
 
61
 
62
  # # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
63
  # # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
@@ -70,7 +80,6 @@ class ChainBuilder:
70
  # )
71
 
72
  # Method to format the terms and definitions returned by the retriever into the prompt
73
- # TODO double check the contents here
74
  def format_context(self, retrieved_terms):
75
  chunk_template = self.retriever_config.get("chunk_template")
76
  chunk_contents = [
@@ -125,16 +134,20 @@ class ChainBuilder:
125
  )
126
  return query_rewrite_prompt
127
 
128
- @st.cache_resource
129
  def get_model(self):
130
- # Foundation Model for generation
131
- model = ChatDatabricks(
132
- endpoint=self.databricks_resources.get("llm_endpoint_name"),
133
- extra_params=self.llm_config.get("llm_parameters"),
134
- )
135
- return model
 
 
 
 
 
 
136
 
137
- @st.cache_resource
138
  def build_chain(self):
139
  model = self.get_model()
140
  prompt = self.get_prompt()
@@ -169,7 +182,6 @@ class ChainBuilder:
169
  | model # prompt passed to model
170
  | StrOutputParser()
171
  )
172
-
173
  return chain
174
 
175
  # ## Tell MLflow logging where to find your chain.
 
33
  def extract_chat_history(chat_messages_array):
34
  return chat_messages_array[:-1]
35
 
 
 
 
 
 
 
36
  def load_embedding_model(self):
37
+ model_name = self.retriever_config.get("embedding_model")
38
+
39
+ # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
40
+ # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
41
+ # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
42
+ # does this cache to the given folder though? It does appear to populate the folder as expected after being run
43
+ @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
44
+ def load_and_cache_embedding_model(model_name):
45
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
46
  # update this to read from a presaved cache of bge-large
47
+
48
+ return load_and_cache_embedding_model(model_name)
49
 
50
  def get_retriever(self):
51
+ endpoint=self.databricks_resources.get("vector_search_endpoint_name")
52
+ index_name=self.retriever_config.get("vector_search_index")
53
  embeddings = self.load_embedding_model()
54
+ search_kwargs=self.retriever_config.get("parameters")
55
+
56
+ # you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
57
+ # This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
58
+ @st.cache_resource # cache the Databricks vector store retriever
59
+ def get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs):
60
+ vector_search_as_retriever = DatabricksVectorSearch(
61
+ endpoint=endpoint,
62
+ index_name=index_name,
63
+ embedding=embeddings,
64
+ text_column="name",
65
+ columns=["name", "description"],
66
+ ).as_retriever(search_kwargs=search_kwargs)
67
+
68
+ return vector_search_as_retriever
69
+
70
+ return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs)
71
 
72
  # # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
73
  # # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
 
80
  # )
81
 
82
  # Method to format the terms and definitions returned by the retriever into the prompt
 
83
  def format_context(self, retrieved_terms):
84
  chunk_template = self.retriever_config.get("chunk_template")
85
  chunk_contents = [
 
134
  )
135
  return query_rewrite_prompt
136
 
 
137
  def get_model(self):
138
+ endpoint = self.databricks_resources.get("llm_endpoint_name")
139
+ extra_params=self.llm_config.get("llm_parameters")
140
+
141
+ @st.cache_resource # cache the DBRX Instruct model we are loading for repeated use in our chain for chat completion
142
+ def get_and_cache_model(endpoint, extra_params):
143
+ model = ChatDatabricks(
144
+ endpoint=endpoint,
145
+ extra_params=extra_params,
146
+ )
147
+ return model
148
+
149
+ return get_and_cache_model(endpoint, extra_params)
150
 
 
151
  def build_chain(self):
152
  model = self.get_model()
153
  prompt = self.get_prompt()
 
182
  | model # prompt passed to model
183
  | StrOutputParser()
184
  )
 
185
  return chain
186
 
187
  # ## Tell MLflow logging where to find your chain.