llm-explorer / app.py
carolanderson's picture
adjust decoding controls
7081223
import logging
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceHub
from langchain.prompts.chat import (
PromptTemplate,
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from openai.error import AuthenticationError
import streamlit as st
def setup_memory():
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
chat_memory=msgs,
return_messages=True)
logging.info("setting up new chat memory")
return memory
def use_existing_chain(model, provider, model_kwargs):
# TODO: consider whether prompt needs to be checked here
if "mistral" in model:
return False
if "current_chain" in st.session_state:
current_chain = st.session_state.current_chain
if (current_chain.model == model) \
and (current_chain.provider == provider) \
and (current_chain.model_kwargs == model_kwargs):
return True
return False
class CurrentChain():
def __init__(self, model, provider, prompt, memory, model_kwargs):
self.model = model
self.provider = provider
self.model_kwargs = model_kwargs
logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
if provider == "OpenAI":
llm = ChatOpenAI(model_name=model,
temperature=model_kwargs['temperature']
)
elif provider == "HuggingFace":
llm = HuggingFaceHub(repo_id=model,
model_kwargs=model_kwargs
)
self.conversation = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
memory=memory
)
def format_mistral_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
st.header("Basic chatbot")
st.write("On small screens, click the `>` at top left to choose options")
with st.expander("How conversation history works"):
st.write("To keep input lengths down and costs reasonable,"
" only the past three turns of conversation "
" are used for OpenAI models. Otherwise the entire chat history is used.")
st.write("To clear all memory and start fresh, click 'Clear history'")
st.sidebar.title("Choose options")
#### USER INPUT ######
model_name = st.sidebar.selectbox(
label="Choose a model",
options=["gpt-3.5-turbo (OpenAI)",
# "bigscience/bloom (HuggingFace)", # runs
# "google/flan-t5-xxl (HuggingFace)", # runs
"mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
],
help="Which LLM to use",
)
temp = st.sidebar.slider(
label="Temperature",
min_value=float(0),
max_value=2.0,
step=0.1,
value=0.4,
help="Set the decoding temperature. "
"Higher temps give more unpredictable outputs."
)
##########################
model = model_name.split("(")[0].rstrip() # remove name of model provider
provider = model_name.split("(")[-1].split(")")[0]
model_kwargs = {"temperature": temp,
"max_new_tokens": 256,
"repetition_penalty": 1.0,
"top_p": 0.95,
"do_sample": True,
"seed": 42}
# TODO: maybe expose more of these to the user
if "session_memory" not in st.session_state:
st.session_state.session_memory = setup_memory() # for openai
if "history" not in st.session_state:
st.session_state.history = [] # for mistral
if "mistral" in model:
prompt = PromptTemplate(input_variables=["input"],
template="{input}")
else:
prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
"You are a nice chatbot having a conversation with a human."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{input}")
],
verbose=True
)
if use_existing_chain(model, provider, model_kwargs):
chain = st.session_state.current_chain
else:
chain = CurrentChain(model,
provider,
prompt,
st.session_state.session_memory,
model_kwargs)
st.session_state.current_chain = chain
conversation = chain.conversation
if st.button("Clear history"):
conversation.memory.clear() # for openai
st.session_state.history = [] # for mistral
logging.info("history cleared")
for user_msg, asst_msg in st.session_state.history:
with st.chat_message("user"):
st.write(user_msg)
with st.chat_message("assistant"):
st.write(asst_msg)
text = st.chat_input()
if text:
with st.chat_message("user"):
st.write(text)
logging.info(text)
try:
if "mistral" in model:
full_prompt = format_mistral_prompt(text, st.session_state.history)
result = conversation.predict(input=full_prompt)
else:
result = conversation.predict(input=text)
st.session_state.history.append((text, result))
logging.info(repr(result))
with st.chat_message("assistant"):
st.write(result)
except (AuthenticationError, ValueError):
st.warning("Supply a valid API key", icon="⚠️")