Spaces:
Runtime error
Runtime error
File size: 6,373 Bytes
9814d4c a8a9ff0 9814d4c a8a9ff0 7081223 a8a9ff0 b5792ea 1927487 a8a9ff0 732d634 b5792ea 7081223 b5792ea 9814d4c 732d634 7081223 f2f3156 08b9e09 7081223 08b9e09 7081223 08b9e09 7081223 08b9e09 7081223 08b9e09 7081223 f2f3156 08b9e09 7081223 f2f3156 7081223 a8a9ff0 9814d4c 08b9e09 a8a9ff0 ee12bcf a8a9ff0 7081223 ee12bcf 1927487 9814d4c a8a9ff0 7081223 a8a9ff0 9814d4c 08b9e09 a8a9ff0 7081223 a8a9ff0 f2f3156 7081223 9814d4c 7081223 08b9e09 f2f3156 7081223 08b9e09 f2f3156 7081223 f2f3156 7081223 f2f3156 7081223 08b9e09 7081223 08b9e09 7081223 08b9e09 7081223 f2f3156 7081223 f2f3156 7081223 08b9e09 f2f3156 08b9e09 f2f3156 08b9e09 f2f3156 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import logging
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceHub
from langchain.prompts.chat import (
PromptTemplate,
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from openai.error import AuthenticationError
import streamlit as st
def setup_memory():
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
chat_memory=msgs,
return_messages=True)
logging.info("setting up new chat memory")
return memory
def use_existing_chain(model, provider, model_kwargs):
# TODO: consider whether prompt needs to be checked here
if "mistral" in model:
return False
if "current_chain" in st.session_state:
current_chain = st.session_state.current_chain
if (current_chain.model == model) \
and (current_chain.provider == provider) \
and (current_chain.model_kwargs == model_kwargs):
return True
return False
class CurrentChain():
def __init__(self, model, provider, prompt, memory, model_kwargs):
self.model = model
self.provider = provider
self.model_kwargs = model_kwargs
logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
if provider == "OpenAI":
llm = ChatOpenAI(model_name=model,
temperature=model_kwargs['temperature']
)
elif provider == "HuggingFace":
llm = HuggingFaceHub(repo_id=model,
model_kwargs=model_kwargs
)
self.conversation = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
memory=memory
)
def format_mistral_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
st.header("Basic chatbot")
st.write("On small screens, click the `>` at top left to choose options")
with st.expander("How conversation history works"):
st.write("To keep input lengths down and costs reasonable,"
" only the past three turns of conversation "
" are used for OpenAI models. Otherwise the entire chat history is used.")
st.write("To clear all memory and start fresh, click 'Clear history'")
st.sidebar.title("Choose options")
#### USER INPUT ######
model_name = st.sidebar.selectbox(
label="Choose a model",
options=["gpt-3.5-turbo (OpenAI)",
# "bigscience/bloom (HuggingFace)", # runs
# "google/flan-t5-xxl (HuggingFace)", # runs
"mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
],
help="Which LLM to use",
)
temp = st.sidebar.slider(
label="Temperature",
min_value=float(0),
max_value=2.0,
step=0.1,
value=0.4,
help="Set the decoding temperature. "
"Higher temps give more unpredictable outputs."
)
##########################
model = model_name.split("(")[0].rstrip() # remove name of model provider
provider = model_name.split("(")[-1].split(")")[0]
model_kwargs = {"temperature": temp,
"max_new_tokens": 256,
"repetition_penalty": 1.0,
"top_p": 0.95,
"do_sample": True,
"seed": 42}
# TODO: maybe expose more of these to the user
if "session_memory" not in st.session_state:
st.session_state.session_memory = setup_memory() # for openai
if "history" not in st.session_state:
st.session_state.history = [] # for mistral
if "mistral" in model:
prompt = PromptTemplate(input_variables=["input"],
template="{input}")
else:
prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
"You are a nice chatbot having a conversation with a human."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{input}")
],
verbose=True
)
if use_existing_chain(model, provider, model_kwargs):
chain = st.session_state.current_chain
else:
chain = CurrentChain(model,
provider,
prompt,
st.session_state.session_memory,
model_kwargs)
st.session_state.current_chain = chain
conversation = chain.conversation
if st.button("Clear history"):
conversation.memory.clear() # for openai
st.session_state.history = [] # for mistral
logging.info("history cleared")
for user_msg, asst_msg in st.session_state.history:
with st.chat_message("user"):
st.write(user_msg)
with st.chat_message("assistant"):
st.write(asst_msg)
text = st.chat_input()
if text:
with st.chat_message("user"):
st.write(text)
logging.info(text)
try:
if "mistral" in model:
full_prompt = format_mistral_prompt(text, st.session_state.history)
result = conversation.predict(input=full_prompt)
else:
result = conversation.predict(input=text)
st.session_state.history.append((text, result))
logging.info(repr(result))
with st.chat_message("assistant"):
st.write(result)
except (AuthenticationError, ValueError):
st.warning("Supply a valid API key", icon="⚠️")
|