carolanderson commited on
Commit
08b9e09
1 Parent(s): cbeb966

use session_state instead of st.cache

Browse files
Files changed (1) hide show
  1. app.py +69 -135
app.py CHANGED
@@ -17,76 +17,6 @@ from openai.error import AuthenticationError
17
  import streamlit as st
18
 
19
 
20
- @st.cache_resource
21
- class KeyManager():
22
- """
23
- Stores the original API keys from environment variables, which
24
- can be overwritten if user supplies keys.
25
- Also stores the currently active API key for each model provider and updates
26
- these based on user input.
27
- """
28
- def __init__(self):
29
- self.provider_names = {"OpenAI" : "OPENAI_API_KEY",
30
- "HuggingFace" : "HUGGINGFACEHUB_API_TOKEN"}
31
- self.original_keys = {k : os.environ.get(v) for k, v
32
- in self.provider_names.items()}
33
- self.current_keys = {k: os.environ.get(v) for k, v in self.provider_names.items()}
34
- self.user_keys = {} # most recent key supplied by user for each provider
35
-
36
- def set_key(self, api_key, model_provider, user_entered=False):
37
- self.current_keys[model_provider] = api_key
38
- os.environ[self.provider_names[model_provider]] = api_key
39
- if user_entered:
40
- self.user_keys[model_provider] = api_key
41
- get_chain.clear()
42
-
43
- def list_keys(self):
44
- """
45
- For debugging purposes only. Do not use in deployed app.
46
- """
47
- st.write("Active API keys:")
48
- for k, v in self.provider_names.items():
49
- st.write(k, " : ", os.environ.get(v))
50
- st.write("Current API keys:")
51
- for k, v in self.current_keys.items():
52
- st.write(k, " : ", v)
53
- st.write("User-supplied API keys:")
54
- for k, v in self.user_keys.items():
55
- st.write(k, " : ", v)
56
- st.write("Original API keys:")
57
- for k, v in self.original_keys.items():
58
- st.write(k, " : ", v)
59
-
60
- def configure_api_key(self, user_api_key, use_provided_key, model_provider):
61
- """
62
- Set the currently active API key(s) based on user input.
63
- """
64
- if user_api_key:
65
- if use_provided_key:
66
- st.warning("API key entered and 'use provided key' checked;"
67
- " using the key you entered", icon="⚠️")
68
- self.set_key(str(user_api_key), model_provider, user_entered=True)
69
- return True
70
-
71
- if use_provided_key:
72
- self.set_key(self.original_keys[model_provider], model_provider)
73
- return True
74
-
75
- if not user_api_key and not use_provided_key:
76
- # check if user previously supplied a key for this provider
77
- if model_provider in self.user_keys:
78
- self.set_key(self.user_keys[model_provider], model_provider)
79
- st.warning("No key entered and 'use provided key' not checked;"
80
- f" using previously entered {model_provider} key", icon="⚠️")
81
- return True
82
-
83
- else:
84
- st.warning("Enter an API key or check 'use provided key'"
85
- " to get started", icon="⚠️")
86
- return False
87
-
88
-
89
- @st.cache_resource
90
  def setup_memory():
91
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
92
  memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
@@ -94,40 +24,49 @@ def setup_memory():
94
  return_messages=True)
95
  logging.info("setting up new chat memory")
96
  return memory
 
97
 
98
-
99
- @st.cache_resource
100
- def get_chain(model_name, model_provider, _memory, temperature):
101
- logging.info(f"setting up new chain with params {model_name}, {model_provider}, {temperature}")
102
- if model_provider == "OpenAI":
103
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
104
- elif model_provider == "HuggingFace":
105
- llm = HuggingFaceHub(repo_id=model_name,
106
- model_kwargs={"temperature": temperature, "max_length": 64})
107
- prompt = ChatPromptTemplate(
108
- messages=[
109
- SystemMessagePromptTemplate.from_template(
110
- "You are a nice chatbot having a conversation with a human."
111
- ),
112
- MessagesPlaceholder(variable_name="chat_history"),
113
- HumanMessagePromptTemplate.from_template("{input}")
114
- ]
115
- )
116
- conversation = LLMChain(
117
- llm=llm,
118
- prompt=prompt,
119
- verbose=True,
120
- memory=memory
121
- )
122
- return conversation
123
-
124
-
125
-
126
-
127
-
 
 
 
 
 
 
 
 
128
  if __name__ == "__main__":
129
  logging.basicConfig(level=logging.INFO)
130
-
131
  st.header("Basic chatbot")
132
  st.write("On small screens, click the `>` at top left to get started")
133
  with st.expander("How conversation history works"):
@@ -144,22 +83,11 @@ if __name__ == "__main__":
144
  ],
145
  help="Which LLM to use",
146
  )
147
-
148
- user_api_key = st.sidebar.text_input(
149
- 'Enter your API Key',
150
- type='password',
151
- help="Enter an API key for the appropriate model provider",
152
- value="")
153
-
154
- use_provided_key = st.sidebar.checkbox(
155
- "Or use provided key",
156
- help="If you don't have a key, you can use mine; usage limits apply.",
157
- )
158
 
159
  st.sidebar.write("Set the decoding temperature. Higher temperatures give "
160
  "more unpredictable outputs.")
161
 
162
- temperature = st.sidebar.slider(
163
  label="Temperature",
164
  min_value=float(0),
165
  max_value=1.0,
@@ -168,29 +96,35 @@ if __name__ == "__main__":
168
  help="Set the decoding temperature"
169
  )
170
  ##########################
171
-
172
  model = model_name.split("(")[0].rstrip() # remove name of model provider
173
- model_provider = model_name.split("(")[-1].split(")")[0]
174
- key_manager = KeyManager()
175
- if key_manager.configure_api_key(user_api_key, use_provided_key, model_provider):
176
- # key_manager.list_keys()
177
- memory = setup_memory()
178
- chain = get_chain(model, model_provider, memory, temperature)
179
- if st.button("Clear history"):
180
- chain.memory.clear()
181
- # st.cache_resource.clear()
182
- for message in chain.memory.buffer: # display chat history
183
- st.chat_message(message.type).write(message.content)
184
- text = st.chat_input()
185
- if text:
186
- with st.chat_message("user"):
187
- st.write(text)
188
- try:
189
- result = chain.predict(input=text)
190
- with st.chat_message("assistant"):
191
- st.write(result)
192
- except (AuthenticationError, ValueError):
193
- st.warning("Enter a valid API key", icon="⚠️")
 
 
 
 
 
 
 
194
 
195
 
196
 
 
17
  import streamlit as st
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def setup_memory():
21
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
22
  memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
 
24
  return_messages=True)
25
  logging.info("setting up new chat memory")
26
  return memory
27
+
28
 
29
+ def use_existing_chain(model, provider, temp):
30
+ if "current_chain" in st.session_state:
31
+ current_chain = st.session_state.current_chain
32
+ if (current_chain.model == model) \
33
+ and (current_chain.provider == provider) \
34
+ and (current_chain.temp == temp):
35
+ return True
36
+ return False
37
+
38
+
39
+ class CurrentChain():
40
+ def __init__(self, model, provider, memory, temp):
41
+ self.model = model
42
+ self.provider = provider
43
+ self.temp = temp
44
+ logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
45
+ if provider == "OpenAI":
46
+ llm = ChatOpenAI(model_name=model, temperature=temp)
47
+ elif provider == "HuggingFace":
48
+ llm = HuggingFaceHub(repo_id=model,
49
+ model_kwargs={"temperature": temp, "max_length": 64})
50
+ prompt = ChatPromptTemplate(
51
+ messages=[
52
+ SystemMessagePromptTemplate.from_template(
53
+ "You are a nice chatbot having a conversation with a human."
54
+ ),
55
+ MessagesPlaceholder(variable_name="chat_history"),
56
+ HumanMessagePromptTemplate.from_template("{input}")
57
+ ]
58
+ )
59
+ self.conversation = LLMChain(
60
+ llm=llm,
61
+ prompt=prompt,
62
+ verbose=True,
63
+ memory=memory
64
+ )
65
+
66
+
67
  if __name__ == "__main__":
68
  logging.basicConfig(level=logging.INFO)
69
+
70
  st.header("Basic chatbot")
71
  st.write("On small screens, click the `>` at top left to get started")
72
  with st.expander("How conversation history works"):
 
83
  ],
84
  help="Which LLM to use",
85
  )
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  st.sidebar.write("Set the decoding temperature. Higher temperatures give "
88
  "more unpredictable outputs.")
89
 
90
+ temp = st.sidebar.slider(
91
  label="Temperature",
92
  min_value=float(0),
93
  max_value=1.0,
 
96
  help="Set the decoding temperature"
97
  )
98
  ##########################
 
99
  model = model_name.split("(")[0].rstrip() # remove name of model provider
100
+ provider = model_name.split("(")[-1].split(")")[0]
101
+ if "session_memory" not in st.session_state:
102
+ st.session_state.session_memory = setup_memory()
103
+
104
+ if use_existing_chain(model, provider, temp):
105
+ chain = st.session_state.current_chain
106
+ else:
107
+ chain = CurrentChain(model,
108
+ provider,
109
+ st.session_state.session_memory,
110
+ temp)
111
+ st.session_state.current_chain = chain
112
+
113
+ conversation = chain.conversation
114
+ if st.button("Clear history"):
115
+ conversation.memory.clear()
116
+ for message in conversation.memory.buffer: # display chat history
117
+ st.chat_message(message.type).write(message.content)
118
+ text = st.chat_input()
119
+ if text:
120
+ with st.chat_message("user"):
121
+ st.write(text)
122
+ try:
123
+ result = conversation.predict(input=text)
124
+ with st.chat_message("assistant"):
125
+ st.write(result)
126
+ except (AuthenticationError, ValueError):
127
+ st.warning("Enter a valid API key", icon="⚠️")
128
 
129
 
130