Spaces:
Runtime error
Runtime error
import logging | |
from langchain.chains import LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.llms import HuggingFaceHub | |
from langchain.prompts.chat import ( | |
PromptTemplate, | |
ChatPromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory | |
from openai.error import AuthenticationError | |
import streamlit as st | |
def setup_memory(): | |
msgs = StreamlitChatMessageHistory(key="basic_chat_app") | |
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", | |
chat_memory=msgs, | |
return_messages=True) | |
logging.info("setting up new chat memory") | |
return memory | |
def use_existing_chain(model, provider, model_kwargs): | |
# TODO: consider whether prompt needs to be checked here | |
if "mistral" in model: | |
return False | |
if "current_chain" in st.session_state: | |
current_chain = st.session_state.current_chain | |
if (current_chain.model == model) \ | |
and (current_chain.provider == provider) \ | |
and (current_chain.model_kwargs == model_kwargs): | |
return True | |
return False | |
class CurrentChain(): | |
def __init__(self, model, provider, prompt, memory, model_kwargs): | |
self.model = model | |
self.provider = provider | |
self.model_kwargs = model_kwargs | |
logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}") | |
if provider == "OpenAI": | |
llm = ChatOpenAI(model_name=model, | |
temperature=model_kwargs['temperature'] | |
) | |
elif provider == "HuggingFace": | |
llm = HuggingFaceHub(repo_id=model, | |
model_kwargs=model_kwargs | |
) | |
self.conversation = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
verbose=True, | |
memory=memory | |
) | |
def format_mistral_prompt(message, history): | |
prompt = "<s>" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
st.header("Basic chatbot") | |
st.write("On small screens, click the `>` at top left to choose options") | |
with st.expander("How conversation history works"): | |
st.write("To keep input lengths down and costs reasonable," | |
" only the past three turns of conversation " | |
" are used for OpenAI models. Otherwise the entire chat history is used.") | |
st.write("To clear all memory and start fresh, click 'Clear history'") | |
st.sidebar.title("Choose options") | |
#### USER INPUT ###### | |
model_name = st.sidebar.selectbox( | |
label="Choose a model", | |
options=["gpt-3.5-turbo (OpenAI)", | |
# "bigscience/bloom (HuggingFace)", # runs | |
# "google/flan-t5-xxl (HuggingFace)", # runs | |
"mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)" | |
], | |
help="Which LLM to use", | |
) | |
temp = st.sidebar.slider( | |
label="Temperature", | |
min_value=float(0), | |
max_value=2.0, | |
step=0.1, | |
value=0.4, | |
help="Set the decoding temperature. " | |
"Higher temps give more unpredictable outputs." | |
) | |
########################## | |
model = model_name.split("(")[0].rstrip() # remove name of model provider | |
provider = model_name.split("(")[-1].split(")")[0] | |
model_kwargs = {"temperature": temp, | |
"max_new_tokens": 256, | |
"repetition_penalty": 1.0, | |
"top_p": 0.95, | |
"do_sample": True, | |
"seed": 42} | |
# TODO: maybe expose more of these to the user | |
if "session_memory" not in st.session_state: | |
st.session_state.session_memory = setup_memory() # for openai | |
if "history" not in st.session_state: | |
st.session_state.history = [] # for mistral | |
if "mistral" in model: | |
prompt = PromptTemplate(input_variables=["input"], | |
template="{input}") | |
else: | |
prompt = ChatPromptTemplate( | |
messages=[ | |
SystemMessagePromptTemplate.from_template( | |
"You are a nice chatbot having a conversation with a human." | |
), | |
MessagesPlaceholder(variable_name="chat_history"), | |
HumanMessagePromptTemplate.from_template("{input}") | |
], | |
verbose=True | |
) | |
if use_existing_chain(model, provider, model_kwargs): | |
chain = st.session_state.current_chain | |
else: | |
chain = CurrentChain(model, | |
provider, | |
prompt, | |
st.session_state.session_memory, | |
model_kwargs) | |
st.session_state.current_chain = chain | |
conversation = chain.conversation | |
if st.button("Clear history"): | |
conversation.memory.clear() # for openai | |
st.session_state.history = [] # for mistral | |
logging.info("history cleared") | |
for user_msg, asst_msg in st.session_state.history: | |
with st.chat_message("user"): | |
st.write(user_msg) | |
with st.chat_message("assistant"): | |
st.write(asst_msg) | |
text = st.chat_input() | |
if text: | |
with st.chat_message("user"): | |
st.write(text) | |
logging.info(text) | |
try: | |
if "mistral" in model: | |
full_prompt = format_mistral_prompt(text, st.session_state.history) | |
result = conversation.predict(input=full_prompt) | |
else: | |
result = conversation.predict(input=text) | |
st.session_state.history.append((text, result)) | |
logging.info(repr(result)) | |
with st.chat_message("assistant"): | |
st.write(result) | |
except (AuthenticationError, ValueError): | |
st.warning("Supply a valid API key", icon="⚠️") | |