Spaces:
Runtime error
Runtime error
carolanderson
commited on
Commit
•
08b9e09
1
Parent(s):
cbeb966
use session_state instead of st.cache
Browse files
app.py
CHANGED
@@ -17,76 +17,6 @@ from openai.error import AuthenticationError
|
|
17 |
import streamlit as st
|
18 |
|
19 |
|
20 |
-
@st.cache_resource
|
21 |
-
class KeyManager():
|
22 |
-
"""
|
23 |
-
Stores the original API keys from environment variables, which
|
24 |
-
can be overwritten if user supplies keys.
|
25 |
-
Also stores the currently active API key for each model provider and updates
|
26 |
-
these based on user input.
|
27 |
-
"""
|
28 |
-
def __init__(self):
|
29 |
-
self.provider_names = {"OpenAI" : "OPENAI_API_KEY",
|
30 |
-
"HuggingFace" : "HUGGINGFACEHUB_API_TOKEN"}
|
31 |
-
self.original_keys = {k : os.environ.get(v) for k, v
|
32 |
-
in self.provider_names.items()}
|
33 |
-
self.current_keys = {k: os.environ.get(v) for k, v in self.provider_names.items()}
|
34 |
-
self.user_keys = {} # most recent key supplied by user for each provider
|
35 |
-
|
36 |
-
def set_key(self, api_key, model_provider, user_entered=False):
|
37 |
-
self.current_keys[model_provider] = api_key
|
38 |
-
os.environ[self.provider_names[model_provider]] = api_key
|
39 |
-
if user_entered:
|
40 |
-
self.user_keys[model_provider] = api_key
|
41 |
-
get_chain.clear()
|
42 |
-
|
43 |
-
def list_keys(self):
|
44 |
-
"""
|
45 |
-
For debugging purposes only. Do not use in deployed app.
|
46 |
-
"""
|
47 |
-
st.write("Active API keys:")
|
48 |
-
for k, v in self.provider_names.items():
|
49 |
-
st.write(k, " : ", os.environ.get(v))
|
50 |
-
st.write("Current API keys:")
|
51 |
-
for k, v in self.current_keys.items():
|
52 |
-
st.write(k, " : ", v)
|
53 |
-
st.write("User-supplied API keys:")
|
54 |
-
for k, v in self.user_keys.items():
|
55 |
-
st.write(k, " : ", v)
|
56 |
-
st.write("Original API keys:")
|
57 |
-
for k, v in self.original_keys.items():
|
58 |
-
st.write(k, " : ", v)
|
59 |
-
|
60 |
-
def configure_api_key(self, user_api_key, use_provided_key, model_provider):
|
61 |
-
"""
|
62 |
-
Set the currently active API key(s) based on user input.
|
63 |
-
"""
|
64 |
-
if user_api_key:
|
65 |
-
if use_provided_key:
|
66 |
-
st.warning("API key entered and 'use provided key' checked;"
|
67 |
-
" using the key you entered", icon="⚠️")
|
68 |
-
self.set_key(str(user_api_key), model_provider, user_entered=True)
|
69 |
-
return True
|
70 |
-
|
71 |
-
if use_provided_key:
|
72 |
-
self.set_key(self.original_keys[model_provider], model_provider)
|
73 |
-
return True
|
74 |
-
|
75 |
-
if not user_api_key and not use_provided_key:
|
76 |
-
# check if user previously supplied a key for this provider
|
77 |
-
if model_provider in self.user_keys:
|
78 |
-
self.set_key(self.user_keys[model_provider], model_provider)
|
79 |
-
st.warning("No key entered and 'use provided key' not checked;"
|
80 |
-
f" using previously entered {model_provider} key", icon="⚠️")
|
81 |
-
return True
|
82 |
-
|
83 |
-
else:
|
84 |
-
st.warning("Enter an API key or check 'use provided key'"
|
85 |
-
" to get started", icon="⚠️")
|
86 |
-
return False
|
87 |
-
|
88 |
-
|
89 |
-
@st.cache_resource
|
90 |
def setup_memory():
|
91 |
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
|
92 |
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
|
@@ -94,40 +24,49 @@ def setup_memory():
|
|
94 |
return_messages=True)
|
95 |
logging.info("setting up new chat memory")
|
96 |
return memory
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
if __name__ == "__main__":
|
129 |
logging.basicConfig(level=logging.INFO)
|
130 |
-
|
131 |
st.header("Basic chatbot")
|
132 |
st.write("On small screens, click the `>` at top left to get started")
|
133 |
with st.expander("How conversation history works"):
|
@@ -144,22 +83,11 @@ if __name__ == "__main__":
|
|
144 |
],
|
145 |
help="Which LLM to use",
|
146 |
)
|
147 |
-
|
148 |
-
user_api_key = st.sidebar.text_input(
|
149 |
-
'Enter your API Key',
|
150 |
-
type='password',
|
151 |
-
help="Enter an API key for the appropriate model provider",
|
152 |
-
value="")
|
153 |
-
|
154 |
-
use_provided_key = st.sidebar.checkbox(
|
155 |
-
"Or use provided key",
|
156 |
-
help="If you don't have a key, you can use mine; usage limits apply.",
|
157 |
-
)
|
158 |
|
159 |
st.sidebar.write("Set the decoding temperature. Higher temperatures give "
|
160 |
"more unpredictable outputs.")
|
161 |
|
162 |
-
|
163 |
label="Temperature",
|
164 |
min_value=float(0),
|
165 |
max_value=1.0,
|
@@ -168,29 +96,35 @@ if __name__ == "__main__":
|
|
168 |
help="Set the decoding temperature"
|
169 |
)
|
170 |
##########################
|
171 |
-
|
172 |
model = model_name.split("(")[0].rstrip() # remove name of model provider
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
chain =
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
|
196 |
|
|
|
17 |
import streamlit as st
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def setup_memory():
|
21 |
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
|
22 |
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
|
|
|
24 |
return_messages=True)
|
25 |
logging.info("setting up new chat memory")
|
26 |
return memory
|
27 |
+
|
28 |
|
29 |
+
def use_existing_chain(model, provider, temp):
|
30 |
+
if "current_chain" in st.session_state:
|
31 |
+
current_chain = st.session_state.current_chain
|
32 |
+
if (current_chain.model == model) \
|
33 |
+
and (current_chain.provider == provider) \
|
34 |
+
and (current_chain.temp == temp):
|
35 |
+
return True
|
36 |
+
return False
|
37 |
+
|
38 |
+
|
39 |
+
class CurrentChain():
|
40 |
+
def __init__(self, model, provider, memory, temp):
|
41 |
+
self.model = model
|
42 |
+
self.provider = provider
|
43 |
+
self.temp = temp
|
44 |
+
logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
|
45 |
+
if provider == "OpenAI":
|
46 |
+
llm = ChatOpenAI(model_name=model, temperature=temp)
|
47 |
+
elif provider == "HuggingFace":
|
48 |
+
llm = HuggingFaceHub(repo_id=model,
|
49 |
+
model_kwargs={"temperature": temp, "max_length": 64})
|
50 |
+
prompt = ChatPromptTemplate(
|
51 |
+
messages=[
|
52 |
+
SystemMessagePromptTemplate.from_template(
|
53 |
+
"You are a nice chatbot having a conversation with a human."
|
54 |
+
),
|
55 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
56 |
+
HumanMessagePromptTemplate.from_template("{input}")
|
57 |
+
]
|
58 |
+
)
|
59 |
+
self.conversation = LLMChain(
|
60 |
+
llm=llm,
|
61 |
+
prompt=prompt,
|
62 |
+
verbose=True,
|
63 |
+
memory=memory
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
if __name__ == "__main__":
|
68 |
logging.basicConfig(level=logging.INFO)
|
69 |
+
|
70 |
st.header("Basic chatbot")
|
71 |
st.write("On small screens, click the `>` at top left to get started")
|
72 |
with st.expander("How conversation history works"):
|
|
|
83 |
],
|
84 |
help="Which LLM to use",
|
85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
st.sidebar.write("Set the decoding temperature. Higher temperatures give "
|
88 |
"more unpredictable outputs.")
|
89 |
|
90 |
+
temp = st.sidebar.slider(
|
91 |
label="Temperature",
|
92 |
min_value=float(0),
|
93 |
max_value=1.0,
|
|
|
96 |
help="Set the decoding temperature"
|
97 |
)
|
98 |
##########################
|
|
|
99 |
model = model_name.split("(")[0].rstrip() # remove name of model provider
|
100 |
+
provider = model_name.split("(")[-1].split(")")[0]
|
101 |
+
if "session_memory" not in st.session_state:
|
102 |
+
st.session_state.session_memory = setup_memory()
|
103 |
+
|
104 |
+
if use_existing_chain(model, provider, temp):
|
105 |
+
chain = st.session_state.current_chain
|
106 |
+
else:
|
107 |
+
chain = CurrentChain(model,
|
108 |
+
provider,
|
109 |
+
st.session_state.session_memory,
|
110 |
+
temp)
|
111 |
+
st.session_state.current_chain = chain
|
112 |
+
|
113 |
+
conversation = chain.conversation
|
114 |
+
if st.button("Clear history"):
|
115 |
+
conversation.memory.clear()
|
116 |
+
for message in conversation.memory.buffer: # display chat history
|
117 |
+
st.chat_message(message.type).write(message.content)
|
118 |
+
text = st.chat_input()
|
119 |
+
if text:
|
120 |
+
with st.chat_message("user"):
|
121 |
+
st.write(text)
|
122 |
+
try:
|
123 |
+
result = conversation.predict(input=text)
|
124 |
+
with st.chat_message("assistant"):
|
125 |
+
st.write(result)
|
126 |
+
except (AuthenticationError, ValueError):
|
127 |
+
st.warning("Enter a valid API key", icon="⚠️")
|
128 |
|
129 |
|
130 |
|