CelebChat / app.py
lhzstar
new commits
8bc0535
raw
history blame
6.37 kB
from celebbot import CelebBot
import streamlit as st
from streamlit_mic_recorder import speech_to_text
from utils import *
def main():
st.set_page_config(initial_sidebar_state="expanded")
hide_footer()
model_list = ["flan-t5-xl"]
celeb_data = get_celeb_data(f'data.json')
st.sidebar.header("CelebChat")
expander = st.sidebar.expander('About the app')
with expander:
st.markdown("Experience the ultimate celebrity chats with this app!")
expander = st.sidebar.expander('Disclaimer')
with expander:
st.markdown("""
* CelebChat may produce inaccurate information about people, places, or facts.
* If you have any concerns about your privacy or believe that the app infringes on your rights, please contact me at liuhaozhe2000@gmail.com. I am committed to addressing your concerns and taking any necessary corrective actions.
""")
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "QA_model_path" not in st.session_state:
st.session_state["QA_model_path"] = "google/flan-t5-xl"
if "sentTr_model_path" not in st.session_state:
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
if "start_chat" not in st.session_state:
st.session_state["start_chat"] = False
if "prompt_from_audio" not in st.session_state:
st.session_state["prompt_from_audio"] = ""
if "prompt_from_text" not in st.session_state:
st.session_state["prompt_from_text"] = ""
if "celeb_bot" not in st.session_state:
st.session_state["celeb_bot"] = None
def text_submit():
st.session_state["prompt_from_text"] = st.session_state.text_input
st.session_state.text_input = ''
def example_submit(text):
st.session_state["prompt_from_text"] = text
def clear_chat_his():
st.session_state["messages"] = []
st.sidebar.selectbox('Choose your celebrity crush', key="celeb_name", options=sorted(list(celeb_data.keys())), on_change=clear_chat_his)
model_id=st.sidebar.selectbox("Choose Your model",options=model_list)
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
if st.session_state["celeb_name"] == "Madonna":
name = "Madonna-American-singer-and-actress"
elif st.session_state["celeb_name"]== "Anne Hathaway":
name = "Anne-Hathaway-American-actress"
else:
name="-".join(st.session_state["celeb_name"].split(" "))
knowledge = get_article(f"https://www.britannica.com/biography/{name}")
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
celeb_gender,
get_tokenizer(st.session_state["QA_model_path"]),
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
get_tokenizer(st.session_state["sentTr_model_path"]),
get_auto_model(st.session_state["sentTr_model_path"]),
*preprocess_text(st.session_state["celeb_name"], knowledge, "en_core_web_lg")
)
dialogue_container = st.container()
with dialogue_container:
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "_last_audio_id" not in st.session_state:
st.session_state["_last_audio_id"] = 0
with st.sidebar:
st.write("You can record your question...")
st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
st.text_input('Or text something...', key='text_input', on_change=text_submit)
st.write("Example questions:")
example1 = "Hello! Did you win an Oscar?"
st.button(example1, on_click=example_submit, args=[example1])
example2 = "Hi! What is your profession?"
st.button(example2, on_click=example_submit, args=[example2])
example3 = "Can you tell me about your family background?"
st.button(example3, on_click=example_submit, args=[example3])
if st.session_state["prompt_from_audio"] != None:
prompt = st.session_state["prompt_from_audio"]
elif st.session_state["prompt_from_text"] != None:
prompt = st.session_state["prompt_from_text"]
if prompt != None and prompt != '':
st.session_state["celeb_bot"].text = prompt
# Display user message in chat message container
with dialogue_container:
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": prompt})
# Add assistant response to chat history
if len(st.session_state["messages"]) < 3:
response = st.session_state["celeb_bot"].question_answer()
else:
chat_his = "Question: {q}\n\nAnswer: {a}\n\n".format(q=st.session_state["messages"][-3]["content"], a=st.session_state["messages"][-2]["content"])
response = st.session_state["celeb_bot"].question_answer(chat_his=chat_his)
# disable autoplay to play in HTML
b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
md = f"""
<p>{response}</p>
<audio controls controlsList="autoplay nodownload">
<source src="data:audio/wav;base64,{b64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
with dialogue_container:
st.chat_message("assistant").markdown(
md,
unsafe_allow_html=True,
)
# Display assistant response in chat message container
st.session_state["messages"].append({"role": "assistant", "content": response})
st.session_state["prompt_from_audio"] = ""
st.session_state["prompt_from_text"] = ""
if __name__ == "__main__":
main()