import os import mlflow import streamlit as st from operator import itemgetter from langchain_huggingface import HuggingFaceEmbeddings from langchain_databricks.vectorstores import DatabricksVectorSearch from langchain_community.chat_models import ChatDatabricks from langchain_community.vectorstores import DatabricksVectorSearch from langchain_core.runnables import RunnableLambda from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough, RunnableBranch from langchain_core.messages import HumanMessage, AIMessage # ## Enable MLflow Tracing # mlflow.langchain.autolog() class ChainBuilder: def __init__(self): # Load the chain's configuration from yaml self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml") self.databricks_resources = self.model_config.get("databricks_resources") self.llm_config = self.model_config.get("llm_config") self.retriever_config = self.model_config.get("retriever_config") self.vector_search_schema = self.retriever_config.get("schema") # Return the string contents of the most recent message from the user def extract_user_query_string(chat_messages_array): return chat_messages_array[-1]["content"] # Return the chat history, which is everything before the last question def extract_chat_history(chat_messages_array): return chat_messages_array[:-1] # ** working logic for querying glossary embeddings # Same embedding model we used to create embeddings of terms # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model # does this cache to the given folder though? It does appear to populate the folder as expected after being run @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching def load_embedding_model(self): embeddings = HuggingFaceEmbeddings(model_name=self.retriever_config.get("embedding_model"), cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container # update this to read from a presaved cache of bge-large return embeddings def get_retriever(self): embeddings = self.load_embedding_model() # instantiate the vector store for similarity search in our chain # need to make this a function and decorate it with @st.experimental_memo as above? # We are only calling this initiatially when the Space starts and builds the chain. Can we expedite this process for users when opening up this Space? # @st.cache_data # TODO add this in vector_search_as_retriever = DatabricksVectorSearch( endpoint=self.databricks_resources.get("vector_search_endpoint_name"), index_name=self.retriever_config.get("vector_search_index"), embedding=embeddings, text_column="name", columns=["name", "description"], ).as_retriever(search_kwargs=self.retriever_config.get("parameters")) return vector_search_as_retriever # # *** TODO Evaluate this block as it relates to "RAG Studio Review App" *** # # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever # mlflow.models.set_retriever_schema( # primary_key=self.vector_search_schema.get("primary_key"), # text_column=vector_search_schema.get("chunked_terms"), # # doc_uri=vector_search_schema.get("definition") # other_columns=[vector_search_schema.get("definition")], # # Review App uses `doc_uri` to display chunks from the same document in a single view # ) # Method to format the terms and definitions returned by the retriever into the prompt # TODO double check the contents here def format_context(self, retrieved_terms): chunk_template = self.retriever_config.get("chunk_template") chunk_contents = [ chunk_template.format( name=term.page_content, description=term.metadata[self.vector_search_schema.get("description")], ) for term in retrieved_terms ] return "".join(chunk_contents) def get_prompt(self): # Prompt Template for generation prompt = ChatPromptTemplate.from_messages( [ ("system", self.llm_config.get("llm_prompt_template")), # *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO # We need to at some point chop this history down to fixed amount of recent messages MessagesPlaceholder(variable_name="formatted_chat_history"), # User's most current question ("user", "{question}"), ] ) return prompt # Format the converastion history to fit into the prompt template above. # **** TODO after only a few statements this will likely overflow the context window def format_chat_history_for_prompt(self, chat_messages_array): history = self.extract_chat_history(chat_messages_array) formatted_chat_history = [] if len(history) > 0: for chat_message in history: if chat_message["role"] == "user": formatted_chat_history.append(HumanMessage(content=chat_message["content"])) elif chat_message["role"] == "assistant": formatted_chat_history.append(AIMessage(content=chat_message["content"])) return formatted_chat_history def get_query_rewrite_prompt(): # Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?" query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation. Chat history: {chat_history} Question: {question}""" query_rewrite_prompt = PromptTemplate( template=query_rewrite_template, input_variables=["chat_history", "question"], ) return query_rewrite_prompt @st.cache_resource def get_model(self): # Foundation Model for generation model = ChatDatabricks( endpoint=self.databricks_resources.get("llm_endpoint_name"), extra_params=self.llm_config.get("llm_parameters"), ) return model @st.cache_resource def build_chain(self): model = self.get_model() prompt = self.get_prompt() format_context = self.format_context() vector_search_as_retriever = self.get_retriever() query_rewrite_prompt = self.get_query_rewrite_prompt() # RAG Chain chain = ( { # set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string() "question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), "chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history), "formatted_chat_history": itemgetter("messages") | RunnableLambda(self.format_chat_history_for_prompt), } | RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain | { "context": RunnableBranch( # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement ( lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html query_rewrite_prompt | model | StrOutputParser(), # rewrite question with context ), itemgetter("question"), # else, just ask the question ) | vector_search_as_retriever # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search | RunnableLambda(format_context), "formatted_chat_history": itemgetter("formatted_chat_history"), "question": itemgetter("question"), } | prompt # 'context', 'formatted_chat_history', and 'question' passed to prompt | model # prompt passed to model | StrOutputParser() ) return chain # ## Tell MLflow logging where to find your chain. # mlflow.models.set_model(model=chain)