Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
5832f57 975a927 5832f57 9ff00d4 20b3b4a 1e91476 5832f57 975a927 5832f57 20b3b4a bdca921 975a927 9ff00d4 1e91476 9ff00d4 975a927 9ff00d4 1e91476 9ff00d4 5832f57 ad3d130 bdca921 20b3b4a 18f6362 975a927 5832f57 9ff00d4 5832f57 9ff00d4 975a927 5832f57 9ff00d4 975a927 5832f57 9ff00d4 bdca921 9ff00d4 1e91476 9ff00d4 1e91476 9ff00d4 1e91476 9ff00d4 975a927 bdca921 1e91476 5832f57 1e91476 975a927 5832f57 1e91476 9ff00d4 1e91476 975a927 20b3b4a 5832f57 20b3b4a 1e91476 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain, custom_chain_predict
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
temperature = 0.8
# username = "barb-chase" #"ivnban-ctl"
if "sent_messages" not in st.session_state:
st.session_state['sent_messages'] = 0
if "total_messages" not in st.session_state:
st.session_state['total_messages'] = 0
if "issue" not in st.session_state:
st.session_state['issue'] = ISSUES[0]
if 'previous_source' not in st.session_state:
st.session_state['previous_source'] = SOURCES[0]
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'texter_name' not in st.session_state:
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
logger.debug(f"texter name is {st.session_state['texter_name']}")
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
with st.sidebar:
username = st.text_input("Username", value='Jasmyn', max_chars=30)
if 'counselor_name' not in st.session_state:
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
language = st.selectbox("Select a Language", supported_languages, index=0,
format_func=lambda x: "English" if x=="en" else "Spanish",
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
source = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label,
)
changed_source = any([
st.session_state['previous_source'] != source,
st.session_state['issue'] != issue,
st.session_state['counselor_name'] != username,
])
if changed_source:
st.session_state["counselor_name"] = username
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
logger.debug(f"texter name is {st.session_state['texter_name']}")
st.session_state['previous_source'] = source
st.session_state['issue'] = issue
st.session_state['sent_messages'] = 0
st.session_state['total_messages'] = 0
create_memory_add_initial_message(memories,
issue,
language,
changed_source=changed_source,
counselor_name=st.session_state["counselor_name"],
texter_name=st.session_state["texter_name"])
st.session_state['previous_source'] = source
memoryA = st.session_state[list(memories.keys())[0]]
# issue only without "." marker for model compatibility
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
st.title("💬 Simulator")
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
for msg in memoryA.buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
st.session_state['sent_messages'] += 1
st.chat_message("user").write(prompt)
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
responses = custom_chain_predict(llm_chain, prompt, stopper)
# responses = llm_chain.predict(input=prompt, stop=stopper)
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
for response in responses:
st.chat_message("assistant").write(response)
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
with st.sidebar:
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]") |