import streamlit as st import os import asyncio from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.vectorstores import Chroma from langchain_together import Together from langchain_community.chat_message_histories import StreamlitChatMessageHistory from langchain_community.document_loaders import WebBaseLoader from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter # Initialize the LLMs llm = Together( model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, top_k=12, max_tokens=22048, together_api_key=os.environ['pilotikval'] ) # Function to store chat history store = {} model_name = "BAAI/bge-base-en" encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity embedding_function = HuggingFaceBgeEmbeddings( model_name=model_name, encode_kwargs=encode_kwargs ) def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = StreamlitChatMessageHistory(key=session_id) return store[session_id] # Define the Streamlit app def app(): with st.sidebar: st.title("dochatter") option = st.selectbox( 'Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine') ) # Define retrievers based on option persist_directory = { 'General Medicine': "./oxfordmedbookdir/", 'Respiratory1': "./respfishmandbcud/", 'Respiratory2': "./respmurray/", 'Med2.2': "./medmrcp2store/", 'Med2.1': "./mrcpchromadb/" }.get(option, "./mrcpchromadb/") collection_name = { 'General Medicine': "oxfordmed", 'Respiratory1': "fishmannotescud", 'Respiratory2': "respmurraynotes", 'Med2.2': "medmrcp2notes", 'Med2.1': "mrcppassmednotes" }.get(option, "mrcppassmednotes") vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name) retriever = vectordb.as_retriever(search_kwargs={"k": 5}) # Define the prompt templates contextualize_q_system_prompt = ( "Given a chat history and the latest user question " "which might reference context in the chat history, " "formulate a standalone question which can be understood " "without the chat history. Do NOT answer the question, " "just reformulate it if needed and otherwise return it as is." ) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) system_prompt = ( "You are helping a doctor. Be as detailed and thorough as possible " "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know." "\n\n" "{context}" ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) # Statefully manage chat history conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) # Session State if "messages" not in st.session_state.keys(): st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] st.header("Hello Doc!") for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) prompts2 = st.chat_input("Say something") if prompts2: st.session_state.messages.append({"role": "user", "content": prompts2}) with st.chat_message("user"): st.write(prompts2) if st.session_state.messages[-1]["role"] != "assistant": with st.chat_message("assistant"): with st.spinner("Thinking..."): final_response = conversational_rag_chain.invoke( { "input": prompts2, }, config={"configurable": {"session_id": "current_session"}} ) st.write(final_response['answer']) st.session_state.messages.append({"role": "assistant", "content": final_response['answer']}) if __name__ == '__main__': app()