Spaces:
Sleeping
Sleeping
import os | |
import threading | |
import streamlit as st | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_databricks.vectorstores import DatabricksVectorSearch | |
from itertools import tee | |
DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST") | |
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN") | |
VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME") | |
VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME") | |
if DATABRICKS_HOST is None: | |
raise ValueError("DATABRICKS_HOST environment variable must be set") | |
if DATABRICKS_TOKEN is None: | |
raise ValueError("DATABRICKS_API_TOKEN environment variable must be set") | |
MODEL_AVATAR_URL= "./VU.jpeg" | |
# MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation." | |
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground" | |
EXAMPLE_PROMPTS = [ | |
"Tell me about maximum out-of-pocket costs in healthcare.", | |
"Write a haiku about Nashville, Tennessee.", | |
"How is a data lake used at Vanderbilt University Medical Center?", | |
"In a table, what are some of the greatest hurdles to healthcare in the United States?", | |
"What does EDW stand for in the context of Vanderbilt University Medical Center?", | |
"Code a sql statement that can query a database named 'VUMC'.", | |
"Write a short story about a country concert in Nashville, Tennessee.", | |
] | |
TITLE = "Vanderbilt AI Assistant" | |
DESCRIPTION="""Welcome to the first generation Vanderbilt AI assistant! \n | |
This AI assistant is built atop the Databricks DBRX large language model | |
and is augmented with additional organization-specific knowledge. Specifically, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center | |
terms like **Data Lake**, **EDW** (Enterprise Data Warehouse), **HCERA** (Health Care and Education Reconciliation Act), and **thousands more!** The model has **no access to PHI**. | |
Try querying the model with any of the example prompts below for a simple introduction to both Vanderbilt-specific and general knowledge queries. The purpose of this | |
model is to allow VUMC employees access to an intelligent assistant that improves and expedites VUMC work. \n | |
Feedback and ideas are very welcome! Please provide any feedback, ideas, or issues to the email: **john.graham.reynolds@vumc.org**. | |
We hope to gradually improve this AI assistant to create a large-scale, all-inclusive tool to compliment the work of all VUMC staff- your comments are invaluable in this process.""" | |
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation." | |
# @st.cache_resource | |
# def get_global_semaphore(): | |
# return threading.BoundedSemaphore(QUEUE_SIZE) | |
# global_semaphore = get_global_semaphore() | |
st.set_page_config(layout="wide") | |
# # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks | |
TOKEN_CHUNK_SIZE = 1 | |
# if TOKEN_CHUNK_SIZE_ENV is not None: | |
# TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV) | |
st.title(TITLE) | |
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space! | |
st.markdown(DESCRIPTION) | |
st.markdown("\n") | |
# use this to format later | |
with open("./style.css") as css: | |
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True) | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
def clear_chat_history(): | |
st.session_state["messages"] = [] | |
st.button('Clear Chat', on_click=clear_chat_history) | |
def last_role_is_user(): | |
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user" | |
def get_system_prompt(): | |
return "" | |
# ** working logic for querying glossary embeddings | |
# Same embedding model we used to create embeddings of terms | |
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping | |
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model | |
# does this cache to the given folder though? It does appear to populate the folder as expected after being run | |
# will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching | |
def load_embedding_model(): | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/") | |
return embeddings | |
embeddings = load_embedding_model() | |
# instantiate the vector store for similarity search in our chain | |
# need to make this a function and decorate it with @st.experimental_memo as above? | |
# We are only calling this initially when the Space starts. Can we expedite this process for users when opening up this Space? | |
# @st.cache_data # TODO add this in | |
vector_store = DatabricksVectorSearch( | |
endpoint=VS_ENDPOINT_NAME, | |
index_name=VS_INDEX_NAME, | |
embedding=embeddings, | |
text_column="name", | |
columns=["name", "description"], | |
) | |
def text_stream(stream): | |
for chunk in stream: | |
if chunk["content"] is not None: | |
yield chunk["content"] | |
def get_stream_warning_error(stream): | |
error = None | |
warning = None | |
# for chunk in stream: | |
# if chunk["error"] is not None: | |
# error = chunk["error"] | |
# if chunk["warning"] is not None: | |
# warning = chunk["warning"] | |
return warning, error | |
# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3)) | |
def chat_api_call(history): | |
# *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually | |
# extra_body = {} | |
# if SAFETY_FILTER: | |
# extra_body["enable_safety_filter"] = SAFETY_FILTER | |
# chat_completion = client.chat.completions.create( | |
# messages=[ | |
# {"role": m["role"], "content": m["content"]} | |
# for m in history | |
# ], | |
# model="databricks-dbrx-instruct", | |
# stream=True, | |
# max_tokens=MAX_TOKENS, | |
# temperature=0.7, | |
# extra_body= extra_body | |
# ) | |
# ** TODO update this next to take and do similarity search on user input! | |
st.write(history) | |
search_result = vector_store.similarity_search(query=st.session_state["messages"][-1]["content"], k=5) | |
chat_completion = search_result # TODO update this after we implement our chain | |
return chat_completion | |
def write_response(): | |
stream = chat_completion(st.session_state["messages"]) | |
content_stream, error_stream = tee(stream) | |
response = st.write_stream(text_stream(content_stream)) | |
stream_warning, stream_error = get_stream_warning_error(error_stream) | |
if stream_warning is not None: | |
st.warning(stream_warning,icon="β οΈ") | |
if stream_error is not None: | |
st.error(stream_error,icon="π¨") | |
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream | |
if isinstance(response, list): | |
response = None | |
return response, stream_warning, stream_error | |
def chat_completion(messages): | |
history_dbrx_format = [ | |
{"role": "system", "content": get_system_prompt()} | |
] | |
history_dbrx_format = history_dbrx_format + messages | |
# if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS: | |
# yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None} | |
# return | |
chat_completion = None | |
error = None | |
# *** original code for querying DBRX through the OpenAI cleint for chat completion | |
# wait to be in queue | |
# with global_semaphore: | |
# try: | |
# chat_completion = chat_api_call(history_dbrx_format) | |
# except Exception as e: | |
# error = e | |
chat_completion = chat_api_call(history_dbrx_format) | |
if error is not None: | |
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None} | |
print(error) | |
return | |
max_token_warning = None | |
partial_message = "" | |
chunk_counter = 0 | |
for chunk in chat_completion: | |
# if chunk.choices[0].delta.content is not None: | |
if chunk.page_content is not None: | |
chunk_counter += 1 | |
# partial_message += chunk.choices[0].delta.content | |
partial_message += f"* {chunk.page_content} [{chunk.metadata}]" | |
if chunk_counter % TOKEN_CHUNK_SIZE == 0: | |
chunk_counter = 0 | |
yield {"content": partial_message, "error": None, "warning": None} | |
partial_message = "" | |
# if chunk.choices[0].finish_reason == "length": | |
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS | |
yield {"content": partial_message, "error": None, "warning": max_token_warning} | |
# if assistant is the last message, we need to prompt the user | |
# if user is the last message, we need to retry the assistant. | |
def handle_user_input(user_input): | |
with history: | |
response, stream_warning, stream_error = [None, None, None] | |
if last_role_is_user(): | |
# retry the assistant if the user tries to send a new message | |
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
response, stream_warning, stream_error = write_response() | |
else: | |
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None}) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
stream = chat_completion(st.session_state["messages"]) | |
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): | |
response, stream_warning, stream_error = write_response() | |
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error}) | |
main = st.container() | |
with main: | |
history = st.container(height=400) | |
with history: | |
for message in st.session_state["messages"]: | |
avatar = "π§βπ»" | |
if message["role"] == "assistant": | |
avatar = MODEL_AVATAR_URL | |
with st.chat_message(message["role"],avatar=avatar): | |
if message["content"] is not None: | |
st.markdown(message["content"]) | |
# if message["error"] is not None: | |
# st.error(message["error"],icon="π¨") | |
# if message["warning"] is not None: | |
# st.warning(message["warning"],icon="β οΈ") | |
if prompt := st.chat_input("Type a message!", max_chars=1000): | |
handle_user_input(prompt) | |
st.markdown("\n") #add some space for iphone users | |
with st.sidebar: | |
with st.container(): | |
st.title("Examples") | |
for prompt in EXAMPLE_PROMPTS: | |
st.button(prompt, args=(prompt,), on_click=handle_user_input) |