File size: 4,417 Bytes
5832f57
 
975a927
5832f57
9ff00d4
 
 
20b3b4a
18f6362
5832f57
975a927
5832f57
20b3b4a
bdca921
975a927
9ff00d4
 
 
 
975a927
 
 
 
9ff00d4
 
42a7266
9ff00d4
 
5832f57
 
ad3d130
bdca921
 
20b3b4a
18f6362
975a927
5832f57
9ff00d4
5832f57
9ff00d4
975a927
5832f57
 
9ff00d4
975a927
5832f57
 
9ff00d4
 
bdca921
 
9ff00d4
 
42a7266
9ff00d4
42a7266
9ff00d4
 
 
 
 
 
 
 
 
975a927
 
bdca921
42a7266
5832f57
 
 
975a927
5832f57
 
 
 
9ff00d4
42a7266
975a927
 
20b3b4a
 
5832f57
20b3b4a
42a7266
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client
from utils.app_utils import create_memory_add_initial_message, get_random_name,  DEFAULT_NAMES_DF
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain, custom_chain_predict
from app_config import ISSUES, SOURCES, source2label, issue2label

logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY'] 
temperature = 0.8
# username = "barb-chase" #"ivnban-ctl"

if "sent_messages" not in st.session_state:
    st.session_state['sent_messages'] = 0
if "issue" not in st.session_state:
    st.session_state['issue'] = ISSUES[0]
if 'previous_source' not in st.session_state:
    st.session_state['previous_source'] = SOURCES[0]
if 'db_client' not in st.session_state:
    st.session_state["db_client"] = get_db_client()
if 'texter_name' not in st.session_state:
    st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
    logger.debug(f"texter name is {st.session_state['texter_name']}")

memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}

with st.sidebar:
    username = st.text_input("Username", value='Jasmyn', max_chars=30)
    if 'counselor_name' not in st.session_state:
        st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
    # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
    issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
                            on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
                        )
    supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
    language = st.selectbox("Select a Language", supported_languages, index=0,
                            format_func=lambda x: "English" if x=="en" else "Spanish",
                            on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
                        )
                            
    source = st.selectbox("Select a source Model A", SOURCES, index=0,
                          format_func=source2label, 
                        )

changed_source = any([
    st.session_state['previous_source'] != source,
    st.session_state['issue'] != issue,
    st.session_state['counselor_name'] != username,
])
if changed_source:
    st.session_state["counselor_name"] = username
    st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
    logger.debug(f"texter name is {st.session_state['texter_name']}")
    st.session_state['previous_source'] = source
    st.session_state['issue'] = issue
    st.session_state['sent_messages'] = 0
create_memory_add_initial_message(memories,
                                  issue, 
                                  language, 
                                  changed_source=changed_source, 
                                  counselor_name=st.session_state["counselor_name"], 
                                  texter_name=st.session_state["texter_name"])
st.session_state['previous_source'] = source
memoryA = st.session_state[list(memories.keys())[0]]
# issue only without "." marker for model compatibility
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])

st.title("💬 Simulator") 

for msg in memoryA.buffer_as_messages:
    role = "user" if type(msg) == HumanMessage else "assistant"
    st.chat_message(role).write(msg.content)

if prompt := st.chat_input():
    st.session_state['sent_messages'] += 1
    st.chat_message("user").write(prompt)
    if 'convo_id' not in st.session_state:
        push_convo2db(memories, username, language)
    responses = custom_chain_predict(llm_chain, prompt, stopper)
    # responses = llm_chain.predict(input=prompt, stop=stopper)
    # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
    for response in responses:
        st.chat_message("assistant").write(response)

with st.sidebar:
    st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
    st.markdown(f"### Total Messages: :red[**{len(memoryA.chat_memory.messages)}**]")