|
import os |
|
import streamlit as st |
|
from streamlit.logger import get_logger |
|
from langchain.schema.messages import HumanMessage |
|
from utils.mongo_utils import get_db_client, update_convo |
|
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF |
|
from utils.memory_utils import clear_memory, push_convo2db |
|
from utils.chain_utils import get_chain, custom_chain_predict |
|
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT |
|
|
|
logger = get_logger(__name__) |
|
openai_api_key = os.environ['OPENAI_API_KEY'] |
|
temperature = 0.8 |
|
|
|
|
|
if "sent_messages" not in st.session_state: |
|
st.session_state['sent_messages'] = 0 |
|
if "total_messages" not in st.session_state: |
|
st.session_state['total_messages'] = 0 |
|
if "issue" not in st.session_state: |
|
st.session_state['issue'] = ISSUES[0] |
|
if 'previous_source' not in st.session_state: |
|
st.session_state['previous_source'] = SOURCES[0] |
|
if 'db_client' not in st.session_state: |
|
st.session_state["db_client"] = get_db_client() |
|
if 'texter_name' not in st.session_state: |
|
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) |
|
logger.debug(f"texter name is {st.session_state['texter_name']}") |
|
|
|
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}} |
|
|
|
with st.sidebar: |
|
username = st.text_input("Username", value='Dani', max_chars=30) |
|
if 'counselor_name' not in st.session_state: |
|
st.session_state["counselor_name"] = username |
|
|
|
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label, |
|
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} |
|
) |
|
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en'] |
|
language = st.selectbox("Select a Language", supported_languages, index=0, |
|
format_func=lambda x: "English" if x=="en" else "Spanish", |
|
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"} |
|
) |
|
|
|
source = st.selectbox("Select a source Model A", SOURCES, index=0, |
|
format_func=source2label, |
|
) |
|
|
|
changed_source = any([ |
|
st.session_state['previous_source'] != source, |
|
st.session_state['issue'] != issue, |
|
st.session_state['counselor_name'] != username, |
|
]) |
|
if changed_source: |
|
st.session_state["counselor_name"] = username |
|
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF) |
|
logger.debug(f"texter name is {st.session_state['texter_name']}") |
|
st.session_state['previous_source'] = source |
|
st.session_state['issue'] = issue |
|
st.session_state['sent_messages'] = 0 |
|
st.session_state['total_messages'] = 0 |
|
create_memory_add_initial_message(memories, |
|
issue, |
|
language, |
|
changed_source=changed_source, |
|
counselor_name=st.session_state["counselor_name"], |
|
texter_name=st.session_state["texter_name"]) |
|
st.session_state['previous_source'] = source |
|
memoryA = st.session_state[list(memories.keys())[0]] |
|
|
|
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"]) |
|
|
|
st.title("💬 Simulator") |
|
st.session_state['total_messages'] = len(memoryA.chat_memory.messages) |
|
for msg in memoryA.buffer_as_messages: |
|
role = "user" if type(msg) == HumanMessage else "assistant" |
|
st.chat_message(role).write(msg.content) |
|
|
|
if prompt := st.chat_input(): |
|
st.session_state['sent_messages'] += 1 |
|
st.chat_message("user").write(prompt) |
|
if 'convo_id' not in st.session_state: |
|
push_convo2db(memories, username, language) |
|
responses = custom_chain_predict(llm_chain, prompt, stopper) |
|
|
|
|
|
for response in responses: |
|
st.chat_message("assistant").write(response) |
|
transcript = memoryA.load_memory_variables({})[memoryA.memory_key] |
|
update_convo(st.session_state["db_client"], st.session_state["convo_id"], transcript) |
|
|
|
st.session_state['total_messages'] = len(memoryA.chat_memory.messages) |
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]") |
|
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]") |