Spaces:
Sleeping
Sleeping
File size: 3,359 Bytes
5832f57 975a927 5832f57 59d5667 975a927 5832f57 975a927 5832f57 975a927 59d5667 5832f57 975a927 5832f57 975a927 5832f57 59d5667 5832f57 59d5667 975a927 5832f57 59d5667 975a927 5832f57 975a927 59d5667 975a927 59d5667 5832f57 975a927 5832f57 975a927 5832f57 975a927 5832f57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client
from utils.app_utils import create_memory_add_initial_message, get_random_name
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
memories = {'memory':{"issue": ISSUES[0], "source": SOURCES[0]}}
if 'previous_source' not in st.session_state:
st.session_state['previous_source'] = SOURCES[0]
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'counselor_name' not in st.session_state:
st.session_state["counselor_name"] = get_random_name()
if 'texter_name' not in st.session_state:
st.session_state["texter_name"] = get_random_name()
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
language = st.selectbox("Select a Language", supported_languages, index=0,
format_func=lambda x: "English" if x=="en" else "Spanish",
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
source = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label,
)
memories = {'memory':{"issue":issue, "source":source}}
changed_source = st.session_state['previous_source'] != source
if changed_source:
st.session_state["counselor_name"] = get_random_name()
st.session_state["texter_name"] = get_random_name()
texter_name = create_memory_add_initial_message(memories,
issue,
language,
changed_source=changed_source,
counselor_name=st.session_state["counselor_name"],
texter_name=st.session_state["texter_name"])
st.session_state['previous_source'] = source
memoryA = st.session_state[list(memories.keys())[0]]
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
st.title("💬 Simulator")
for msg in memoryA.buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
if prompt := st.chat_input():
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
st.chat_message("user").write(prompt)
response = llm_chain.predict(input=prompt, stop=stopper)
# response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
st.chat_message("assistant").write(response) |