import os import random import datetime as dt import streamlit as st from streamlit.logger import get_logger from langchain.schema.messages import HumanMessage from utils.mongo_utils import get_db_client, new_comparison, new_battle_result from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF from utils.memory_utils import clear_memory, push_convo2db from utils.chain_utils import get_chain from app_config import ISSUES, SOURCES, source2label logger = get_logger(__name__) openai_api_key = os.environ['OPENAI_API_KEY'] if "sent_messages" not in st.session_state: st.session_state['sent_messages'] = 0 logger.info(f'sent messages {st.session_state["sent_messages"]}') if "issue" not in st.session_state: st.session_state['issue'] = ISSUES[0] if 'previous_sourceA' not in st.session_state: st.session_state['previous_sourceA'] = SOURCES[0] if 'previous_sourceB' not in st.session_state: st.session_state['previous_sourceB'] = SOURCES[0] memories = { 'memoryA': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceA']}, 'memoryB': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceB']}, 'commonMemory': {"issue": st.session_state['issue'], "source": SOURCES[0]} } if 'db_client' not in st.session_state: st.session_state["db_client"] = get_db_client() if 'counselor_name' not in st.session_state: st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) if 'texter_name' not in st.session_state: st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) def delete_last_message(memory): last_prompt = memory.chat_memory.messages[-2].content memory.chat_memory.messages = memory.chat_memory.messages[:-2] return last_prompt def replace_last_message(memory, new_message): memory.chat_memory.messages = memory.chat_memory.messages[:-1] memory.chat_memory.add_ai_message(new_message) def regenerateA(): last_prompt = delete_last_message(memoryA) new_response = llm_chainA.predict(input=last_prompt, stop=stopperA) col1.chat_message("user").write(last_prompt) col1.chat_message("assistant").write(new_response) return new_response def regenerateB(): last_prompt = delete_last_message(memoryB) new_response = llm_chainB.predict(input=last_prompt, stop=stopperB) col2.chat_message("user").write(last_prompt) col2.chat_message("assistant").write(new_response) return new_response def replaceA(): last_prompt = memoryB.chat_memory.messages[-2].content new_message = memoryB.chat_memory.messages[-1].content replace_last_message(memoryA, new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_two' ) def replaceB(): last_prompt = memoryA.chat_memory.messages[-2].content new_message = memoryA.chat_memory.messages[-1].content replace_last_message(memoryB, new_message) st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_one' ) def regenerateBoth(): promt_ts = dt.datetime.now(tz=dt.timezone.utc) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='both_bad' ) responseA = regenerateA() responseB = regenerateB() completion_ts = dt.datetime.now(tz=dt.timezone.utc) new_comparison(st.session_state['db_client'], promt_ts, completion_ts, st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) def bothGood(): if st.session_state['sent_messages'] == 0: pass else: i = random.choice([memoryA, memoryB]) last_prompt = i.chat_memory.messages[-2].content last_reponse = i.chat_memory.messages[-1].content st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse}) new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='tie' ) with st.sidebar: username = st.text_input("Username", value='ivnban-ctl', max_chars=30) issue = st.selectbox("Select an Issue", ISSUES, index=0, on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) supported_languages = ['en', "es"] if issue == "Anxiety" else ['en'] language = st.selectbox("Select a Language", supported_languages, index=0, format_func=lambda x: "English" if x=="en" else "Spanish", on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} ) with st.expander("Model A"): temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1) sourceA = st.selectbox("Select a source Model A", SOURCES, index=0, format_func=source2label ) with st.expander("Model B"): temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1) sourceB = st.selectbox("Select a source Model B", SOURCES, index=0, format_func=source2label ) st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]") sbcol1, sbcol2 = st.columns(2) beta = sbcol1.button("A is better", on_click=replaceB) betb = sbcol2.button("B is better", on_click=replaceA) same = sbcol1.button("Tie", on_click=bothGood) bbad = sbcol2.button("Both are bad", on_click=regenerateBoth) # regenA = sbcol1.button("Regenerate A", on_click=regenerateA) # regenB = sbcol2.button("Regenerate B", on_click=regenerateB) clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language}) changed_source = any([ st.session_state['previous_sourceA'] != sourceA, st.session_state['previous_sourceB'] != sourceB, st.session_state['issue'] != issue ]) if changed_source: print("changed something") st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) st.session_state['previous_sourceA'] = sourceA st.session_state['previous_sourceB'] = sourceB st.session_state['issue'] = issue st.session_state['sent_messages'] = 0 create_memory_add_initial_message(memories, issue, language, changed_source=changed_source, counselor_name=st.session_state["counselor_name"], texter_name=st.session_state["texter_name"]) memoryA = st.session_state[list(memories.keys())[0]] memoryB = st.session_state[list(memories.keys())[1]] llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"]) llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"]) st.title(f"💬 History") for msg in st.session_state['commonMemory'].buffer_as_messages: role = "user" if type(msg) == HumanMessage else "assistant" st.chat_message(role).write(msg.content) col1, col2 = st.columns(2) col1.title(f"💬 Simulator A") col2.title(f"💬 Simulator B") def reset_buttons(): buttons = [beta, betb, same, bbad, #regenA, regenB ] for but in buttons: but = False def disable_chat(): buttons = [beta, betb, same, bbad] if any(buttons): return False else: return True if prompt := st.chat_input(disabled=disable_chat()): st.session_state['sent_messages'] += 1 if 'convo_id' not in st.session_state: push_convo2db(memories, username, language) promt_ts = dt.datetime.now(tz=dt.timezone.utc) col1.chat_message("user").write(prompt) col2.chat_message("user").write(prompt) responseA = llm_chainA.predict(input=prompt, stop=stopperA) responseB = llm_chainB.predict(input=prompt, stop=stopperB) completion_ts = dt.datetime.now(tz=dt.timezone.utc) new_comparison(st.session_state['db_client'], promt_ts, completion_ts, st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB) col1.chat_message("assistant").write(responseA) col2.chat_message("assistant").write(responseB) reset_buttons()