Ashhar commited on
Commit
bfb639d
·
1 Parent(s): a94a1dc

fixed history context

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -283,12 +283,12 @@ def __setStartMsg(msg):
283
  st.session_state.startMsg = msg
284
 
285
 
 
 
 
286
  if "messages" not in st.session_state:
287
  st.session_state.messages = []
288
 
289
- if "history" not in st.session_state:
290
- st.session_state.history = []
291
-
292
  if "buttonValue" not in st.session_state:
293
  __resetButtonState()
294
 
@@ -296,27 +296,22 @@ if "startMsg" not in st.session_state:
296
  st.session_state.startMsg = ""
297
 
298
 
299
- def __getChatMessages(prompt: str):
300
- st.session_state.history.append({
301
- "role": "user",
302
- "content": prompt
303
- })
304
-
305
  def getContextSize():
306
- currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.history) + 100
307
  pprint(f"{currContextSize=}")
308
  return currContextSize
309
 
310
  while getContextSize() > MAX_CONTEXT:
311
  pprint("Context size exceeded, removing first message")
312
- st.session_state.history.pop(0)
313
 
314
- return st.session_state.history
315
 
316
 
317
  def predict(prompt):
318
  messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
319
- messagesFormatted.extend(__getChatMessages(prompt))
320
  contextSize = countTokens(messagesFormatted)
321
  pprint(f"{contextSize=} | {MODEL}")
322
 
@@ -356,10 +351,10 @@ st.title("Kommuneity Story Creator 📖")
356
  if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
357
  st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
358
 
359
- for message in st.session_state.messages:
360
- role = message["role"]
361
- content = message["content"]
362
- imagePath = message.get("image")
363
  avatar = AI_ICON if role == "assistant" else USER_ICON
364
  with st.chat_message(role, avatar=avatar):
365
  st.markdown(content)
@@ -373,7 +368,8 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
373
  with st.chat_message("user", avatar=USER_ICON):
374
  st.markdown(prompt)
375
  pprint(f"{prompt=}")
376
- st.session_state.messages.append({"role": "user", "content": prompt })
 
377
 
378
  with st.chat_message("assistant", avatar=AI_ICON):
379
  responseContainer = st.empty()
@@ -454,5 +450,9 @@ if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_s
454
  st.session_state.messages.append({
455
  "role": "assistant",
456
  "content": response,
 
 
 
 
457
  "image": imagePath,
458
  })
 
283
  st.session_state.startMsg = msg
284
 
285
 
286
+ if "chatHistory" not in st.session_state:
287
+ st.session_state.chatHistory = []
288
+
289
  if "messages" not in st.session_state:
290
  st.session_state.messages = []
291
 
 
 
 
292
  if "buttonValue" not in st.session_state:
293
  __resetButtonState()
294
 
 
296
  st.session_state.startMsg = ""
297
 
298
 
299
+ def __getMessages():
 
 
 
 
 
300
  def getContextSize():
301
+ currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.messages) + 100
302
  pprint(f"{currContextSize=}")
303
  return currContextSize
304
 
305
  while getContextSize() > MAX_CONTEXT:
306
  pprint("Context size exceeded, removing first message")
307
+ st.session_state.messages.pop(0)
308
 
309
+ return st.session_state.messages
310
 
311
 
312
  def predict(prompt):
313
  messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
314
+ messagesFormatted.extend(__getMessages())
315
  contextSize = countTokens(messagesFormatted)
316
  pprint(f"{contextSize=} | {MODEL}")
317
 
 
351
  if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
352
  st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
353
 
354
+ for chat in st.session_state.chatHistory:
355
+ role = chat["role"]
356
+ content = chat["content"]
357
+ imagePath = chat.get("image")
358
  avatar = AI_ICON if role == "assistant" else USER_ICON
359
  with st.chat_message(role, avatar=avatar):
360
  st.markdown(content)
 
368
  with st.chat_message("user", avatar=USER_ICON):
369
  st.markdown(prompt)
370
  pprint(f"{prompt=}")
371
+ st.session_state.messages.append({"role": "user", "content": prompt})
372
+ st.session_state.chatHistory.append({"role": "user", "content": prompt })
373
 
374
  with st.chat_message("assistant", avatar=AI_ICON):
375
  responseContainer = st.empty()
 
450
  st.session_state.messages.append({
451
  "role": "assistant",
452
  "content": response,
453
+ })
454
+ st.session_state.chatHistory.append({
455
+ "role": "assistant",
456
+ "content": response,
457
  "image": imagePath,
458
  })