John Graham Reynolds commited on
Commit
29cf982
·
1 Parent(s): 8df66b4

add chain for reformatting inputs and augmenting the question with relevant context

Browse files
Files changed (1) hide show
  1. chain.py +176 -0
chain.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import mlflow
3
+ import streamlit as st
4
+ from operator import itemgetter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_databricks.vectorstores import DatabricksVectorSearch
7
+ from langchain_community.chat_models import ChatDatabricks
8
+ from langchain_community.vectorstores import DatabricksVectorSearch
9
+ from langchain_core.runnables import RunnableLambda
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
12
+ from langchain_core.runnables import RunnablePassthrough, RunnableBranch
13
+ from langchain_core.messages import HumanMessage, AIMessage
14
+
15
+ # ## Enable MLflow Tracing
16
+ # mlflow.langchain.autolog()
17
+
18
+ class ChainBuilder:
19
+
20
+ def __init__(self):
21
+ # Load the chain's configuration from yaml
22
+ self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml")
23
+ self.databricks_resources = self.model_config.get("databricks_resources")
24
+ self.llm_config = self.model_config.get("llm_config")
25
+ self.retriever_config = self.model_config.get("retriever_config")
26
+ self.vector_search_schema = self.retriever_config.get("schema")
27
+
28
+ # Return the string contents of the most recent message from the user
29
+ def extract_user_query_string(chat_messages_array):
30
+ return chat_messages_array[-1]["content"]
31
+
32
+ # Return the chat history, which is everything before the last question
33
+ def extract_chat_history(chat_messages_array):
34
+ return chat_messages_array[:-1]
35
+
36
+ # ** working logic for querying glossary embeddings
37
+ # Same embedding model we used to create embeddings of terms
38
+ # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
39
+ # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
40
+ # does this cache to the given folder though? It does appear to populate the folder as expected after being run
41
+ @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
42
+ def load_embedding_model(self):
43
+ embeddings = HuggingFaceEmbeddings(model_name=self.retriever_config.get("embedding_model"), cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
44
+ # update this to read from a presaved cache of bge-large
45
+ return embeddings
46
+
47
+ def get_retriever(self):
48
+ embeddings = self.load_embedding_model()
49
+ # instantiate the vector store for similarity search in our chain
50
+ # need to make this a function and decorate it with @st.experimental_memo as above?
51
+ # We are only calling this initiatially when the Space starts and builds the chain. Can we expedite this process for users when opening up this Space?
52
+ # @st.cache_data # TODO add this in
53
+ vector_search_as_retriever = DatabricksVectorSearch(
54
+ endpoint=self.databricks_resources.get("vector_search_endpoint_name"),
55
+ index_name=self.retriever_config.get("vector_search_index"),
56
+ embedding=embeddings,
57
+ text_column="name",
58
+ columns=["name", "description"],
59
+ ).as_retriever(search_kwargs=self.retriever_config.get("parameters"))
60
+ return vector_search_as_retriever
61
+
62
+ # # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
63
+ # # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
64
+ # mlflow.models.set_retriever_schema(
65
+ # primary_key=self.vector_search_schema.get("primary_key"),
66
+ # text_column=vector_search_schema.get("chunked_terms"),
67
+ # # doc_uri=vector_search_schema.get("definition")
68
+ # other_columns=[vector_search_schema.get("definition")],
69
+ # # Review App uses `doc_uri` to display chunks from the same document in a single view
70
+ # )
71
+
72
+ # Method to format the terms and definitions returned by the retriever into the prompt
73
+ # TODO double check the contents here
74
+ def format_context(self, retrieved_terms):
75
+ chunk_template = self.retriever_config.get("chunk_template")
76
+ chunk_contents = [
77
+ chunk_template.format(
78
+ name=term.page_content,
79
+ description=term.metadata[self.vector_search_schema.get("description")],
80
+ )
81
+ for term in retrieved_terms
82
+ ]
83
+ return "".join(chunk_contents)
84
+
85
+ def get_prompt(self):
86
+ # Prompt Template for generation
87
+ prompt = ChatPromptTemplate.from_messages(
88
+ [
89
+ ("system", self.llm_config.get("llm_prompt_template")),
90
+ # *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO
91
+ # We need to at some point chop this history down to fixed amount of recent messages
92
+ MessagesPlaceholder(variable_name="formatted_chat_history"),
93
+ # User's most current question
94
+ ("user", "{question}"),
95
+ ]
96
+ )
97
+ return prompt
98
+
99
+ # Format the converastion history to fit into the prompt template above.
100
+ # **** TODO after only a few statements this will likely overflow the context window
101
+ def format_chat_history_for_prompt(self, chat_messages_array):
102
+ history = self.extract_chat_history(chat_messages_array)
103
+ formatted_chat_history = []
104
+ if len(history) > 0:
105
+ for chat_message in history:
106
+ if chat_message["role"] == "user":
107
+ formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
108
+ elif chat_message["role"] == "assistant":
109
+ formatted_chat_history.append(AIMessage(content=chat_message["content"]))
110
+ return formatted_chat_history
111
+
112
+ def get_query_rewrite_prompt():
113
+ # Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?"
114
+ query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so
115
+ that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant
116
+ information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation.
117
+
118
+ Chat history: {chat_history}
119
+
120
+ Question: {question}"""
121
+
122
+ query_rewrite_prompt = PromptTemplate(
123
+ template=query_rewrite_template,
124
+ input_variables=["chat_history", "question"],
125
+ )
126
+ return query_rewrite_prompt
127
+
128
+ @st.cache_resource
129
+ def get_model(self):
130
+ # Foundation Model for generation
131
+ model = ChatDatabricks(
132
+ endpoint=self.databricks_resources.get("llm_endpoint_name"),
133
+ extra_params=self.llm_config.get("llm_parameters"),
134
+ )
135
+ return model
136
+
137
+ @st.cache_resource
138
+ def build_chain(self):
139
+ model = self.get_model()
140
+ prompt = self.get_prompt()
141
+ format_context = self.format_context()
142
+ vector_search_as_retriever = self.get_retriever()
143
+ query_rewrite_prompt = self.get_query_rewrite_prompt()
144
+
145
+ # RAG Chain
146
+ chain = (
147
+ {
148
+ # set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string()
149
+ "question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string),
150
+ "chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history),
151
+ "formatted_chat_history": itemgetter("messages")
152
+ | RunnableLambda(self.format_chat_history_for_prompt),
153
+ }
154
+ | RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain
155
+ | {
156
+ "context": RunnableBranch( # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement
157
+ (
158
+ lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html
159
+ query_rewrite_prompt | model | StrOutputParser(), # rewrite question with context
160
+ ),
161
+ itemgetter("question"), # else, just ask the question
162
+ )
163
+ | vector_search_as_retriever # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search
164
+ | RunnableLambda(format_context),
165
+ "formatted_chat_history": itemgetter("formatted_chat_history"),
166
+ "question": itemgetter("question"),
167
+ }
168
+ | prompt # 'context', 'formatted_chat_history', and 'question' passed to prompt
169
+ | model # prompt passed to model
170
+ | StrOutputParser()
171
+ )
172
+
173
+ return chain
174
+
175
+ # ## Tell MLflow logging where to find your chain.
176
+ # mlflow.models.set_model(model=chain)