convosim-ui / pages /comparisor.py
ivnban27-ctl's picture
changed roleplays for openai to GCT and SP
59d5667
raw
history blame
7.89 kB
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
from utils.app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
memories = {
'memoryA': {"issue": ISSUES[0], "source": SOURCES[0]},
'memoryB': {"issue": ISSUES[0], "source": SOURCES[1]},
'commonMemory': {"issue": ISSUES[0], "source": SOURCES[0]}
}
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'previous_sourceA' not in st.session_state:
st.session_state['previous_sourceA'] = SOURCES[0]
if 'previous_sourceB' not in st.session_state:
st.session_state['previous_sourceB'] = SOURCES[1]
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(memoryA)
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
return new_response
def regenerateB():
last_prompt = delete_last_message(memoryB)
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
return new_response
def replaceA():
last_prompt = memoryB.chat_memory.messages[-2].content
new_message = memoryB.chat_memory.messages[-1].content
replace_last_message(memoryA, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
last_prompt = memoryA.chat_memory.messages[-2].content
new_message = memoryA.chat_memory.messages[-1].content
replace_last_message(memoryB, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
responseA = regenerateA()
responseB = regenerateB()
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
def bothGood():
if len(memoryA.buffer_as_messages) == 1:
pass
else:
i = random.choice([memoryA, memoryB])
last_prompt = i.chat_memory.messages[-2].content
last_reponse = i.chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
language = st.selectbox("Select a Language", supported_languages, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", SOURCES, index=1,
format_func=source2label
)
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language})
memories = {
'memoryA': {"issue": issue, "source": sourceA},
'memoryB': {"issue": issue, "source": sourceB},
'commonMemory': {"issue": issue, "source": SOURCES[0]}
}
changed_source = any([
st.session_state['previous_sourceA'] != sourceA,
st.session_state['previous_sourceB'] != sourceB
])
create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
memoryA = st.session_state[list(memories.keys())[0]]
memoryB = st.session_state[list(memories.keys())[1]]
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
st.title(f"πŸ’¬ History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"πŸ’¬ Simulator A")
col2.title(f"πŸ’¬ Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons()