Spaces:
Sleeping
Sleeping
File size: 9,405 Bytes
975a927 5832f57 975a927 5832f57 975a927 5832f57 9ff00d4 975a927 5832f57 975a927 5832f57 9ff00d4 975a927 9ff00d4 975a927 9ff00d4 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 9ff00d4 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 9ff00d4 5832f57 9ff00d4 975a927 9ff00d4 5832f57 975a927 5832f57 9ff00d4 975a927 5832f57 9ff00d4 5832f57 13be379 5832f57 975a927 9ff00d4 975a927 9ff00d4 975a927 9ff00d4 5832f57 9ff00d4 975a927 5832f57 975a927 5832f57 975a927 5832f57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
if "sent_messages" not in st.session_state:
st.session_state['sent_messages'] = 0
logger.info(f'sent messages {st.session_state["sent_messages"]}')
if "issue" not in st.session_state:
st.session_state['issue'] = ISSUES[0]
if 'previous_sourceA' not in st.session_state:
st.session_state['previous_sourceA'] = SOURCES[0]
if 'previous_sourceB' not in st.session_state:
st.session_state['previous_sourceB'] = SOURCES[0]
memories = {
'memoryA': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceA']},
'memoryB': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceB']},
'commonMemory': {"issue": st.session_state['issue'], "source": SOURCES[0]}
}
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'counselor_name' not in st.session_state:
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
if 'texter_name' not in st.session_state:
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(memoryA)
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
return new_response
def regenerateB():
last_prompt = delete_last_message(memoryB)
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
return new_response
def replaceA():
last_prompt = memoryB.chat_memory.messages[-2].content
new_message = memoryB.chat_memory.messages[-1].content
replace_last_message(memoryA, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
last_prompt = memoryA.chat_memory.messages[-2].content
new_message = memoryA.chat_memory.messages[-1].content
replace_last_message(memoryB, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
responseA = regenerateA()
responseB = regenerateB()
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
def bothGood():
if st.session_state['sent_messages'] == 0:
pass
else:
i = random.choice([memoryA, memoryB])
last_prompt = i.chat_memory.messages[-2].content
last_reponse = i.chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
language = st.selectbox("Select a Language", supported_languages, index=0,
format_func=lambda x: "English" if x=="en" else "Spanish",
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", SOURCES, index=0,
format_func=source2label
)
st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]")
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language})
changed_source = any([
st.session_state['previous_sourceA'] != sourceA,
st.session_state['previous_sourceB'] != sourceB,
st.session_state['issue'] != issue
])
if changed_source:
print("changed something")
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
st.session_state['previous_sourceA'] = sourceA
st.session_state['previous_sourceB'] = sourceB
st.session_state['issue'] = issue
st.session_state['sent_messages'] = 0
create_memory_add_initial_message(memories,
issue,
language,
changed_source=changed_source,
counselor_name=st.session_state["counselor_name"],
texter_name=st.session_state["texter_name"])
memoryA = st.session_state[list(memories.keys())[0]]
memoryB = st.session_state[list(memories.keys())[1]]
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"])
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"])
st.title(f"💬 History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"💬 Simulator A")
col2.title(f"💬 Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
st.session_state['sent_messages'] += 1
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons() |