Spaces:
Sleeping
Sleeping
John Graham Reynolds
add chain for reformatting inputs and augmenting the question with relevant context
29cf982
import os | |
import mlflow | |
import streamlit as st | |
from operator import itemgetter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_databricks.vectorstores import DatabricksVectorSearch | |
from langchain_community.chat_models import ChatDatabricks | |
from langchain_community.vectorstores import DatabricksVectorSearch | |
from langchain_core.runnables import RunnableLambda | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnablePassthrough, RunnableBranch | |
from langchain_core.messages import HumanMessage, AIMessage | |
# ## Enable MLflow Tracing | |
# mlflow.langchain.autolog() | |
class ChainBuilder: | |
def __init__(self): | |
# Load the chain's configuration from yaml | |
self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml") | |
self.databricks_resources = self.model_config.get("databricks_resources") | |
self.llm_config = self.model_config.get("llm_config") | |
self.retriever_config = self.model_config.get("retriever_config") | |
self.vector_search_schema = self.retriever_config.get("schema") | |
# Return the string contents of the most recent message from the user | |
def extract_user_query_string(chat_messages_array): | |
return chat_messages_array[-1]["content"] | |
# Return the chat history, which is everything before the last question | |
def extract_chat_history(chat_messages_array): | |
return chat_messages_array[:-1] | |
# ** working logic for querying glossary embeddings | |
# Same embedding model we used to create embeddings of terms | |
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping | |
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model | |
# does this cache to the given folder though? It does appear to populate the folder as expected after being run | |
# will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching | |
def load_embedding_model(self): | |
embeddings = HuggingFaceEmbeddings(model_name=self.retriever_config.get("embedding_model"), cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container | |
# update this to read from a presaved cache of bge-large | |
return embeddings | |
def get_retriever(self): | |
embeddings = self.load_embedding_model() | |
# instantiate the vector store for similarity search in our chain | |
# need to make this a function and decorate it with @st.experimental_memo as above? | |
# We are only calling this initiatially when the Space starts and builds the chain. Can we expedite this process for users when opening up this Space? | |
# @st.cache_data # TODO add this in | |
vector_search_as_retriever = DatabricksVectorSearch( | |
endpoint=self.databricks_resources.get("vector_search_endpoint_name"), | |
index_name=self.retriever_config.get("vector_search_index"), | |
embedding=embeddings, | |
text_column="name", | |
columns=["name", "description"], | |
).as_retriever(search_kwargs=self.retriever_config.get("parameters")) | |
return vector_search_as_retriever | |
# # *** TODO Evaluate this block as it relates to "RAG Studio Review App" *** | |
# # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever | |
# mlflow.models.set_retriever_schema( | |
# primary_key=self.vector_search_schema.get("primary_key"), | |
# text_column=vector_search_schema.get("chunked_terms"), | |
# # doc_uri=vector_search_schema.get("definition") | |
# other_columns=[vector_search_schema.get("definition")], | |
# # Review App uses `doc_uri` to display chunks from the same document in a single view | |
# ) | |
# Method to format the terms and definitions returned by the retriever into the prompt | |
# TODO double check the contents here | |
def format_context(self, retrieved_terms): | |
chunk_template = self.retriever_config.get("chunk_template") | |
chunk_contents = [ | |
chunk_template.format( | |
name=term.page_content, | |
description=term.metadata[self.vector_search_schema.get("description")], | |
) | |
for term in retrieved_terms | |
] | |
return "".join(chunk_contents) | |
def get_prompt(self): | |
# Prompt Template for generation | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", self.llm_config.get("llm_prompt_template")), | |
# *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO | |
# We need to at some point chop this history down to fixed amount of recent messages | |
MessagesPlaceholder(variable_name="formatted_chat_history"), | |
# User's most current question | |
("user", "{question}"), | |
] | |
) | |
return prompt | |
# Format the converastion history to fit into the prompt template above. | |
# **** TODO after only a few statements this will likely overflow the context window | |
def format_chat_history_for_prompt(self, chat_messages_array): | |
history = self.extract_chat_history(chat_messages_array) | |
formatted_chat_history = [] | |
if len(history) > 0: | |
for chat_message in history: | |
if chat_message["role"] == "user": | |
formatted_chat_history.append(HumanMessage(content=chat_message["content"])) | |
elif chat_message["role"] == "assistant": | |
formatted_chat_history.append(AIMessage(content=chat_message["content"])) | |
return formatted_chat_history | |
def get_query_rewrite_prompt(): | |
# Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?" | |
query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so | |
that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant | |
information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation. | |
Chat history: {chat_history} | |
Question: {question}""" | |
query_rewrite_prompt = PromptTemplate( | |
template=query_rewrite_template, | |
input_variables=["chat_history", "question"], | |
) | |
return query_rewrite_prompt | |
def get_model(self): | |
# Foundation Model for generation | |
model = ChatDatabricks( | |
endpoint=self.databricks_resources.get("llm_endpoint_name"), | |
extra_params=self.llm_config.get("llm_parameters"), | |
) | |
return model | |
def build_chain(self): | |
model = self.get_model() | |
prompt = self.get_prompt() | |
format_context = self.format_context() | |
vector_search_as_retriever = self.get_retriever() | |
query_rewrite_prompt = self.get_query_rewrite_prompt() | |
# RAG Chain | |
chain = ( | |
{ | |
# set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string() | |
"question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), | |
"chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history), | |
"formatted_chat_history": itemgetter("messages") | |
| RunnableLambda(self.format_chat_history_for_prompt), | |
} | |
| RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain | |
| { | |
"context": RunnableBranch( # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement | |
( | |
lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html | |
query_rewrite_prompt | model | StrOutputParser(), # rewrite question with context | |
), | |
itemgetter("question"), # else, just ask the question | |
) | |
| vector_search_as_retriever # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search | |
| RunnableLambda(format_context), | |
"formatted_chat_history": itemgetter("formatted_chat_history"), | |
"question": itemgetter("question"), | |
} | |
| prompt # 'context', 'formatted_chat_history', and 'question' passed to prompt | |
| model # prompt passed to model | |
| StrOutputParser() | |
) | |
return chain | |
# ## Tell MLflow logging where to find your chain. | |
# mlflow.models.set_model(model=chain) |