import logging from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.llms import HuggingFaceHub from langchain.prompts.chat import ( PromptTemplate, ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.memory import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import StreamlitChatMessageHistory from openai.error import AuthenticationError import streamlit as st def setup_memory(): msgs = StreamlitChatMessageHistory(key="basic_chat_app") memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", chat_memory=msgs, return_messages=True) logging.info("setting up new chat memory") return memory def use_existing_chain(model, provider, model_kwargs): # TODO: consider whether prompt needs to be checked here if "mistral" in model: return False if "current_chain" in st.session_state: current_chain = st.session_state.current_chain if (current_chain.model == model) \ and (current_chain.provider == provider) \ and (current_chain.model_kwargs == model_kwargs): return True return False class CurrentChain(): def __init__(self, model, provider, prompt, memory, model_kwargs): self.model = model self.provider = provider self.model_kwargs = model_kwargs logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}") if provider == "OpenAI": llm = ChatOpenAI(model_name=model, temperature=model_kwargs['temperature'] ) elif provider == "HuggingFace": llm = HuggingFaceHub(repo_id=model, model_kwargs=model_kwargs ) self.conversation = LLMChain( llm=llm, prompt=prompt, verbose=True, memory=memory ) def format_mistral_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt if __name__ == "__main__": logging.basicConfig(level=logging.INFO) st.header("Basic chatbot") st.write("On small screens, click the `>` at top left to choose options") with st.expander("How conversation history works"): st.write("To keep input lengths down and costs reasonable," " only the past three turns of conversation " " are used for OpenAI models. Otherwise the entire chat history is used.") st.write("To clear all memory and start fresh, click 'Clear history'") st.sidebar.title("Choose options") #### USER INPUT ###### model_name = st.sidebar.selectbox( label="Choose a model", options=["gpt-3.5-turbo (OpenAI)", # "bigscience/bloom (HuggingFace)", # runs # "google/flan-t5-xxl (HuggingFace)", # runs "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)" ], help="Which LLM to use", ) temp = st.sidebar.slider( label="Temperature", min_value=float(0), max_value=2.0, step=0.1, value=0.4, help="Set the decoding temperature. " "Higher temps give more unpredictable outputs." ) ########################## model = model_name.split("(")[0].rstrip() # remove name of model provider provider = model_name.split("(")[-1].split(")")[0] model_kwargs = {"temperature": temp, "max_new_tokens": 256, "repetition_penalty": 1.0, "top_p": 0.95, "do_sample": True, "seed": 42} # TODO: maybe expose more of these to the user if "session_memory" not in st.session_state: st.session_state.session_memory = setup_memory() # for openai if "history" not in st.session_state: st.session_state.history = [] # for mistral if "mistral" in model: prompt = PromptTemplate(input_variables=["input"], template="{input}") else: prompt = ChatPromptTemplate( messages=[ SystemMessagePromptTemplate.from_template( "You are a nice chatbot having a conversation with a human." ), MessagesPlaceholder(variable_name="chat_history"), HumanMessagePromptTemplate.from_template("{input}") ], verbose=True ) if use_existing_chain(model, provider, model_kwargs): chain = st.session_state.current_chain else: chain = CurrentChain(model, provider, prompt, st.session_state.session_memory, model_kwargs) st.session_state.current_chain = chain conversation = chain.conversation if st.button("Clear history"): conversation.memory.clear() # for openai st.session_state.history = [] # for mistral logging.info("history cleared") for user_msg, asst_msg in st.session_state.history: with st.chat_message("user"): st.write(user_msg) with st.chat_message("assistant"): st.write(asst_msg) text = st.chat_input() if text: with st.chat_message("user"): st.write(text) logging.info(text) try: if "mistral" in model: full_prompt = format_mistral_prompt(text, st.session_state.history) result = conversation.predict(input=full_prompt) else: result = conversation.predict(input=text) st.session_state.history.append((text, result)) logging.info(repr(result)) with st.chat_message("assistant"): st.write(result) except (AuthenticationError, ValueError): st.warning("Supply a valid API key", icon="⚠️")