import os import threading import streamlit as st from itertools import tee from chain import ChainBuilder DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST") DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN") # remove these secrets from the container # VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME") # VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME") if DATABRICKS_HOST is None: raise ValueError("DATABRICKS_HOST environment variable must be set") if DATABRICKS_TOKEN is None: raise ValueError("DATABRICKS_TOKEN environment variable must be set") MODEL_AVATAR_URL= "./VU.jpeg" MAX_CHAT_TURNS = 10 # limit this for preliminary testing MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation." # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground" EXAMPLE_PROMPTS = [ "How is a data lake used at Vanderbilt University Medical Center?", "In a table, what are some of the greatest hurdles to healthcare in the United States?", "What does EDW stand for in the context of Vanderbilt University Medical Center?", "Code a sql statement that can query a database named 'VUMC'.", "Write a short story about a country concert in Nashville, Tennessee.", "Tell me about maximum out-of-pocket costs in healthcare.", ] TITLE = "Vanderbilt AI Assistant" DESCRIPTION= """Welcome to the first generation Vanderbilt AI assistant! \n **Overview and Usage**: This AI assistant is built atop the Databricks DBRX large language model and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**; click any of these examples to issue the corresponding query to the AI. **Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title. Please be sure to select either 👍 or 👎 before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this feedback allows us to later improve this model for your usage through a training technique known as reinforced learning through human feedback. \n **Disclaimer**: The model has **no access to PHI**. \n Please provide any additional, larger feedback, ideas, or issues to the email: **john.graham.reynolds@vumc.org**. Happy chatting!""" GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation." # # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks TOKEN_CHUNK_SIZE = 1 # test this number # if TOKEN_CHUNK_SIZE_ENV is not None: # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV) QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue? # if QUEUE_SIZE_ENV is not None: # QUEUE_SIZE = int(QUEUE_SIZE_ENV) # @st.cache_resource # def get_global_semaphore(): # return threading.BoundedSemaphore(QUEUE_SIZE) # global_semaphore = get_global_semaphore() st.set_page_config(layout="wide") st.title(TITLE) # st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space! st.markdown(DESCRIPTION) st.markdown("\n") # use this to format later with open("./style.css") as css: st.markdown( f'' , unsafe_allow_html= True) if "messages" not in st.session_state: st.session_state["messages"] = [] if "feedback" not in st.session_state: st.session_state["feedback"] = [None] def clear_chat_history(): st.session_state["messages"] = [] st.button('Clear Chat', on_click=clear_chat_history) # build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion chain = ChainBuilder().build_chain() def last_role_is_user(): return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user" def text_stream(stream): for chunk in stream: if chunk["content"] is not None: yield chunk["content"] def get_stream_warning_error(stream): error = None warning = None for chunk in stream: if chunk["error"] is not None: error = chunk["error"] if chunk["warning"] is not None: warning = chunk["warning"] return warning, error # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3)) def chain_call(history): input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]} chat_completion = chain.stream(input) return chat_completion def write_response(): stream = chat_completion(st.session_state["messages"]) content_stream, error_stream = tee(stream) response = st.write_stream(text_stream(content_stream)) stream_warning, stream_error = get_stream_warning_error(error_stream) if stream_warning is not None: st.warning(stream_warning,icon="⚠️") if stream_error is not None: st.error(stream_error,icon="🚨") # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream if isinstance(response, list): response = None return response, stream_warning, stream_error def chat_completion(messages): if (len(messages)-1)//2 >= MAX_CHAT_TURNS: yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None} return chat_completion = None error = None # *** TODO add code for implementing a global queue with a bounded semaphore? # wait to be in queue # with global_semaphore: # try: # chat_completion = chat_api_call(history_dbrx_format) # except Exception as e: # error = e # chat_completion = chain_call(history_dbrx_format) chat_completion = chain_call(messages) if error is not None: yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None} print(error) return max_token_warning = None partial_message = "" chunk_counter = 0 for chunk in chat_completion: if chunk is not None: chunk_counter += 1 partial_message += chunk if chunk_counter % TOKEN_CHUNK_SIZE == 0: chunk_counter = 0 yield {"content": partial_message, "error": None, "warning": None} partial_message = "" # if chunk.choices[0].finish_reason == "length": # max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS yield {"content": partial_message, "error": None, "warning": max_token_warning} # if assistant is the last message, we need to prompt the user # if user is the last message, we need to retry the assistant. def handle_user_input(user_input): with history: response, stream_warning, stream_error = [None, None, None] if last_role_is_user(): # retry the assistant if the user tries to send a new message with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): response, stream_warning, stream_error = write_response() else: st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None}) with st.chat_message("user", avatar="🧑‍💻"): st.markdown(user_input) stream = chat_completion(st.session_state["messages"]) with st.chat_message("assistant", avatar=MODEL_AVATAR_URL): response, stream_warning, stream_error = write_response() st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error}) def feedback(): with st.form("feedback_form"): st.title("Feedback Form") st.markdown("Please select either 👍 or 👎 before providing a reason for your review of the most recent response. Dont forget to click submit!") rating = st.feedback() feedback = st.text_input("Please detail your feedback: ") # implement a method for writing these responses to storage! submitted = st.form_submit_button("Submit Feedback") main = st.container() with main: if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn? st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.") history = st.container(height=400) with history: for message in st.session_state["messages"]: avatar = "🧑‍💻" if message["role"] == "assistant": avatar = MODEL_AVATAR_URL with st.chat_message(message["role"], avatar=avatar): if message["content"] is not None: st.markdown(message["content"]) if message["error"] is not None: st.error(message["error"],icon="🚨") if message["warning"] is not None: st.warning(message["warning"],icon="⚠️") if prompt := st.chat_input("Type a message!", max_chars=5000): handle_user_input(prompt) st.markdown("\n") #add some space for iphone users gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback) if gave_feedback: # TODO clean up the conditions here with a function st.session_state["feedback"].append("given") else: st.session_state["feedback"].append(None) with st.sidebar: with st.container(): st.title("Examples") for prompt in EXAMPLE_PROMPTS: st.button(prompt, args=(prompt,), on_click=handle_user_input)