lfoppiano commited on
Commit
9e0de2a
1 Parent(s): 3cd4af5

simplifying

Browse files
Files changed (1) hide show
  1. streamlit_app.py +16 -19
streamlit_app.py CHANGED
@@ -23,6 +23,13 @@ OPENAI_MODELS = ['chatgpt-3.5-turbo',
23
  "gpt-4",
24
  "gpt-4-1106-preview"]
25
 
 
 
 
 
 
 
 
26
  if 'rqa' not in st.session_state:
27
  st.session_state['rqa'] = {}
28
 
@@ -136,18 +143,14 @@ def init_qa(model, api_key=None):
136
  frequency_penalty=0.1)
137
  embeddings = OpenAIEmbeddings()
138
 
139
- elif model == 'mistral-7b-instruct-v0.1':
140
- chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
141
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
 
 
142
  embeddings = HuggingFaceEmbeddings(
143
  model_name="all-MiniLM-L6-v2")
144
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
145
-
146
- elif model == 'zephyr-7b-beta':
147
- chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
148
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
149
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
150
- st.session_state['memory'] = None
151
  else:
152
  st.error("The model was not loaded properly. Try reloading. ")
153
  st.stop()
@@ -212,14 +215,8 @@ def play_old_messages():
212
  with st.sidebar:
213
  st.session_state['model'] = model = st.selectbox(
214
  "Model:",
215
- options=[
216
- "chatgpt-3.5-turbo",
217
- "mistral-7b-instruct-v0.1",
218
- "zephyr-7b-beta",
219
- "gpt-4",
220
- "gpt-4-1106-preview"
221
- ],
222
- index=2,
223
  placeholder="Select model",
224
  help="Select the LLM model:",
225
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
@@ -228,7 +225,7 @@ with st.sidebar:
228
  st.markdown(
229
  ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
230
 
231
- if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
232
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
233
  api_key = st.text_input('Huggingface API Key', type="password")
234
 
 
23
  "gpt-4",
24
  "gpt-4-1106-preview"]
25
 
26
+ OPEN_MODELS = {
27
+ 'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1',
28
+ "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta'
29
+ }
30
+
31
+ DISABLE_MEMORY = ['zephyr-7b-beta']
32
+
33
  if 'rqa' not in st.session_state:
34
  st.session_state['rqa'] = {}
35
 
 
143
  frequency_penalty=0.1)
144
  embeddings = OpenAIEmbeddings()
145
 
146
+ elif model in OPEN_MODELS:
147
+ chat = HuggingFaceHub(
148
+ repo_id=OPEN_MODELS[model],
149
+ model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}
150
+ )
151
  embeddings = HuggingFaceEmbeddings(
152
  model_name="all-MiniLM-L6-v2")
153
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
 
 
 
 
 
 
154
  else:
155
  st.error("The model was not loaded properly. Try reloading. ")
156
  st.stop()
 
215
  with st.sidebar:
216
  st.session_state['model'] = model = st.selectbox(
217
  "Model:",
218
+ options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
219
+ index=4,
 
 
 
 
 
 
220
  placeholder="Select model",
221
  help="Select the LLM model:",
222
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
 
225
  st.markdown(
226
  ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
227
 
228
+ if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
229
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
230
  api_key = st.text_input('Huggingface API Key', type="password")
231